[Mlir-commits] [mlir] [mlir][vector] Patterns to convert to shape_cast, where possible (PR #138777)

James Newling llvmlistbot at llvm.org
Tue May 6 15:59:30 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/138777

These are all semantically just copies, and can be rewritten as shape_casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

Currently the vector dialect has no strict specification of which of 2 equivalent forms is canonical, the unwritten rule seems to be that if it is not 'obvious' that a transformation results in something more canonical, it shouldn't be on a op's canonicalization method. So it's probably not worthwhile discussing here if these conversions to shape_cast should be part op canonicalizers! Nonetheless I've found these particular patterns useful in my work, so maybe they're a good addition upstream? 

>From 9e349f31609ccc854b2c40e36fce896dd8e3434d Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:46:36 -0700
Subject: [PATCH 1/3] first

---
 .../mlir/Dialect/Vector/IR/VectorOps.h        |  7 ++
 .../Vector/Transforms/VectorRewritePatterns.h | 20 ++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |  5 +-
 .../Vector/Transforms/VectorTransforms.cpp    | 92 +++++++++++++++++++
 .../Dialect/Vector/convert-to-shape-cast.mlir | 65 +++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 22 +++++
 6 files changed, 207 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/convert-to-shape-cast.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
 class ContractionOp;
 class TransferReadOp;
 class TransferWriteOp;
+class TransposeOp;
 class VectorDialect;
 
 namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
 /// `std::nullopt`.
 std::optional<int64_t> getConstantVscaleMultiplier(Value value);
 
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..a6a221b2e3a67 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
 void populateVectorTransposeNarrowTypeRewritePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel. 
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+                                        PatternBenefit = 1);
+
 /// Initialize `typeConverter` and `conversionTarget` for vector linearization.
 /// This registers (1) which operations are legal and hence should not be
 /// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..11622e1da8de1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,12 @@ LogicalResult ShapeCastOp::verify() {
   return success();
 }
 
-namespace {
 
 /// Return true if `transpose` does not permute a pair of non-unit dims.
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
   ArrayRef<int64_t> permutation = transpose.getPermutation();
   VectorType sourceType = transpose.getSourceVectorType();
   ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
-} // namespace
-
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..05fc6989bf9d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,91 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
   }
 };
 
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
+/// vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, cannot convert to
+    // shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2370,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
       patterns.getContext());
 }
 
+void mlir::vector::populateConvertToShapeCastPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns
+      .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+          patterns.getContext(), benefit);
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd enum attribute definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..483c3e73614e0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast |  FileCheck %s 
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] 
+//  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
+//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] 
+//  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+  %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+  return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<4xi8>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+  return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+//   CHECK-NOT: shape_cast
+//       CHECK: return 
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+  %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+  return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+//  CHECK-SAME:  %[[ARG0:.*]]: vector<1x4xf32>
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+//  CHECK-NEXT:  return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+  %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison. 
+// CHECK-LABEL: @negative_extract_to_shape_cast
+//   CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+  %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+  return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
                          VscaleRange{vscaleMin, vscaleMax});
   }
 };
+
+struct TestConvertToShapeCast
+    : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+  TestConvertToShapeCast() = default;
+
+  StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+  StringRef getDescription() const final {
+    return "Test conversion to shape_cast of semantically equivalent ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateConvertToShapeCastPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
 } // namespace
 
 namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
   PassRegistration<vendor::TestVectorBitWidthLinearize>();
 
   PassRegistration<TestEliminateVectorMasks>();
+
+  PassRegistration<TestConvertToShapeCast>();
 }
 } // namespace test
 } // namespace mlir

>From afb11a43ab0fea5c9ed995c23a88e5ed15c81e5f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:49:24 -0700
Subject: [PATCH 2/3] whitespace

---
 .../Dialect/Vector/Transforms/VectorRewritePatterns.h  |  2 +-
 mlir/test/Dialect/Vector/convert-to-shape-cast.mlir    | 10 +++++-----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a6a221b2e3a67..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -422,7 +422,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// one that has idential input and output shapes, which is already folded.
 ///
 /// These patterns can be useful to expose more folding opportunities by
-/// creating pairs of shape_casts that cancel. 
+/// creating pairs of shape_casts that cancel.
 void populateConvertToShapeCastPatterns(RewritePatternSet &,
                                         PatternBenefit = 1);
 
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
index 483c3e73614e0..0ad6b3ff7d541 100644
--- a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast |  FileCheck %s 
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast |  FileCheck %s
 
 
 // CHECK-LABEL: @transpose_to_shape_cast
 //  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
-//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] 
+//  CHECK-NEXT:  %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
 //  CHECK-NEXT:  return %[[SCAST]] : vector<2x2x1xf32>
 func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
   %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
@@ -14,7 +14,7 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
 
 // CHECK-LABEL: @negative_transpose_to_shape_cast
 //  CHECK-SAME:  %[[ARG0:.*]]: vector<2x1x2xf32>
-//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] 
+//  CHECK-NEXT:  %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
 //  CHECK-NEXT:  return %[[TRANSPOSE]] : vector<2x2x1xf32>
 func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
   %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
@@ -36,7 +36,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
 
 // CHECK-LABEL: @negative_broadcast_to_shape_cast
 //   CHECK-NOT: shape_cast
-//       CHECK: return 
+//       CHECK: return
 func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
   %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
   return %0 : vector<2x3x4xi8>
@@ -55,7 +55,7 @@ func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
 
 // -----
 
-// In this example, arg1 might be negative indicating poison. 
+// In this example, arg1 might be negative indicating poison.
 // CHECK-LABEL: @negative_extract_to_shape_cast
 //   CHECK-NOT: shape_cast
 func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {

>From 37107a4259f723ce9925d3923526ca8df516a2ce Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:52:32 -0700
Subject: [PATCH 3/3] spacing

---
 mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 05fc6989bf9d2..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2184,12 +2184,13 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
 
 /// For example,
 /// ```
-/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
-/// vector<2x2x1xf32>
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
 /// ```
 /// becomes
 /// ```
-/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
 /// ```
 struct TransposeToShapeCast final
     : public OpRewritePattern<vector::TransposeOp> {



More information about the Mlir-commits mailing list