[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 11 12:59:57 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-shape
Author: Rafael Ubal (rafaelubalmw)
<details>
<summary>Changes</summary>
The proposed canonicalization pattern converts
```
func.func @<!-- -->f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}
```
to
```
func.func @<!-- -->f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
return %arg1 : tensor<?xindex>
}
```
When lowering element-wise ops with unranked tensor operands, it may be necessary to reshape inputs into a 1D tensor. The following op pattern emerges:
```
%unranked_shape = shape.shape_of %unranked_input
%ranked_shape = shape.num_elements %unranked_shape
%ranked_input = tensor.reshape %input, %ranked_shape
%ranked_result = ... %ranked_input ...
%unranked_result = tensor.reshape %ranked_result, %unranked_shape
```
When 2 consecutive element-wise operations `op1` and `op2` with unranked inputs are lowered into such a pattern, the proposed canonicalization pattern fuses the last `tensor.reshape` from `op1` with the first `shape.shape_of` from `op2`. CSE may then fuse both occurrences of `shape.num_elements` from `op1` and `op2`.
---
Full diff: https://github.com/llvm/llvm-project/pull/98531.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+16-6)
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+26)
``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..639bd7851c35d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
return failure();
- if (llvm::isa<ShapedType>(op.getType()))
+ if (op.getType() != tensorReshapeOp.getShape().getType())
return failure();
- rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
- op.getArg());
+ rewriter.replaceOp(op, tensorReshapeOp.getShape());
return success();
}
};
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..a17a7d1499935 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
// -----
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+ // CHECK: return %[[SHAPE]] : tensor<?xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+ // CHECK: return %[[SHAPE_OF]] : !shape.shape
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+ return %1 : !shape.shape
+}
+
+// -----
+
// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
``````````
</details>
https://github.com/llvm/llvm-project/pull/98531
More information about the Mlir-commits
mailing list