[Mlir-commits] [mlir] cd85f5d - [mlir] canonicalizer: shape_cast(poison) -> poison (#133988)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 11 07:13:07 PDT 2025
Author: James Newling
Date: 2025-04-11T15:13:03+01:00
New Revision: cd85f5dbdf135347a9912dde148ec9fd325ba8c1
URL: https://github.com/llvm/llvm-project/commit/cd85f5dbdf135347a9912dde148ec9fd325ba8c1
DIFF: https://github.com/llvm/llvm-project/commit/cd85f5dbdf135347a9912dde148ec9fd325ba8c1.diff
LOG: [mlir] canonicalizer: shape_cast(poison) -> poison (#133988)
Based on the ShapeCastConstantFolder, this pattern replaces
%0 = ub.poison : vector<2x3xf32>
%1 = vector.shape_cast %0 vector<2x3xf32> to vector<6xf32>
with
%1 = ub.poison : vector<6xf32>
---------
Signed-off-by: James Newling <james.newling at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 98d98f067de14..59f3b788cebed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -42,6 +42,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
#include <cassert>
#include <cstdint>
@@ -5611,18 +5612,20 @@ LogicalResult ShapeCastOp::verify() {
}
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
+
// No-op shape cast.
- if (getSource().getType() == getResult().getType())
+ if (getSource().getType() == getType())
return getSource();
+ VectorType resultType = getType();
+
// Canceling shape casts.
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
- if (getResult().getType() == otherOp.getSource().getType())
- return otherOp.getSource();
- // Only allows valid transitive folding.
- VectorType srcType = llvm::cast<VectorType>(otherOp.getSource().getType());
- VectorType resultType = llvm::cast<VectorType>(getResult().getType());
+ // Only allows valid transitive folding (expand/collapse dimensions).
+ VectorType srcType = otherOp.getSource().getType();
+ if (resultType == srcType)
+ return otherOp.getSource();
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
@@ -5632,43 +5635,32 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
} else {
return {};
}
-
setOperand(otherOp.getSource());
return getResult();
}
// Cancelling broadcast and shape cast ops.
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
- if (bcastOp.getSourceType() == getType())
+ if (bcastOp.getSourceType() == resultType)
return bcastOp.getSource();
}
+ // shape_cast(constant) -> constant
+ if (auto splatAttr =
+ llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
+ return DenseElementsAttr::get(resultType,
+ splatAttr.getSplatValue<Attribute>());
+ }
+
+ // shape_cast(poison) -> poison
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ return ub::PoisonAttr::get(getContext());
+ }
+
return {};
}
namespace {
-// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
-class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto constantOp =
- shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
- if (!constantOp)
- return failure();
- // Only handle splat for now.
- auto dense = llvm::dyn_cast<SplatElementsAttr>(constantOp.getValue());
- if (!dense)
- return failure();
- auto newAttr =
- DenseElementsAttr::get(llvm::cast<VectorType>(shapeCastOp.getType()),
- dense.getSplatValue<Attribute>());
- rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
- return success();
- }
-};
/// Helper function that computes a new vector type based on the input vector
/// type by removing the trailing one dims:
@@ -5828,8 +5820,9 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ShapeCastConstantFolder, ShapeCastCreateMaskFolderTrailingOneDim,
- ShapeCastBroadcastFolder>(context);
+ results
+ .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7db8ec834be7..72064fb42741a 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1167,6 +1167,20 @@ func.func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
// -----
+// CHECK-LABEL: shape_cast_poison
+// CHECK-DAG: %[[CST1:.*]] = ub.poison : vector<3x4x2xi32>
+// CHECK-DAG: %[[CST0:.*]] = ub.poison : vector<20x2xf32>
+// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
+func.func @shape_cast_poison() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
+ %poison = ub.poison : vector<5x4x2xf32>
+ %poison_1 = ub.poison : vector<12x2xi32>
+ %0 = vector.shape_cast %poison : vector<5x4x2xf32> to vector<20x2xf32>
+ %1 = vector.shape_cast %poison_1 : vector<12x2xi32> to vector<3x4x2xi32>
+ return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_constant
// CHECK-DAG: %[[CST1:.*]] = arith.constant dense<1> : vector<2x13x3xi32>
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<12x2xf32>
More information about the Mlir-commits
mailing list