[Mlir-commits] [mlir] [mlir] canonicalizer: shape_cast(poison) -> poison (PR #133988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 1 15:34:05 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

<details>
<summary>Changes</summary>

Based on the ShapeCastConstantFolder, this pattern replaces

%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>

with 

%1 = ub.poison : vector<6xf32>

---
Full diff: https://github.com/llvm/llvm-project/pull/133988.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+21-2) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..ee7df8a943d24 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5646,6 +5646,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
   }
 };
 
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+
+    if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+                                              shapeCastOp.getType());
+    return success();
+  }
+};
+
 /// Helper function that computes a new vector type based on the input vector
 /// type by removing the trailing one dims:
 ///
@@ -5804,8 +5821,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 
 void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
-              ShapeCastBroadcastFolder>(context);
+  results
+      .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+           ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
 
 // -----
 
+// CHECK-LABEL: shape_cast_poison
+//       CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+//       CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+//       CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+  %poison = ub.poison : vector<5x4x2xf32>
+  %poison_1 = ub.poison : vector<12x2xi32>
+  %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+  %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+  return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_constant
 //       CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
 //       CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/133988


More information about the Mlir-commits mailing list