[Mlir-commits] [mlir] [mli][vector] canonicalize vector.from_elements from ascending extracts (PR #139819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 14 12:48:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-vector

Author: James Newling (newling)

<details>
<summary>Changes</summary>

Example:
```mlir
%0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
```

becomes
```mlir
%2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
```


---
Full diff: https://github.com/llvm/llvm-project/pull/139819.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+90) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (-69) 
- (added) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+155) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..1080263ed3eb6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
@@ -2385,9 +2386,98 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
   return success();
 }
 
+/// Rewrite vector.from_elements as vector.shape_cast, if possible.
+///
+/// Example:
+///   %0 = vector.extract %source[0, 0] : i8 from vector<1x2xi8>
+///   %1 = vector.extract %source[0, 1] : i8 from vector<1x2xi8>
+///   %2 = vector.from_elements %0, %1 : vector<2xi8>
+///
+/// becomes
+///   %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8>
+///
+/// The requirements for this to be valid are
+/// i) all elements are extracted from the same vector (source),
+/// ii) source and from_elements result have the same number of elements,
+/// iii) the elements are extracted in ascending order.
+///
+/// It might be possible to rewrite vector.from_elements as a single
+/// vector.extract if (ii) is not satisifed, or in some cases as a
+/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied,
+/// this is left for future consideration.
+class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FromElementsOp fromElements,
+                                PatternRewriter &rewriter) const override {
+
+    mlir::OperandRange elements = fromElements.getElements();
+    assert(!elements.empty() && "must be at least 1 element");
+    Value firstElement = elements.front();
+
+    ExtractOp extractOp =
+        dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp());
+    if (!extractOp) {
+      return rewriter.notifyMatchFailure(
+          fromElements, "first element not from vector.extract");
+    }
+    VectorType sourceType = extractOp.getSourceVectorType();
+    Value source = extractOp.getVector();
+
+    // Check condition (ii).
+    if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) {
+      return rewriter.notifyMatchFailure(fromElements,
+                                         "number of elements differ");
+    }
+
+    for (auto [indexMinusOne, element] :
+         llvm::enumerate(elements.drop_front(1))) {
+
+      extractOp =
+          dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp());
+      if (!extractOp) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element not from vector.extract");
+      }
+      Value currentSource = extractOp.getVector();
+      // Check condition (i).
+      if (currentSource != source) {
+        return rewriter.notifyMatchFailure(fromElements,
+                                           "element from different vector");
+      }
+
+      ArrayRef<int64_t> position = extractOp.getStaticPosition();
+      assert(position.size() == static_cast<size_t>(sourceType.getRank()) &&
+             "scalar extract must have full rank position");
+      int64_t stride{1};
+      int64_t offset{0};
+      for (auto [pos, size] : llvm::zip(llvm::reverse(position),
+                                        llvm::reverse(sourceType.getShape()))) {
+        if (pos == ShapedType::kDynamic) {
+          return rewriter.notifyMatchFailure(
+              fromElements, "elements not in ascending order (dynamic order)");
+        }
+        offset += pos * stride;
+        stride *= size;
+      }
+      // Check condition (iii).
+      if (offset != static_cast<int64_t>(indexMinusOne + 1)) {
+        return rewriter.notifyMatchFailure(
+            fromElements, "elements not in ascending order (static order)");
+      }
+    }
+
+    rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements,
+                                             fromElements.getType(), source);
+    return success();
+  }
+};
+
 void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                  MLIRContext *context) {
   results.add(rewriteFromElementsAsSplat);
+  results.add<FromElementsToShapCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 99f0850000a16..6af517d988360 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2952,75 +2952,6 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
 
 // -----
 
-// CHECK-LABEL: func @extract_scalar_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
-  // Extract from 0D.
-  %0 = vector.from_elements %a : vector<f32>
-  %1 = vector.extract %0[] : f32 from vector<f32>
-
-  // Extract from 1D.
-  %2 = vector.from_elements %a : vector<1xf32>
-  %3 = vector.extract %2[0] : f32 from vector<1xf32>
-  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
-  %5 = vector.extract %4[4] : f32 from vector<5xf32>
-
-  // Extract from 2D.
-  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
-  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
-  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
-  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
-
-  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
-  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_1d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
-  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
-  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
-  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<3xf32>, vector<3xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @extract_2d_from_from_elements(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
-  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
-  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
-  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
-  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
-  // CHECK: return %[[splat1]], %[[splat2]]
-  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @from_elements_to_splat(
-//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
-func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
-  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
-  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
-  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
-  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
-  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
-  %2 = vector.from_elements %a : vector<f32>
-  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
-  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
-}
-
-// -----
-
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
new file mode 100644
index 0000000000000..14bf5d9df4783
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -0,0 +1,155 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.from_elements
+
+///===----------------------------------------------===//
+///  Tests of `rewriteFromElementsAsSplat`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @extract_scalar_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
+  // Extract from 0D.
+  %0 = vector.from_elements %a : vector<f32>
+  %1 = vector.extract %0[] : f32 from vector<f32>
+
+  // Extract from 1D.
+  %2 = vector.from_elements %a : vector<1xf32>
+  %3 = vector.extract %2[0] : f32 from vector<1xf32>
+  %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
+  %5 = vector.extract %4[4] : f32 from vector<5xf32>
+
+  // Extract from 2D.
+  %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
+  %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+  %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
+  %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
+
+  // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]]
+  return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_1d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
+  // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32>
+  %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32>
+  %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<3xf32>, vector<3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @extract_2d_from_from_elements(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
+  %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
+  // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32>
+  %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32>
+  %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
+  // CHECK: return %[[splat1]], %[[splat2]]
+  return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @from_elements_to_splat(
+//  CHECK-SAME:     %[[a:.*]]: f32, %[[b:.*]]: f32)
+func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
+  // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32>
+  %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
+  // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
+  %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
+  // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector<f32>
+  %2 = vector.from_elements %a : vector<f32>
+  // CHECK: return %[[splat]], %[[from_el]], %[[splat2]]
+  return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
+}
+
+// -----
+
+///===----------------------------------------------===//
+///  Tests of `FromElementsToShapeCast`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
+//  CHECK-SAME:       %[[a:.*]]: vector<1x2xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2xi8>
+func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %4 = vector.from_elements %0, %1 : vector<2xi8>
+  return %4 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
+//  CHECK-SAME:       %[[a:.*]]: vector<8xi8>)
+//       CHECK:       %[[shape_cast:.*]] = vector.shape_cast %[[a]] : vector<8xi8> to vector<2x2x2xi8>
+//       CHECK:       return %[[shape_cast]] : vector<2x2x2xi8>
+func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
+  %0 = vector.extract %arg0[0] : i8 from vector<8xi8>
+  %1 = vector.extract %arg0[1] : i8 from vector<8xi8>
+  %2 = vector.extract %arg0[2] : i8 from vector<8xi8>
+  %3 = vector.extract %arg0[3] : i8 from vector<8xi8>
+  %4 = vector.extract %arg0[4] : i8 from vector<8xi8>
+  %5 = vector.extract %arg0[5] : i8 from vector<8xi8>
+  %6 = vector.extract %arg0[6] : i8 from vector<8xi8>
+  %7 = vector.extract %arg0[7] : i8 from vector<8xi8>
+  %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
+  return %8 : vector<2x2x2xi8>
+}
+
+// -----
+
+// The extracted elements are recombined into a single vector, but in a new order.
+// CHECK-LABEL: func @negative_nonascending_order(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_nonstatic_extract(
+//   CHECK-NOT: shape_cast
+func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_different_sources(
+//   CHECK-NOT: shape_cast
+func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
+  %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_source_too_large(
+//   CHECK-NOT: shape_cast
+func.func @negative_source_too_large(%arg0: vector<1x3xi8>) -> vector<2xi8> {
+  %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
+  %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
+  %2 = vector.from_elements %0, %1 : vector<2xi8>
+  return %2 : vector<2xi8>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/139819


More information about the Mlir-commits mailing list