[Mlir-commits] [mlir] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (PR #74438)
Sean Silva
llvmlistbot at llvm.org
Tue Dec 5 01:46:44 PST 2023
================
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
// ShapeOfOp
//===----------------------------------------------------------------------===//
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
- auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
- if (!type || !type.hasStaticShape())
- return nullptr;
- Builder builder(getContext());
- return builder.getIndexTensorAttr(type.getShape());
-}
-
namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+ using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+ PatternRewriter &rewriter) const override {
+ auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+ if (!type || !type.hasStaticShape())
+ return failure();
+ Location loc = op.getLoc();
+ Value constShape =
+ rewriter
+ .create<ConstShapeOp>(loc,
+ rewriter.getIndexTensorAttr(type.getShape()))
+ .getResult();
+ if (constShape.getType() != op.getResult().getType())
+ constShape = rewriter.create<tensor::CastOp>(
----------------
silvasean wrote:
Do we need a new test that covers the creation of the cast op?
https://github.com/llvm/llvm-project/pull/74438
More information about the Mlir-commits
mailing list