[Mlir-commits] [mlir] dbb782d - [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (#74438)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 5 16:41:29 PST 2023


Author: Matthias Springer
Date: 2023-12-06T09:41:24+09:00
New Revision: dbb782dffdbd37e4aafa745eba9ba0f2831e21e8

URL: https://github.com/llvm/llvm-project/commit/dbb782dffdbd37e4aafa745eba9ba0f2831e21e8
DIFF: https://github.com/llvm/llvm-project/commit/dbb782dffdbd37e4aafa745eba9ba0f2831e21e8.diff

LOG: [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (#74438)

The `ShapeOfOp` folder used to generate invalid IR.

Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```

Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```

This rewrite cannot be implemented as a folder because the result type
may have to change. In the above example, the original `shape.shape_of`
op had a return type of `tensor<?xindex>`, but the folded attribute
(materialized as a `shape.const_shape` op) must have a type of
`tensor<0xf32>` to be valid.

This commit fixes tests such as
`mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after
each pattern application (#74270).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
-  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
-  if (!type || !type.hasStaticShape())
-    return nullptr;
-  Builder builder(getContext());
-  return builder.getIndexTensorAttr(type.getShape());
-}
-
 namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+    if (!type || !type.hasStaticShape())
+      return failure();
+    Location loc = op.getLoc();
+    Value constShape =
+        rewriter
+            .create<ConstShapeOp>(loc,
+                                  rewriter.getIndexTensorAttr(type.getShape()))
+            .getResult();
+    if (constShape.getType() != op.getResult().getType())
+      constShape = rewriter.create<tensor::CastOp>(
+          loc, op.getResult().getType(), constShape);
+    rewriter.replaceOp(op, constShape);
+    return success();
+  }
+};
+
 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
-               ExtractFromShapeOfExtentTensor>(context);
+               ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+      context);
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 8edbae3baf52e..40b137f1fa36e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size {
   %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
   return %result : !shape.size
 }
+
+// -----
+
+// CHECK-LABEL: func @shape_of_0d(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<f32>
+//       CHECK:   %[[const:.*]] = shape.const_shape [] : tensor<0xindex>
+//       CHECK:   %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor<?xindex>
+//       CHECK:   return %[[cast]]
+func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
+  %0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
+  return %0 : tensor<?xindex>
+}


        


More information about the Mlir-commits mailing list