[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu May 15 06:47:31 PDT 2025


================
@@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElements,
+                                PatternRewriter &rewriter) const override {
+
+    mlir::OperandRange elements = fromElements.getElements();
+    assert(!elements.empty() && "must be at least 1 element");
+    Value firstElement = elements.front();
+
+    ExtractOp extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+    if (!extractOp) {
+      return rewriter.notifyMatchFailure(
+          fromElements, "first element not from vector.extract");
+    }
+    VectorType sourceType = extractOp.getSourceVectorType();
+    Value source = extractOp.getVector();
+
+    // Check condition (ii).
+    if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+      return rewriter.notifyMatchFailure(fromElements,
+                                         "number of elements differ");
+    }
----------------
banach-space wrote:

[nit] I would "rebrand" this as "Condition (i)" (it's the first condition to be checked) and move it all the way to the top - it feels like a fairly high level condition that deserves a special place :)

https://github.com/llvm/llvm-project/pull/139819


More information about the Mlir-commits mailing list