[Mlir-commits] [mlir] [mlir] canonicalizer: shape_cast(poison) -> poison (PR #133988)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 1 15:34:05 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: James Newling (newling)
<details>
<summary>Changes</summary>
Based on the ShapeCastConstantFolder, this pattern replaces
%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>
with
%1 = ub.poison : vector<6xf32>
---
Full diff: https://github.com/llvm/llvm-project/pull/133988.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+21-2)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5a3983699d5a3..ee7df8a943d24 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5646,6 +5646,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+ shapeCastOp.getType());
+ return success();
+ }
+};
+
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
@@ -5804,8 +5821,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ results
+ .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+ ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// CHECK-LABEL: shape_cast_poison
+// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+ %poison = ub.poison : vector<5x4x2xf32>
+ %poison_1 = ub.poison : vector<12x2xi32>
+ %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+ %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+ return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/133988
More information about the Mlir-commits
mailing list