[Mlir-commits] [mlir] dbb782d - [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (#74438)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 5 16:41:29 PST 2023
Author: Matthias Springer
Date: 2023-12-06T09:41:24+09:00
New Revision: dbb782dffdbd37e4aafa745eba9ba0f2831e21e8
URL: https://github.com/llvm/llvm-project/commit/dbb782dffdbd37e4aafa745eba9ba0f2831e21e8
DIFF: https://github.com/llvm/llvm-project/commit/dbb782dffdbd37e4aafa745eba9ba0f2831e21e8.diff
LOG: [mlir][shape] Turn `ShapeOfOp` folding into canonicalization pattern (#74438)
The `ShapeOfOp` folder used to generate invalid IR.
Input:
```
%0 = shape.shape_of %arg1 : tensor<index> -> tensor<?xindex>
```
Output:
```
%0 = "shape.const_shape"() <{shape = dense<> : tensor<0xindex>}> : () -> tensor<?xindex>
error: 'shape.const_shape' op inferred type(s) 'tensor<0xindex>' are incompatible with return type(s) of operation 'tensor<?xindex>'
```
This rewrite cannot be implemented as a folder because the result type
may have to change. In the above example, the original `shape.shape_of`
op had a return type of `tensor<?xindex>`, but the folded attribute
(materialized as a `shape.const_shape` op) must have a type of
`tensor<0xf32>` to be valid.
This commit fixes tests such as
`mlir/test/Dialect/Shape/canonicalize.mlir` when verifying the IR after
each pattern application (#74270).
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 3c9f45366fa2b..08a0398e74b0c 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -566,7 +566,6 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
let hasCanonicalizer = 1;
- let hasFolder = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 2444556a45635..4f829db1305c8 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1678,15 +1678,30 @@ LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
// ShapeOfOp
//===----------------------------------------------------------------------===//
-OpFoldResult ShapeOfOp::fold(FoldAdaptor) {
- auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
- if (!type || !type.hasStaticShape())
- return nullptr;
- Builder builder(getContext());
- return builder.getIndexTensorAttr(type.getShape());
-}
-
namespace {
+/// Replace shape_of(x) where x has a constant shape with a const_shape op.
+struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
+ using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::ShapeOfOp op,
+ PatternRewriter &rewriter) const override {
+ auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
+ if (!type || !type.hasStaticShape())
+ return failure();
+ Location loc = op.getLoc();
+ Value constShape =
+ rewriter
+ .create<ConstShapeOp>(loc,
+ rewriter.getIndexTensorAttr(type.getShape()))
+ .getResult();
+ if (constShape.getType() != op.getResult().getType())
+ constShape = rewriter.create<tensor::CastOp>(
+ loc, op.getResult().getType(), constShape);
+ rewriter.replaceOp(op, constShape);
+ return success();
+ }
+};
+
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
@@ -1739,7 +1754,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
- ExtractFromShapeOfExtentTensor>(context);
+ ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
+ context);
}
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 8edbae3baf52e..40b137f1fa36e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1492,3 +1492,15 @@ func.func @add_poison() -> !shape.size {
%result = shape.add %1, %2 : !shape.size, !shape.size -> !shape.size
return %result : !shape.size
}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_0d(
+// CHECK-SAME: %[[arg0:.*]]: tensor<f32>
+// CHECK: %[[const:.*]] = shape.const_shape [] : tensor<0xindex>
+// CHECK: %[[cast:.*]] = tensor.cast %[[const]] : tensor<0xindex> to tensor<?xindex>
+// CHECK: return %[[cast]]
+func.func @shape_of_0d(%arg0: tensor<f32>) -> tensor<?xindex> {
+ %0 = shape.shape_of %arg0 : tensor<f32> -> tensor<?xindex>
+ return %0 : tensor<?xindex>
+}
More information about the Mlir-commits
mailing list