[Mlir-commits] [mlir] [mlir][vector] Patterns to convert to shape_cast, where possible (PR #138777)
James Newling
llvmlistbot at llvm.org
Tue May 6 16:06:07 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/138777
>From 9e349f31609ccc854b2c40e36fce896dd8e3434d Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:46:36 -0700
Subject: [PATCH 1/4] first
---
.../mlir/Dialect/Vector/IR/VectorOps.h | 7 ++
.../Vector/Transforms/VectorRewritePatterns.h | 20 ++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 +-
.../Vector/Transforms/VectorTransforms.cpp | 92 +++++++++++++++++++
.../Dialect/Vector/convert-to-shape-cast.mlir | 65 +++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 22 +++++
6 files changed, 207 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 98fb6075cbf32..be9839ce26339 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -50,6 +50,7 @@ namespace vector {
class ContractionOp;
class TransferReadOp;
class TransferWriteOp;
+class TransposeOp;
class VectorDialect;
namespace detail {
@@ -171,6 +172,12 @@ SmallVector<Value> getAsValues(OpBuilder &builder, Location loc,
/// `std::nullopt`.
std::optional<int64_t> getConstantVscaleMultiplier(Value value);
+/// Return true if `transpose` does not permute a pair of non-unit dims.
+/// By `order preserving` we mean that the flattened versions of the input and
+/// output vectors are (numerically) identical. In other words `transpose` is
+/// effectively a shape cast.
+bool isOrderPreserving(TransposeOp transpose);
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index f1100d5cf8b68..a6a221b2e3a67 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -406,6 +406,26 @@ void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
void populateVectorTransposeNarrowTypeRewritePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Add patterns that convert operations that are semantically equivalent to
+/// shape_cast, to shape_cast. Currently this includes patterns for converting
+/// transpose, extract and broadcast to shape_cast. Examples that will be
+/// converted to shape_cast are:
+///
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
+/// %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
+/// ```
+///
+/// Note that there is no pattern for vector.extract_strided_slice, because the
+/// only extract_strided_slice that is semantically equivalent to shape_cast is
+/// one that has idential input and output shapes, which is already folded.
+///
+/// These patterns can be useful to expose more folding opportunities by
+/// creating pairs of shape_casts that cancel.
+void populateConvertToShapeCastPatterns(RewritePatternSet &,
+ PatternBenefit = 1);
+
/// Initialize `typeConverter` and `conversionTarget` for vector linearization.
/// This registers (1) which operations are legal and hence should not be
/// linearized, (2) what converted types are (rank-1 vectors) and how to
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f9c7fb7799eb0..11622e1da8de1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,13 +5574,12 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-namespace {
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-bool isOrderPreserving(TransposeOp transpose) {
+bool mlir::vector::isOrderPreserving(TransposeOp transpose) {
ArrayRef<int64_t> permutation = transpose.getPermutation();
VectorType sourceType = transpose.getSourceVectorType();
ArrayRef<int64_t> inShape = sourceType.getShape();
@@ -5600,8 +5599,6 @@ bool isOrderPreserving(TransposeOp transpose) {
return true;
}
-} // namespace
-
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..05fc6989bf9d2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2182,6 +2182,91 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
}
};
+/// For example,
+/// ```
+/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
+/// vector<2x2x1xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
+/// ```
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// ```
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
+
+/// For example,
+/// ```
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// ```
+/// becomes
+/// ```
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// ```
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, cannot convert to
+ // shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateFoldArithExtensionPatterns(
@@ -2285,6 +2370,13 @@ void mlir::vector::populateElementwiseToVectorOpsPatterns(
patterns.getContext());
}
+void mlir::vector::populateConvertToShapeCastPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns
+ .insert<TransposeToShapeCast, BroadcastToShapeCast, ExtractToShapeCast>(
+ patterns.getContext(), benefit);
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
new file mode 100644
index 0000000000000..483c3e73614e0
--- /dev/null
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
+
+
+// CHECK-LABEL: @transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
+func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_transpose_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
+func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
+ %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
+ return %0 : vector<2x2x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @broadcast_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
+func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+ return %0 : vector<1x1x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_broadcast_to_shape_cast
+// CHECK-NOT: shape_cast
+// CHECK: return
+func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
+ %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
+ return %0 : vector<2x3x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_to_shape_cast
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
+func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
+ %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// In this example, arg1 might be negative indicating poison.
+// CHECK-LABEL: @negative_extract_to_shape_cast
+// CHECK-NOT: shape_cast
+func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
+ %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32>
+ return %0 : vector<4xf32>
+}
+
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index b73c40adcffa7..aa97d6fc5dc69 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -1022,6 +1022,26 @@ struct TestEliminateVectorMasks
VscaleRange{vscaleMin, vscaleMax});
}
};
+
+struct TestConvertToShapeCast
+ : public PassWrapper<TestConvertToShapeCast, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertToShapeCast)
+
+ TestConvertToShapeCast() = default;
+
+ StringRef getArgument() const final { return "test-convert-to-shape-cast"; }
+ StringRef getDescription() const final {
+ return "Test conversion to shape_cast of semantically equivalent ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateConvertToShapeCastPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -1072,6 +1092,8 @@ void registerTestVectorLowerings() {
PassRegistration<vendor::TestVectorBitWidthLinearize>();
PassRegistration<TestEliminateVectorMasks>();
+
+ PassRegistration<TestConvertToShapeCast>();
}
} // namespace test
} // namespace mlir
>From afb11a43ab0fea5c9ed995c23a88e5ed15c81e5f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:49:24 -0700
Subject: [PATCH 2/4] whitespace
---
.../Dialect/Vector/Transforms/VectorRewritePatterns.h | 2 +-
mlir/test/Dialect/Vector/convert-to-shape-cast.mlir | 10 +++++-----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a6a221b2e3a67..3344765f4818a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -422,7 +422,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// one that has idential input and output shapes, which is already folded.
///
/// These patterns can be useful to expose more folding opportunities by
-/// creating pairs of shape_casts that cancel.
+/// creating pairs of shape_casts that cancel.
void populateConvertToShapeCastPatterns(RewritePatternSet &,
PatternBenefit = 1);
diff --git a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
index 483c3e73614e0..0ad6b3ff7d541 100644
--- a/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/convert-to-shape-cast.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-convert-to-shape-cast | FileCheck %s
// CHECK-LABEL: @transpose_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
-// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
+// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
@@ -14,7 +14,7 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
// CHECK-LABEL: @negative_transpose_to_shape_cast
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
-// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
%0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
@@ -36,7 +36,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
// CHECK-LABEL: @negative_broadcast_to_shape_cast
// CHECK-NOT: shape_cast
-// CHECK: return
+// CHECK: return
func.func @negative_broadcast_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8>
return %0 : vector<2x3x4xi8>
@@ -55,7 +55,7 @@ func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
// -----
-// In this example, arg1 might be negative indicating poison.
+// In this example, arg1 might be negative indicating poison.
// CHECK-LABEL: @negative_extract_to_shape_cast
// CHECK-NOT: shape_cast
func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {
>From 37107a4259f723ce9925d3923526ca8df516a2ce Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 15:52:32 -0700
Subject: [PATCH 3/4] spacing
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 05fc6989bf9d2..efcde8e97c0cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2184,12 +2184,13 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
/// For example,
/// ```
-/// %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to
-/// vector<2x2x1xf32>
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
/// becomes
/// ```
-/// %0 = vector.shape_cast %arg0 : vector<2x1x2xf32> to vector<2x2x1xf32>
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// ```
struct TransposeToShapeCast final
: public OpRewritePattern<vector::TransposeOp> {
>From 6854f9262bfecd08b70903a0a272361fe779f867 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 16:06:14 -0700
Subject: [PATCH 4/4] format
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 11622e1da8de1..562fc7d6ca110 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5574,7 +5574,6 @@ LogicalResult ShapeCastOp::verify() {
return success();
}
-
/// Return true if `transpose` does not permute a pair of non-unit dims.
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
More information about the Mlir-commits
mailing list