[Mlir-commits] [mlir] [mlir] canonicalizer: shape_cast(poison) -> poison (PR #133988)
James Newling
llvmlistbot at llvm.org
Wed Apr 9 13:01:50 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/133988
>From 2ddaa8ecb824e263730b1747b2762c2526b0773e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 1 Apr 2025 14:25:16 -0700
Subject: [PATCH 1/4] add canonicalizer
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 23 ++++++++++++++++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 14 +++++++++++++
2 files changed, 35 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..68b4c26880141 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5670,6 +5670,23 @@ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
}
};
+// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
+class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
+ shapeCastOp.getType());
+ return success();
+ }
+};
+
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
///
@@ -5828,8 +5845,10 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ results
+ .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
+ ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// CHECK-LABEL: shape_cast_poison
+// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+ %poison = ub.poison : vector<5x4x2xf32>
+ %poison_1 = ub.poison : vector<12x2xi32>
+ %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+ %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+ return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>
>From 99ed6355272915d893db4a38fdfd7c3ba0fa0daa Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 7 Apr 2025 09:04:44 -0700
Subject: [PATCH 2/4] use folders where possible (replace 2 canons)
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 56 +++++++-----------------
1 file changed, 15 insertions(+), 41 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 68b4c26880141..53fc47a6d6ef5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5643,49 +5643,24 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- return {};
-}
-
-namespace {
-// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
-class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto constantOp =
- shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
- if (!constantOp)
- return failure();
- // Only handle splat for now.
- auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
- if (!dense)
- return failure();
- auto newAttr =
- DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
- dense.getSplatValue<Attribute>());
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
- return success();
+ // Replace shape_cast(arith.constant) with arith.constant. Currently only
+ // handles splat constants.
+ if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
+ if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
+ return DenseElementsAttr::get(cast<VectorType>(getType()),
+ dense.getSplatValue<Attribute>());
+ }
}
-};
-// Pattern to rewrite a ShapeCast(PoisonOp) -> PoisonOp.
-class ShapeCastPoisonFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
+ // Replace shape_cast(poison) with poison.
+ if (getSource().getDefiningOp<ub::PoisonOp>()) {
+ return ub::PoisonAttr::get(getContext());
+ }
- if (!shapeCastOp.getSource().getDefiningOp<ub::PoisonOp>())
- return failure();
+ return {};
+}
- rewriter.replaceOpWithNewOp<ub::PoisonOp>(shapeCastOp,
- shapeCastOp.getType());
- return success();
- }
-};
+namespace {
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
@@ -5846,8 +5821,7 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
- .add<ShapeCastConstantFolder, ShapeCastPoisonFolder,
- ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+ .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
context);
}
>From eb5b9d7046b8c7ae0b30f8872ac4a593074af890 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Mon, 7 Apr 2025 10:16:20 -0700
Subject: [PATCH 3/4] simplify folding of const/poison
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 +++++++++++-------------
1 file changed, 18 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 53fc47a6d6ef5..0ac22969629da 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <cassert>
#include <cstdint>
@@ -5611,28 +5612,27 @@ LogicalResult ShapeCastOp::verify() {
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+
// No-op shape cast.
- if (getSource().getType() == getResult().getType())
+ if (getSource().getType() == getType())
return getSource();
// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
- if (getResult().getType() == otherOp.getSource().getType())
- return otherOp.getSource();
- // Only allows valid transitive folding.
- VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
- VectorType resultType = llvm::cast<VectorType>(getResult().getType());
- if (srcType.getRank() < resultType.getRank()) {
- if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
+ // Only allows valid transitive folding (expand/collapse dimensions).
+ VectorType srcType = otherOp.getSource().getType();
+ if (getType() == srcType)
+ return otherOp.getSource();
+ if (srcType.getRank() < getType().getRank()) {
+ if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
return {};
- } else if (srcType.getRank() > resultType.getRank()) {
- if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
+ } else if (srcType.getRank() > getType().getRank()) {
+ if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
return {};
} else {
return {};
}
-
setOperand(otherOp.getSource());
return getResult();
}
@@ -5643,17 +5643,15 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
return bcastOp.getSource();
}
- // Replace shape_cast(arith.constant) with arith.constant. Currently only
- // handles splat constants.
- if (auto constantOp = getSource().getDefiningOp<arith::ConstantOp>()) {
- if (auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue())) {
- return DenseElementsAttr::get(cast<VectorType>(getType()),
- dense.getSplatValue<Attribute>());
- }
+ // shape_cast(constant) -> constant
+ if (auto splatAttr =
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ return DenseElementsAttr::get(getType(),
+ splatAttr.getSplatValue<Attribute>());
}
- // Replace shape_cast(poison) with poison.
- if (getSource().getDefiningOp<ub::PoisonOp>()) {
+ // shape_cast(poison) -> poison
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
return ub::PoisonAttr::get(getContext());
}
>From d47886c4c6516b877baeacade6a29b80403aad9b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 9 Apr 2025 13:06:13 -0700
Subject: [PATCH 4/4] use resultType
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 0ac22969629da..59f3b788cebed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,20 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
if (getSource().getType() == getType())
return getSource();
+ VectorType resultType = getType();
+
// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();
- if (getType() == srcType)
+ if (resultType == srcType)
return otherOp.getSource();
- if (srcType.getRank() < getType().getRank()) {
- if (!isValidShapeCast(srcType.getShape(), getType().getShape()))
+ if (srcType.getRank() < resultType.getRank()) {
+ if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
- } else if (srcType.getRank() > getType().getRank()) {
- if (!isValidShapeCast(getType().getShape(), srcType.getShape()))
+ } else if (srcType.getRank() > resultType.getRank()) {
+ if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
return {};
} else {
return {};
@@ -5639,14 +5641,14 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// Cancelling broadcast and shape cast ops.
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
- if (bcastOp.getSourceType() == getType())
+ if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
}
// shape_cast(constant) -> constant
if (auto splatAttr =
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
- return DenseElementsAttr::get(getType(),
+ return DenseElementsAttr::get(resultType,
splatAttr.getSplatValue<Attribute>());
}
More information about the Mlir-commits
mailing list