[Mlir-commits] [mlir] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (PR #74438)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 01:29:28 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The `ShapeOfOp` folder used to generate invalid IR.

Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```

Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid.

This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (#<!-- -->74270).

---
Full diff: https://github.com/llvm/llvm-project/pull/74438.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td (-1) 
- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+25-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
-  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
-  if (!type || !type.hasStaticShape())
-    return nullptr;
-  Builder builder(getContext());
-  return builder.getIndexTensorAttr(type.getShape());
-}
-
 namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+    if (!type || !type.hasStaticShape())
+      return failure();
+    Location loc = op.getLoc();
+    Value constShape =
+        rewriter
+            .create<ConstShapeOp>(loc,
+                                  rewriter.getIndexTensorAttr(type.getShape()))
+            .getResult();
+    if (constShape.getType() != op.getResult().getType())
+      constShape = rewriter.create<tensor::CastOp>(
+          loc, op.getResult().getType(), constShape);
+    rewriter.replaceOp(op, constShape);
+    return success();
+  }
+};
+
 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
-               ExtractFromShapeOfExtentTensor>(context);
+               ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+      context);
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

``````````

</details>


https://github.com/llvm/llvm-project/pull/74438


More information about the Mlir-commits mailing list