[Mlir-commits] [mlir] [mlir] Canonicalization pattern for 'shape.shape_of' (PR #98531)
Rafael Ubal
llvmlistbot at llvm.org
Fri Jul 12 15:46:58 PDT 2024
https://github.com/rafaelubalmw updated https://github.com/llvm/llvm-project/pull/98531
>From 1a2bffdfc9b824cf760bc01fae86c9ed1e9fa889 Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Thu, 11 Jul 2024 12:39:42 -0400
Subject: [PATCH 1/2] Canonicalization pattern 'ShapeOfFromReshape'
---
mlir/lib/Dialect/Shape/IR/Shape.cpp | 22 +++++++++++++------
mlir/test/Dialect/Shape/canonicalize.mlir | 26 +++++++++++++++++++++++
2 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c33457..639bd7851c35 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,28 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
return failure();
- if (llvm::isa<ShapedType>(op.getType()))
+ if (op.getType() != tensorReshapeOp.getShape().getType())
return failure();
- rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
- op.getArg());
+ rewriter.replaceOp(op, tensorReshapeOp.getShape());
return success();
}
};
@@ -1753,7 +1763,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36..a17a7d149993 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,32 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
// -----
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+ // CHECK: return %[[SHAPE]] : tensor<?xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+ // CHECK: return %[[SHAPE_OF]] : !shape.shape
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+ return %1 : !shape.shape
+}
+
+// -----
+
// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
>From d11a5d75266c94207a03569e715543788022ddbf Mon Sep 17 00:00:00 2001
From: Rafael Ubal Tena <rubal at mathworks.com>
Date: Fri, 12 Jul 2024 18:46:34 -0400
Subject: [PATCH 2/2] Canonicalization pattern to fold chains of
'tensor.reshape' ops
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 8 ++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ++++++++++++++++
2 files changed, 24 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0e840da9530e..676a10dc7ba3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1585,6 +1585,14 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
getResult().getType()))
return reshapedSource;
+ // If the producer of operand 'source' is another 'tensor.reshape' op, use the
+ // producer's input instead as the original tensor to reshape. This could
+ // render such producer dead code.
+ if (auto producer = getSource().getDefiningOp<ReshapeOp>()) {
+ setOperand(0, producer.getSource());
+ return getResult();
+ }
+
auto source = getSource();
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
auto resultTy = dyn_cast<RankedTensorType>(getType());
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index baa205b9f42c..e9fbb40da10f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -847,6 +847,22 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>
// -----
+// CHECK-LABEL: func @fold_reshape_chain
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
+// CHECK: return %[[RESULT]]
+func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
+ %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
More information about the Mlir-commits
mailing list