[Mlir-commits] [mlir] [MLIR][Canonicalization] Added shape_cast folding patterns (PR #183061)
Alexandra Sidorova
llvmlistbot at llvm.org
Tue Feb 24 06:14:31 PST 2026
https://github.com/a-sidorova created https://github.com/llvm/llvm-project/pull/183061
### Details
- Added folding rewrite patterns for `ToElements(ShapeCast(X)) -> ToElements(X)` and `ShapeCast(FromElements(X)) -> FromElements(X)` to canonicalization
>From 78d9b57c3ba903ee083c143db90df04d0ff3a944 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <asidorov at amd.com>
Date: Tue, 24 Feb 2026 18:17:37 +0400
Subject: [PATCH] [MLIR][Canonicalization] Added shape_cast folding patterns
Signed-off-by: Alexandra Sidorova <asidorov at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 53 ++++++++++++++++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 26 +++++++++++
2 files changed, 75 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 613adeb5eeaaf..f3faf275f8f58 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2591,9 +2591,31 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
}
};
+/// Pattern to rewrite Y = ToElements(ShapeCast(X)) as Y = ToElements(X)
+///
+/// BEFORE:
+/// %1 = vector.shape_cast %0 : vector<6xf32> to vector<2x3xf32>
+/// %2:6 = vector.to_elements %1 : vector<2x3xf32>
+/// AFTER:
+/// %2:6 = vector.to_elements %0 : vector<6xf32>
+struct FoldToElementsOfShapeCast final : public OpRewritePattern<ToElementsOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto shapeCast = toElementsOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!shapeCast)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ToElementsOp>(toElementsOp,
+ shapeCast.getSource());
+ return success();
+ }
+};
+
void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ToElementsOfBroadcast>(context);
+ results.add<ToElementsOfBroadcast, FoldToElementsOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -6660,13 +6682,36 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
+///
+/// BEFORE:
+/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
+/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
+/// AFTER:
+/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
+class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
+ if (!fromElements)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(
+ shapeCastOp, shapeCastOp.getResultVectorType(),
+ fromElements.getElements());
+ return success();
+ }
+};
+
} // namespace
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
- context);
+ results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
+ FoldShapeCastOfFromElements>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8126389212ce6..4740dc93accb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3209,6 +3209,19 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
// -----
+// CHECK-LABEL: func.func @fold_shape_cast_from_elements(
+// CHECK-SAME: %[[C1:.*]]: f32, %[[C2:.*]]: f32, %[[C3:.*]]: f32, %[[C4:.*]]: f32
+func.func @fold_shape_cast_from_elements(%c1: f32, %c2: f32, %c3: f32, %c4: f32) -> vector<2x2xf32>{
+// CHECK: %[[VAL_0:.*]] = vector.from_elements %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2x2xf32>
+// CHECK: return %[[VAL_0]] : vector<2x2xf32>
+// CHECK-NOT: vector.shape_cast
+ %1 = vector.from_elements %c1, %c2, %c3, %c4 : vector<4xf32>
+ %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+ return %2 : vector<2x2xf32>
+}
+
+// -----
+
// TODO: This IR could be canonicalized but the canonicalization pattern is not
// smart enough. For now, just make sure that we do not crash.
@@ -3354,6 +3367,19 @@ func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32
// -----
+// CHECK-LABEL: func @fold_to_elements_of_shape_cast
+// CHECK-SAME: (%[[VEC:.*]]: vector<4xf32>) -> (f32, f32, f32, f32)
+func.func @fold_to_elements_of_shape_cast(%v: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %sc = vector.shape_cast %v : vector<4xf32> to vector<2x2xf32>
+ %e:4 = vector.to_elements %sc : vector<2x2xf32>
+ // CHECK-NOT: vector.shape_cast
+ // CHECK: %[[E:.*]]:4 = vector.to_elements %[[VEC]] : vector<4xf32>
+ // CHECK: return %[[E]]#0, %[[E]]#1, %[[E]]#2, %[[E]]#3 : f32, f32, f32, f32
+ return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
// CHECK-LABEL: func @to_elements_of_vector_broadcast
// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
More information about the Mlir-commits
mailing list