[Mlir-commits] [mlir] [MLIR][Canonicalization] Added shape_cast folding patterns (PR #183061)

Alexandra Sidorova llvmlistbot at llvm.org
Tue Feb 24 06:14:31 PST 2026


https://github.com/a-sidorova created https://github.com/llvm/llvm-project/pull/183061

### Details

- Added folding rewrite patterns for `ToElements(ShapeCast(X)) -> ToElements(X)` and `ShapeCast(FromElements(X)) -> FromElements(X)` to canonicalization

>From 78d9b57c3ba903ee083c143db90df04d0ff3a944 Mon Sep 17 00:00:00 2001
From: Alexandra Sidorova <asidorov at amd.com>
Date: Tue, 24 Feb 2026 18:17:37 +0400
Subject: [PATCH] [MLIR][Canonicalization] Added shape_cast folding patterns

Signed-off-by: Alexandra Sidorova <asidorov at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 53 ++++++++++++++++++++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 26 +++++++++++
 2 files changed, 75 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 613adeb5eeaaf..f3faf275f8f58 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2591,9 +2591,31 @@ struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
   }
 };
 
+/// Pattern to rewrite Y = ToElements(ShapeCast(X)) as Y = ToElements(X)
+///
+/// BEFORE:
+///    %1 = vector.shape_cast %0 : vector<6xf32> to vector<2x3xf32>
+///    %2:6 = vector.to_elements %1 : vector<2x3xf32>
+/// AFTER:
+///    %2:6 = vector.to_elements %0 : vector<6xf32>
+struct FoldToElementsOfShapeCast final : public OpRewritePattern<ToElementsOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+                                PatternRewriter &rewriter) const override {
+    auto shapeCast = toElementsOp.getSource().getDefiningOp<ShapeCastOp>();
+    if (!shapeCast)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ToElementsOp>(toElementsOp,
+                                              shapeCast.getSource());
+    return success();
+  }
+};
+
 void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
-  results.add<ToElementsOfBroadcast>(context);
+  results.add<ToElementsOfBroadcast, FoldToElementsOfShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -6660,13 +6682,36 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
+///
+/// BEFORE:
+///    %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
+///    %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
+/// AFTER:
+///    %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
+class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
+    if (!fromElements)
+      return failure();
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(
+        shapeCastOp, shapeCastOp.getResultVectorType(),
+        fromElements.getElements());
+    return success();
+  }
+};
+
 } // namespace
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results
-      .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
-          context);
+  results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
+              FoldShapeCastOfFromElements>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8126389212ce6..4740dc93accb5 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3209,6 +3209,19 @@ func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_shape_cast_from_elements(
+// CHECK-SAME:    %[[C1:.*]]: f32, %[[C2:.*]]: f32, %[[C3:.*]]: f32, %[[C4:.*]]: f32
+func.func @fold_shape_cast_from_elements(%c1: f32, %c2: f32, %c3: f32, %c4: f32) -> vector<2x2xf32>{
+// CHECK:           %[[VAL_0:.*]] = vector.from_elements %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2x2xf32>
+// CHECK:           return %[[VAL_0]] : vector<2x2xf32>
+// CHECK-NOT: vector.shape_cast
+  %1 = vector.from_elements %c1, %c2, %c3, %c4 : vector<4xf32>
+  %2 = vector.shape_cast %1 : vector<4xf32> to vector<2x2xf32>
+  return %2 : vector<2x2xf32>
+}
+
+// -----
+
 // TODO: This IR could be canonicalized but the canonicalization pattern is not
 // smart enough. For now, just make sure that we do not crash.
 
@@ -3354,6 +3367,19 @@ func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32
 
 // -----
 
+// CHECK-LABEL: func @fold_to_elements_of_shape_cast
+// CHECK-SAME: (%[[VEC:.*]]: vector<4xf32>) -> (f32, f32, f32, f32)
+func.func @fold_to_elements_of_shape_cast(%v: vector<4xf32>) -> (f32, f32, f32, f32) {
+  %sc = vector.shape_cast %v : vector<4xf32> to vector<2x2xf32>
+  %e:4 = vector.to_elements %sc : vector<2x2xf32>
+  // CHECK-NOT: vector.shape_cast
+  // CHECK: %[[E:.*]]:4 = vector.to_elements %[[VEC]] : vector<4xf32>
+  // CHECK: return %[[E]]#0, %[[E]]#1, %[[E]]#2, %[[E]]#3 : f32, f32, f32, f32
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
 // CHECK-LABEL: func @to_elements_of_vector_broadcast
 // CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
 func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {



More information about the Mlir-commits mailing list