[Mlir-commits] [mlir] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (PR #74438)

Matthias Springer llvmlistbot at llvm.org
Tue Dec 5 16:36:12 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/74438

>From 18ab550e0dabc6eb76aa290dc474ce5fedf9ed75 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Dec 2023 09:34:52 +0900
Subject: [PATCH] [mlir][shape] Turn `ShapeOfOp` folding into canonicalization
 pattern

The `ShapeOfOp` folder used to generate invalid IR.

Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```

Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```

This rewrite cannot be implemented as a folder because the result type may have to change. In the above example, the original `shape.shape_of` op had a return type of `tensor<?xindex>`, but the folded attribute (materialized as a `shape.const_shape` op) must have a type of `tensor<0xf32>` to be valid.

This commit fixes tests such as `mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after each pattern application (#74270).
---
 .../include/mlir/Dialect/Shape/IR/ShapeOps.td |  1 -
 mlir/lib/Dialect/Shape/IR/Shape.cpp           | 34 ++++++++++++++-----
 mlir/test/Dialect/Shape/canonicalize.mlir     | 12 +++++++
 3 files changed, 37 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
 
   let hasCanonicalizer = 1;
-  let hasFolder = 1;
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
 // ShapeOfOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
-  auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
-  if (!type || !type.hasStaticShape())
-    return nullptr;
-  Builder builder(getContext());
-  return builder.getIndexTensorAttr(type.getShape());
-}
-
 namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+  using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+                                PatternRewriter &rewriter) const override {
+    auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+    if (!type || !type.hasStaticShape())
+      return failure();
+    Location loc = op.getLoc();
+    Value constShape =
+        rewriter
+            .create<ConstShapeOp>(loc,
+                                  rewriter.getIndexTensorAttr(type.getShape()))
+            .getResult();
+    if (constShape.getType() != op.getResult().getType())
+      constShape = rewriter.create<tensor::CastOp>(
+          loc, op.getResult().getType(), constShape);
+    rewriter.replaceOp(op, constShape);
+    return success();
+  }
+};
+
 struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
   patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
-               ExtractFromShapeOfExtentTensor>(context);
+               ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+      context);
 }
 
 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 8edbae3baf52e..40b137f1fa36e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size {
   %result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
   return %result : !shape.size
 }
+
+// -----
+
+// CHECK-LABEL: func @shape_of_0d(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<f32>
+//       CHECK:   %[[const:.*]] = shape.const_shape [] : tensor<0xindex>
+//       CHECK:   %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor<?xindex>
+//       CHECK:   return %[[cast]]
+func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
+  %0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
+  return %0 : tensor<?xindex>
+}



More information about the Mlir-commits mailing list