[Mlir-commits] [mlir] [mli][vector] vector.from_elements canonicalizer when elements ascending extracts (PR #139819)

James Newling llvmlistbot at llvm.org
Tue May 13 17:47:05 PDT 2025


================
@@ -2385,9 +2385,69 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+static LogicalResult
+rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp,
+                               PatternRewriter &rewriter) {
+
+  // The common source of vector.extract operations (if one exists), as well
+  // as its shape and rank. These are set in the first iteration of the loop
+  // over the operands (elements) of `fromElementsOp`.
+  Value source;
+  ArrayRef<int64_t> shape;
+  int64_t rank;
+
+  for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) {
+
+    // Check that the element is defined by an extract operation, and that
+    // the extract is on the same vector as all preceding elements.
+    auto extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+    if (!extractOp)
+      return failure();
+    Value currentSource = extractOp.getVector();
+    if (index == 0) {
+      source = currentSource;
+      shape = extractOp.getSourceVectorType().getShape();
+      rank = shape.size();
+    } else if (currentSource != source) {
+      return failure();
+    }
+
+    // Check that the (linearized) index of extraction is the same as the index
+    // in the result of `fromElementsOp`.
+    ArrayRef<int64_t> position = extractOp.getStaticPosition();
+    assert(position.size() == rank &&
+           "scalar extract must have full rank position");
+    int64_t stride{1};
+    int64_t offset{0};
+    for (auto [pos, size] :
+         llvm::zip(llvm::reverse(position), llvm::reverse(shape))) {
+      if (pos == ShapedType::kDynamic)
+        return failure();
+      offset += pos * stride;
+      stride *= size;
+    }
+    if (offset != index)
+      return failure();
+  }
+
+  rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp,
+                                           fromElementsOp.getType(), source);
+}
+
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
+  results.add(rewriteFromElementsAsShapeCast);
----------------
newling wrote:

Just copying the design pattern of `rewriteFromElementsAsSplat` but this maybe complex enough to warrant an actual Pattern with `notifyMatchFailure` s. 

https://github.com/llvm/llvm-project/pull/139819


More information about the Mlir-commits mailing list