[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)

Rafael Ubal llvmlistbot at llvm.org
Fri Jul 12 15:46:58 PDT 2024


https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/98531

>From 1a2bffdfc9b824cf760bc01fae86c9ed1e9fa889 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 11 Jul 2024 12:39:42 -0400
Subject: [PATCH 1/2] Canonicalization pattern 'ShapeOfFromReshape'

---
 mlir/lib/Dialect/Shape/IR/Shape.cpp       | 22 +++++++++++++------
 mlir/test/Dialect/Shape/canonicalize.mlir | 26 +++++++++++++++++++++++
 2 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c33457..639bd7851c35 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
   }
 };
 
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
 
   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa<ShapedType>(op.getArg().getType()))
+    auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+    if (!tensorReshapeOp)
       return failure();
-    if (llvm::isa<ShapedType>(op.getType()))
+    if (op.getType() != tensorReshapeOp.getShape().getType())
       return failure();
 
-    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
-                                                  op.getArg());
+    rewriter.replaceOp(op, tensorReshapeOp.getShape());
     return success();
   }
 };
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
                ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
       context);
 }
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36..a17a7d149993 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
 
 // -----
 
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK: return %[[SHAPE]] : tensor<?xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+  // CHECK: return %[[SHAPE_OF]] : !shape.shape
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+}
+
+// -----
+
 // CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
 func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {

>From d11a5d75266c94207a03569e715543788022ddbf Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Fri, 12 Jul 2024 18:46:34 -0400
Subject: [PATCH 2/2] Canonicalization pattern to fold chains of
 'tensor.reshape' ops

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  8 ++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0e840da9530e..676a10dc7ba3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1585,6 +1585,14 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
           getResult().getType()))
     return reshapedSource;
 
+  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
+  // producer's input instead as the original tensor to reshape. This could
+  // render such producer dead code.
+  if (auto producer = getSource().getDefiningOp<ReshapeOp>()) {
+    setOperand(0, producer.getSource());
+    return getResult();
+  }
+
   auto source = getSource();
   auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
   auto resultTy = dyn_cast<RankedTensorType>(getType());
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index baa205b9f42c..e9fbb40da10f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -847,6 +847,22 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>
 
 // -----
 
+// CHECK-LABEL: func @fold_reshape_chain
+//  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
+//  CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//  CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//  CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//       CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
+//       CHECK: return %[[RESULT]]
+func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
+  %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  return %2 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract_constant_splat
 //   CHECK-NOT: tensor.extract_slice
 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>



More information about the Mlir-commits mailing list