[Mlir-commits] [mlir] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (PR #74438)

Sean Silva llvmlistbot at llvm.org
Tue Dec 5 01:46:44 PST 2023


================
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
-  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
-  if (!type || !type.hasStaticShape())
-    return nullptr;
-  Builder builder(getContext());
-  return builder.getIndexTensorAttr(type.getShape());
-}
-
 namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+    if (!type || !type.hasStaticShape())
+      return failure();
+    Location loc = op.getLoc();
+    Value constShape =
+        rewriter
+            .create<ConstShapeOp>(loc,
+                                  rewriter.getIndexTensorAttr(type.getShape()))
+            .getResult();
+    if (constShape.getType() != op.getResult().getType())
+      constShape = rewriter.create<tensor::CastOp>(
----------------
silvasean wrote:

Do we need a new test that covers the creation of the cast op?

https://github.com/llvm/llvm-project/pull/74438


More information about the Mlir-commits mailing list