[Mlir-commits] [mlir] [MLIR][Canonicalization] Added shape_cast folding patterns (PR #183061)
Alexandra Sidorova
llvmlistbot at llvm.org
Tue Feb 24 22:56:57 PST 2026
https://github.com/a-sidorova updated https://github.com/llvm/llvm-project/pull/183061
>From 78d9b57c3ba903ee083c143db90df04d0ff3a944 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <asidorov at amd.com>
Date: Tue, 24 Feb 2026 18:17:37 +0400
Subject: [PATCH 1/3] [MLIR][Canonicalization] Added shape_cast folding
patterns
Signed-off-by: Alexandra Sidorova <asidorov at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 53 ++++++++++++++++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 26 +++++++++++
2 files changed, 75 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 613adeb5eeaaf..f3faf275f8f58 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2591,9 +2591,31 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
}
};
+/// Pattern to rewrite Y = ToElements(ShapeCast(X)) as Y = ToElements(X)
+///
+/// BEFORE:
+/// %1 = vector.shape_cast %0 : vector<6xf32> to vector<2x3xf32>
+/// %2:6 = vector.to_elements %1 : vector<2x3xf32>
+/// AFTER:
+/// %2:6 = vector.to_elements %0 : vector<6xf32>
+struct FoldToElementsOfShapeCast final : public OpRewritePattern<ToElementsOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto shapeCast = toElementsOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!shapeCast)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ToElementsOp>(toElementsOp,
+ shapeCast.getSource());
+ return success();
+ }
+};
+
void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ToElementsOfBroadcast>(context);
+ results.add<ToElementsOfBroadcast, FoldToElementsOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -6660,13 +6682,36 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
+///
+/// BEFORE:
+/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
+/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
+/// AFTER:
+/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
+class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
+ if (!fromElements)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(
+ shapeCastOp, shapeCastOp.getResultVectorType(),
+ fromElements.getElements());
+ return success();
+ }
+};
+
} // namespace
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
- context);
+ results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
+ FoldShapeCastOfFromElements>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8126389212ce6..4740dc93accb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3209,6 +3209,19 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
// -----
+// CHECK-LABEL: func.func @fold_shape_cast_from_elements(
+// CHECK-SAME: %[[C1:.*]]: f32, %[[C2:.*]]: f32, %[[C3:.*]]: f32, %[[C4:.*]]: f32
+func.func @fold_shape_cast_from_elements(%c1: f32, %c2: f32, %c3: f32, %c4: f32) -> vector<2x2xf32>{
+// CHECK: %[[VAL_0:.*]] = vector.from_elements %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2x2xf32>
+// CHECK: return %[[VAL_0]] : vector<2x2xf32>
+// CHECK-NOT: vector.shape_cast
+ %1 = vector.from_elements %c1, %c2, %c3, %c4 : vector<4xf32>
+ %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ return %2 : vector<2x2xf32>
+}
+
+// -----
+
// TODO: This IR could be canonicalized but the canonicalization pattern is not
// smart enough. For now, just make sure that we do not crash.
@@ -3354,6 +3367,19 @@ func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32
// -----
+// CHECK-LABEL: func @fold_to_elements_of_shape_cast
+// CHECK-SAME: (%[[VEC:.*]]: vector<4xf32>) -> (f32, f32, f32, f32)
+func.func @fold_to_elements_of_shape_cast(%v: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %sc = vector.shape_cast %v : vector<4xf32> to vector<2x2xf32>
+ %e:4 = vector.to_elements %sc : vector<2x2xf32>
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: %[[E:.*]]:4 = vector.to_elements %[[VEC]] : vector<4xf32>
+ // CHECK: return %[[E]]#0, %[[E]]#1, %[[E]]#2, %[[E]]#3 : f32, f32, f32, f32
+ return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
>From 7566851d2848a3d6bfd2336b5220865d073067c5 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <asidorov at amd.com>
Date: Wed, 25 Feb 2026 10:53:55 +0400
Subject: [PATCH 2/3] [MLIR][Canonicalization] Moved FoldToElementsOfShapeCast
to ToElementsOp::fold
Signed-off-by: Alexandra Sidorova <asidorov at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 31 ++++++------------------
1 file changed, 8 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f3faf275f8f58..9ca5637796743 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2496,6 +2496,13 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (succeeded(foldToElementsFromElements(*this, results)))
return success();
+
+ // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
+ if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
+ setOperand(shapeCast.getSource());
+ return success();
+ }
+
return foldToElementsOfBroadcast(*this, results);
}
@@ -2591,31 +2598,9 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
}
};
-/// Pattern to rewrite Y = ToElements(ShapeCast(X)) as Y = ToElements(X)
-///
-/// BEFORE:
-/// %1 = vector.shape_cast %0 : vector<6xf32> to vector<2x3xf32>
-/// %2:6 = vector.to_elements %1 : vector<2x3xf32>
-/// AFTER:
-/// %2:6 = vector.to_elements %0 : vector<6xf32>
-struct FoldToElementsOfShapeCast final : public OpRewritePattern<ToElementsOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
- PatternRewriter &rewriter) const override {
- auto shapeCast = toElementsOp.getSource().getDefiningOp<ShapeCastOp>();
- if (!shapeCast)
- return failure();
-
- rewriter.replaceOpWithNewOp<ToElementsOp>(toElementsOp,
- shapeCast.getSource());
- return success();
- }
-};
-
void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ToElementsOfBroadcast, FoldToElementsOfShapeCast>(context);
+ results.add<ToElementsOfBroadcast>(context);
}
//===----------------------------------------------------------------------===//
>From 3a1eed3da9bc39f80f482377d246432c3b59c0b6 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <asidorov at amd.com>
Date: Wed, 25 Feb 2026 11:01:53 +0400
Subject: [PATCH 3/3] [MLIR][Canonicalization] Moved lit tests to the
specialized files
Signed-off-by: Alexandra Sidorova <asidorov at amd.com>
---
mlir/test/Dialect/Vector/canonicalize.mlir | 26 -------------------
.../canonicalize/vector-from-elements.mlir | 16 ++++++++++++
.../canonicalize/vector-to-elements.mlir | 18 +++++++++++++
3 files changed, 34 insertions(+), 26 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-to-elements.mlir
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 4740dc93accb5..8126389212ce6 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3209,19 +3209,6 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
// -----
-// CHECK-LABEL: func.func @fold_shape_cast_from_elements(
-// CHECK-SAME: %[[C1:.*]]: f32, %[[C2:.*]]: f32, %[[C3:.*]]: f32, %[[C4:.*]]: f32
-func.func @fold_shape_cast_from_elements(%c1: f32, %c2: f32, %c3: f32, %c4: f32) -> vector<2x2xf32>{
-// CHECK: %[[VAL_0:.*]] = vector.from_elements %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2x2xf32>
-// CHECK: return %[[VAL_0]] : vector<2x2xf32>
-// CHECK-NOT: vector.shape_cast
- %1 = vector.from_elements %c1, %c2, %c3, %c4 : vector<4xf32>
- %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
- return %2 : vector<2x2xf32>
-}
-
-// -----
-
// TODO: This IR could be canonicalized but the canonicalization pattern is not
// smart enough. For now, just make sure that we do not crash.
@@ -3367,19 +3354,6 @@ func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32
// -----
-// CHECK-LABEL: func @fold_to_elements_of_shape_cast
-// CHECK-SAME: (%[[VEC:.*]]: vector<4xf32>) -> (f32, f32, f32, f32)
-func.func @fold_to_elements_of_shape_cast(%v: vector<4xf32>) -> (f32, f32, f32, f32) {
- %sc = vector.shape_cast %v : vector<4xf32> to vector<2x2xf32>
- %e:4 = vector.to_elements %sc : vector<2x2xf32>
- // CHECK-NOT: vector.shape_cast
- // CHECK: %[[E:.*]]:4 = vector.to_elements %[[VEC]] : vector<4xf32>
- // CHECK: return %[[E]]#0, %[[E]]#1, %[[E]]#2, %[[E]]#3 : f32, f32, f32, f32
- return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
-}
-
-// -----
-
// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index aa6539e466c95..c64c3888dc779 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -266,3 +266,19 @@ func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> {
return %2 : vector<4xi8>
}
+// -----
+
+///===----------------------------------------------===//
+/// Test of `FoldShapeCastOfFromElements`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func.func @fold_shape_cast_from_elements(
+// CHECK-SAME: %[[C1:.*]]: f32, %[[C2:.*]]: f32, %[[C3:.*]]: f32, %[[C4:.*]]: f32
+func.func @fold_shape_cast_from_elements(%c1: f32, %c2: f32, %c3: f32, %c4: f32) -> vector<2x2xf32>{
+// CHECK: %[[VAL_0:.*]] = vector.from_elements %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2x2xf32>
+// CHECK: return %[[VAL_0]] : vector<2x2xf32>
+// CHECK-NOT: vector.shape_cast
+ %1 = vector.from_elements %c1, %c2, %c3, %c4 : vector<4xf32>
+ %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ return %2 : vector<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-elements.mlir
new file mode 100644
index 0000000000000..76be76a5312c9
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-elements.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// This file contains some tests of folding/canonicalizing vector.to_elements
+
+///===----------------------------------------------===//
+/// Tests of `ToElementsOp::fold`
+///===----------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_of_shape_cast_folds
+// CHECK-SAME: (%[[VEC:.*]]: vector<4xf32>) -> (f32, f32, f32, f32)
+func.func @to_elements_of_shape_cast_folds(%v: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %sc = vector.shape_cast %v : vector<4xf32> to vector<2x2xf32>
+ %e:4 = vector.to_elements %sc : vector<2x2xf32>
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: %[[E:.*]]:4 = vector.to_elements %[[VEC]] : vector<4xf32>
+ // CHECK: return %[[E]]#0, %[[E]]#1, %[[E]]#2, %[[E]]#3 : f32, f32, f32, f32
+ return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
More information about the Mlir-commits
mailing list