[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)

Rafael Ubal llvmlistbot at llvm.org
Sat Jul 13 08:33:28 PDT 2024

https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/98531

>From 1a2bffdfc9b824cf760bc01fae86c9ed1e9fa889 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 11 Jul 2024 12:39:42 -0400
Subject: [PATCH 1/3] Canonicalization pattern 'ShapeOfFromReshape'

 mlir/lib/Dialect/Shape/IR/Shape.cpp       | 22 +++++++++++++------
 mlir/test/Dialect/Shape/canonicalize.mlir | 26 +++++++++++++++++++++++
 2 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..639bd7851c35d 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+// to
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
   using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(shape::ShapeOfOp op,
                                 PatternRewriter &rewriter) const override {
-    if (!llvm::isa<ShapedType>(op.getArg().getType()))
+    auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+    if (!tensorReshapeOp)
       return failure();
-    if (llvm::isa<ShapedType>(op.getType()))
+    if (op.getType() != tensorReshapeOp.getShape().getType())
       return failure();
-    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
-                                                  op.getArg());
+    rewriter.replaceOp(op, tensorReshapeOp.getShape());
     return success();
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                             MLIRContext *context) {
-  patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
                ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..a17a7d1499935 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
 // -----
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+  // CHECK: return %[[SHAPE]] : tensor<?xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+// -----
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+  // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+  // CHECK: return %[[SHAPE_OF]] : !shape.shape
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+  return %1 : !shape.shape
+// -----
 // CHECK-LABEL: @cast_extent_tensor
 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
 func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {

>From d11a5d75266c94207a03569e715543788022ddbf Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Fri, 12 Jul 2024 18:46:34 -0400
Subject: [PATCH 2/3] Canonicalization pattern to fold chains of
 'tensor.reshape' ops

 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   |  8 ++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ++++++++++++++++
 2 files changed, 24 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0e840da9530ed..676a10dc7ba34 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1585,6 +1585,14 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
     return reshapedSource;
+  // If the producer of operand 'source' is another 'tensor.reshape' op, use the
+  // producer's input instead as the original tensor to reshape. This could
+  // render such producer dead code.
+  if (auto producer = getSource().getDefiningOp<ReshapeOp>()) {
+    setOperand(0, producer.getSource());
+    return getResult();
+  }
   auto source = getSource();
   auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
   auto resultTy = dyn_cast<RankedTensorType>(getType());
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index baa205b9f42c6..e9fbb40da10f9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -847,6 +847,22 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>
 // -----
+// CHECK-LABEL: func @fold_reshape_chain
+//  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
+//  CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//  CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//  CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
+//       CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
+//       CHECK: return %[[RESULT]]
+func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
+  %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+  return %2 : tensor<*xf32>
+// -----
 // CHECK-LABEL: func @fold_extract_constant_splat
 //   CHECK-NOT: tensor.extract_slice
 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>

>From 5bef900df7fb3f3a17b788c54fc2d70b80d3a523 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Sat, 13 Jul 2024 11:33:16 -0400
Subject: [PATCH 3/3] Addressed review feedback. Added new 'tensor.reshape'
 fold patterns. Added new comprehensive test 'unranked-tensor-lowering.mlir'

 mlir/lib/Dialect/Shape/IR/Shape.cpp           | 18 ++--
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 10 ++-
 mlir/test/Dialect/Shape/canonicalize.mlir     | 13 +++
 .../Shape/unranked-tensor-lowering.mlir       | 90 +++++++++++++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 11 +++
 5 files changed, 134 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir

diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 639bd7851c35d..8eb8e579954fa 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1719,11 +1719,19 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
                                 PatternRewriter &rewriter) const override {
     auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
     if (!tensorReshapeOp)
-      return failure();
-    if (op.getType() != tensorReshapeOp.getShape().getType())
-      return failure();
-    rewriter.replaceOp(op, tensorReshapeOp.getShape());
+      return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
+    if (!isa<TensorType>(op.getType()))
+      return rewriter.notifyMatchFailure(op, "result is not a tensor");
+    // Operand 'shape' of 'tensor.reshape' may now be used as the result of
+    // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
+    // formed IR, it may not be identical (dynamically vs statically shaped),
+    // in which case it needs to be cast first.
+    Value shape = tensorReshapeOp.getShape();
+    if (op.getType() != shape.getType())
+      shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
+    rewriter.replaceOp(op, shape);
     return success();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 676a10dc7ba34..1bc263db05a7d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1588,18 +1588,22 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
   // If the producer of operand 'source' is another 'tensor.reshape' op, use the
   // producer's input instead as the original tensor to reshape. This could
   // render such producer dead code.
-  if (auto producer = getSource().getDefiningOp<ReshapeOp>()) {
-    setOperand(0, producer.getSource());
+  if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
+    setOperand(0, reshapeOpProducer.getSource());
     return getResult();
   auto source = getSource();
   auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
   auto resultTy = dyn_cast<RankedTensorType>(getType());
   if (!sourceTy || !resultTy || sourceTy != resultTy)
     return {};
+  // If the source and result are both 1D tensors and have the same type, the
+  // reshape has no effect, even if the tensor if dynamically shaped.
+  if (sourceTy.getRank() == 1)
+    return source;
   if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
     auto elements = fromElements.getElements();
     bool dynamicNoop =
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index a17a7d1499935..5b98a7790debf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1373,6 +1373,19 @@ func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -
 // -----
+// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
+func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
+  // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
+  // CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
+  %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
+  %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+  return %1 : tensor<?xindex>
+// -----
 // CHECK-LABEL: func @shape_of_from_reshape_nofold
 // CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
 // CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
diff --git a/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir b/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
new file mode 100644
index 0000000000000..c1fcbd1ad045f
--- /dev/null
+++ b/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s
+// This test verifies the simplification of IR patterns that emerge when
+// lowering high-level element-wise ops with unranked tensor inputs. Consider
+// the following function incrementing and doubling the value of an input
+// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
+//  func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
+//    %0 = hl.inc %input : tensor<*xf32>
+//    %1 = hl.double %0 : tensor<*xf32>
+//    return %1 : tensor<*xf32>
+//  }
+// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
+// a 1D tensor, creating a 1D tensor splat with the same total size as the input
+// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
+// reshaping the result back into the original input shape. A similar process
+// applies for 'hl.double', except with a tensor splat with value 2.0 and an
+// 'arith.mulf' op. The body of the function in the test below contains the full
+// sequence.
+// Since such lowering process would operate on individual 'hl' ops in a
+// context-oblivious manner, the emitted code produces a redundant IR pattern
+// where the result of 'arith.addf' is reshaped into an unranked tensor, just
+// for it to be immediately reshaped back into the 1D tensor consumed by
+// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
+// shape ('shape.shape_of') and size ('shape.num_elements').
+// This test verifies that the consecutive application of a canonicalization and
+// a CSE pass successfully simplifies this emerging pattern, leading to a
+// version of the code in which the result of the emitted 'arith.addf' op
+// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
+// associated with 'hl.double', as observed in the FileCheck directives. The
+// main rewrite patterns at play are 'shape.shape_of' canonicalization,
+// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
+// elimination.
+// CHECK-LABEL: @unranked_tensor_lowering
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
+// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>
+// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
+// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>
+// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[PRODUCT]] : tensor<*xf32>
+func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {
+  // Collapse input
+  %input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
+  %input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
+  %input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
+  %input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+  // Second operand for sum
+  %one = arith.constant 1.0 : f32
+  %one_splat = tensor.splat %one[%input_size] : tensor<?xf32>
+  // Compute sum and expand it
+  %sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
+  %sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+  // Collapse sum
+  %sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
+  %sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
+  %sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
+  %sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+  // Second operand for product
+  %two = arith.constant 2.0 : f32
+  %two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>
+  // Compute product and expand it
+  %product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
+  %product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+  return %product : tensor<*xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index e9fbb40da10f9..4b8efde78cc23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -863,6 +863,17 @@ func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>,
 // -----
+// CHECK-LABEL: func @fold_reshape_1d
+//  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
+//  CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
+//       CHECK: return %[[INPUT]]
+func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
+  %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+// -----
 // CHECK-LABEL: func @fold_extract_constant_splat
 //   CHECK-NOT: tensor.extract_slice
 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>

More information about the Mlir-commits mailing list