[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 11 12:59:57 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-shape

Author: Rafael Ubal (rafaelubalmw)

<details>
<summary>Changes</summary>

The proposed canonicalization pattern converts

```
func.func @<!-- -->f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
  %0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
  return %0 : tensor<?xindex>
}
```

to

```
func.func @<!-- -->f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
  return %arg1 : tensor<?xindex>
}
```

When lowering element-wise ops with unranked tensor operands, it may be necessary to reshape inputs into a 1D tensor. The following op pattern emerges:

```
%unranked_shape = shape.shape_of %unranked_input
%ranked_shape = shape.num_elements %unranked_shape
%ranked_input = tensor.reshape %input, %ranked_shape

%ranked_result = ... %ranked_input ...

%unranked_result = tensor.reshape %ranked_result, %unranked_shape
```

When 2 consecutive element-wise operations `op1` and `op2` with unranked inputs are lowered into such a pattern, the proposed canonicalization pattern fuses the last `tensor.reshape` from `op1` with the first `shape.shape_of` from `op2`. CSE may then fuse both occurrences of `shape.num_elements` from `op1` and `op2`.

---
Full diff: https://github.com/llvm/llvm-project/pull/98531.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+16-6) 
- (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+26) 


``````````diff
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..639bd7851c35d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa<ShapedType>(op.getArg().getType()))
+    auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+    if (!tensorReshapeOp)
       return failure();
-    if (llvm::isa<ShapedType>(op.getType()))
+    if (op.getType() != tensorReshapeOp.getShape().getType())
       return failure();
 
-    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
-                                                  op.getArg());
+    rewriter.replaceOp(op, tensorReshapeOp.getShape());
     return success();
   }
 };
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
                ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
       context);
 }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..a17a7d1499935 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
 
 // -----
 
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK: return %[[SHAPE]] : tensor<?xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+  // CHECK: return %[[SHAPE_OF]] : !shape.shape
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+}
+
+// -----
+
 // CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
 func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/98531


More information about the Mlir-commits mailing list