[Mlir-commits] [mlir] 38d0b2d - [mlir] New canonicalization patterns for shape.shape_of and tensor.reshape (#98531)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 19 07:09:34 PDT 2024
Author: Rafael Ubal
Date: 2024-07-19T10:09:31-04:00
New Revision: 38d0b2d174efe05504a18988299b4d78d37999b7
URL: https://github.com/llvm/llvm-project/commit/38d0b2d174efe05504a18988299b4d78d37999b7
DIFF: https://github.com/llvm/llvm-project/commit/38d0b2d174efe05504a18988299b4d78d37999b7.diff
LOG: [mlir] New canonicalization patterns for shape.shape_of and tensor.reshape (#98531)
This PR includes 3 new canonicalization patterns:
- Operation `shape.shape_of`: shape of reshape
```
// Before
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
%reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%0 = shape.shape_of %reshape : tensor<*xf32> -> tensor<?xindex>
return %0 : tensor<?xindex>
}
// After
func.func @f(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
return %arg1 : tensor<?xindex>
}
```
- Operation `tensor.reshape`: reshape of reshape
```
// Before
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
%0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
%1 = tensor.reshape %0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// After
func.func @fold_tensor_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: tensor<?xindex>) -> tensor<*xf32> {
%reshape = tensor.reshape %arg0(%arg2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
return %reshape : tensor<*xf32>
}
```
- Operation `tensor.reshape`: reshape 1D to 1D
```
// Before
func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
%0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// After
func.func @fold_reshape_1d(%arg0: tensor<?xf32>, %arg1: tensor<1xindex>) -> tensor<?xf32> {
return %arg0 : tensor<?xf32>
}
```
These three canonicalization patterns cooperate to simplify the IR
structure emerging from the lowering of certain element-wise ops with
unranked tensor inputs. See file `unranked-tensor-lowering.mlir` in the
proposed change list for a detailed example and description.
For context, this PR is meant to enable code optimizations for the code
generated while lowering ops `quant.qcast` and `quant.dcast` with
unranked tensors, as proposed in
https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942
(implementation currently in progress).
Added:
mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 58c3f4c334577..8eb8e579954fa 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -1702,18 +1702,36 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
}
};
-struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
+// Canonicalize
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+//
+// to
+//
+// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+// %1 = %shape
+//
+struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(shape::ShapeOfOp op,
PatternRewriter &rewriter) const override {
- if (!llvm::isa<ShapedType>(op.getArg().getType()))
- return failure();
- if (llvm::isa<ShapedType>(op.getType()))
- return failure();
-
- rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op.getOperation(),
- op.getArg());
+ auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
+ if (!tensorReshapeOp)
+ return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
+ if (!isa<TensorType>(op.getType()))
+ return rewriter.notifyMatchFailure(op, "result is not a tensor");
+
+ // Operand 'shape' of 'tensor.reshape' may now be used as the result of
+ // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
+ // formed IR, it may not be identical (dynamically vs statically shaped),
+ // in which case it needs to be cast first.
+ Value shape = tensorReshapeOp.getShape();
+ if (op.getType() != shape.getType())
+ shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), shape);
+
+ rewriter.replaceOp(op, shape);
return success();
}
};
@@ -1753,7 +1771,7 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
+ patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
context);
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0e840da9530ed..0751ffc419cbf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1585,13 +1585,25 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
getResult().getType()))
return reshapedSource;
+ // If the producer of operand 'source' is another 'tensor.reshape' op, use the
+ // producer's input instead as the original tensor to reshape. This could
+ // render such producer dead code.
+ if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
+ getSourceMutable().assign(reshapeOpProducer.getSource());
+ return getResult();
+ }
+
auto source = getSource();
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
auto resultTy = dyn_cast<RankedTensorType>(getType());
-
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
+ // If the source and result are both 1D tensors and have the same type, the
+ // reshape has no effect, even if the tensor is dynamically shaped.
+ if (sourceTy.getRank() == 1)
+ return source;
+
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
auto elements = fromElements.getElements();
bool dynamicNoop =
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 40b137f1fa36e..5b98a7790debf 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -1361,6 +1361,45 @@ func.func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape
// -----
+// CHECK-LABEL: func @shape_of_from_reshape
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> tensor<?xindex> {
+ // CHECK: return %[[SHAPE]] : tensor<?xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_compatible_types
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<5xindex>
+func.func @shape_of_from_reshape_compatible_types(%arg0: tensor<*xf32>, %arg1: tensor<5xindex>) -> tensor<?xindex> {
+ // CHECK: %[[CAST_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<5xindex> to tensor<?xindex>
+ // CHECK: return %[[CAST_SHAPE]] : tensor<?xindex>
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<5xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
+ return %1 : tensor<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @shape_of_from_reshape_nofold
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>
+func.func @shape_of_from_reshape_nofold(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> !shape.shape {
+ // CHECK: %[[RESHAPED:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE]]) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<*xf32> -> !shape.shape
+ // CHECK: return %[[SHAPE_OF]] : !shape.shape
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
+ return %1 : !shape.shape
+}
+
+// -----
+
// CHECK-LABEL: @cast_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func.func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
diff --git a/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir b/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
new file mode 100644
index 0000000000000..c1fcbd1ad045f
--- /dev/null
+++ b/mlir/test/Dialect/Shape/unranked-tensor-lowering.mlir
@@ -0,0 +1,90 @@
+// RUN: mlir-opt -split-input-file -canonicalize -cse %s | FileCheck %s
+
+// This test verifies the simplification of IR patterns that emerge when
+// lowering high-level element-wise ops with unranked tensor inputs. Consider
+// the following function incrementing and doubling the value of an input
+// unranked tensor using ops in a hypothetical high-level dialect called 'hl':
+//
+// func.func @f(%input: tensor<*xf32>) -> tensor<*xf32> {
+// %0 = hl.inc %input : tensor<*xf32>
+// %1 = hl.double %0 : tensor<*xf32>
+// return %1 : tensor<*xf32>
+// }
+//
+// A possible strategy to lower 'hl.inc' consists in reshaping its operand into
+// a 1D tensor, creating a 1D tensor splat with the same total size as the input
+// operand and with value 1.0, adding both 1D tensors using 'arith.addf', and
+// reshaping the result back into the original input shape. A similar process
+// applies for 'hl.double', except with a tensor splat with value 2.0 and an
+// 'arith.mulf' op. The body of the function in the test below contains the full
+// sequence.
+//
+// Since such lowering process would operate on individual 'hl' ops in a
+// context-oblivious manner, the emitted code produces a redundant IR pattern
+// where the result of 'arith.addf' is reshaped into an unranked tensor, just
+// for it to be immediately reshaped back into the 1D tensor consumed by
+// 'arith.mulf'. This entails the overhead of re-computing the unranked tensor
+// shape ('shape.shape_of') and size ('shape.num_elements').
+//
+// This test verifies that the consecutive application of a canonicalization and
+// a CSE pass successfully simplifies this emerging pattern, leading to a
+// version of the code in which the result of the emitted 'arith.addf' op
+// associated with 'hl.inc' is directly consumed by the 'arith.mulf' op
+// associated with 'hl.double', as observed in the FileCheck directives. The
+// main rewrite patterns at play are 'shape.shape_of' canonicalization,
+// 'tensor.reshape' canonicalization, and 'shape.num_elements' subexpression
+// elimination.
+//
+
+// CHECK-LABEL: @unranked_tensor_lowering
+// CHECK-SAME: %[[INPUT:.*]]: tensor<*xf32>
+
+// CHECK-DAG: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[TWO:.*]] = arith.constant 2.000000e+00 : f32
+
+// CHECK: %[[INPUT_SHAPE:.*]] = shape.shape_of %[[INPUT]] : tensor<*xf32> -> tensor<?xindex>
+// CHECK: %[[INPUT_SIZE:.*]] = shape.num_elements %[[INPUT_SHAPE]] : tensor<?xindex> -> index
+// CHECK: %[[INPUT_COLLAPSED_SHAPE:.*]] = tensor.from_elements %[[INPUT_SIZE]] : tensor<1xindex>
+// CHECK: %[[INPUT_COLLAPSED:.*]] = tensor.reshape %[[INPUT]](%[[INPUT_COLLAPSED_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+// CHECK: %[[ONE_SPLAT:.*]] = tensor.splat %[[ONE]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
+// CHECK: %[[SUM_COLLAPSED:.*]] = arith.addf %[[INPUT_COLLAPSED]], %[[ONE_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[TWO_SPLAT:.*]] = tensor.splat %[[TWO]]{{\[}}%[[INPUT_SIZE]]] : tensor<?xf32>
+// CHECK: %[[PRODUCT_COLLAPSED:.*]] = arith.mulf %[[SUM_COLLAPSED]], %[[TWO_SPLAT]] : tensor<?xf32>
+
+// CHECK: %[[PRODUCT:.*]] = tensor.reshape %[[PRODUCT_COLLAPSED]](%[[INPUT_SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+// CHECK: return %[[PRODUCT]] : tensor<*xf32>
+
+func.func @unranked_tensor_lowering(%input: tensor<*xf32>) -> tensor<*xf32> {
+
+ // Collapse input
+ %input_shape = shape.shape_of %input : tensor<*xf32> -> tensor<?xindex>
+ %input_size = shape.num_elements %input_shape : tensor<?xindex> -> index
+ %input_collapsed_shape = tensor.from_elements %input_size : tensor<1xindex>
+ %input_collapsed = tensor.reshape %input(%input_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+ // Second operand for sum
+ %one = arith.constant 1.0 : f32
+ %one_splat = tensor.splat %one[%input_size] : tensor<?xf32>
+
+ // Compute sum and expand it
+ %sum_collapsed = arith.addf %input_collapsed, %one_splat : tensor<?xf32>
+ %sum = tensor.reshape %sum_collapsed(%input_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ // Collapse sum
+ %sum_shape = shape.shape_of %sum : tensor<*xf32> -> tensor<?xindex>
+ %sum_size = shape.num_elements %sum_shape : tensor<?xindex> -> index
+ %sum_collapsed_shape = tensor.from_elements %sum_size : tensor<1xindex>
+ %sum_collapsed_0 = tensor.reshape %sum(%sum_collapsed_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
+
+ // Second operand for product
+ %two = arith.constant 2.0 : f32
+ %two_splat = tensor.splat %two[%sum_size] : tensor<?xf32>
+
+ // Compute product and expand it
+ %product_collapsed = arith.mulf %sum_collapsed_0, %two_splat : tensor<?xf32>
+ %product = tensor.reshape %product_collapsed(%sum_shape) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ return %product : tensor<*xf32>
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index baa205b9f42c6..4b8efde78cc23 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -847,6 +847,33 @@ func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32>
// -----
+// CHECK-LABEL: func @fold_reshape_chain
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
+// CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
+// CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
+// CHECK: return %[[RESULT]]
+func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
+ %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ return %2 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_reshape_1d
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
+// CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
+// CHECK: return %[[INPUT]]
+func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
+ %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
More information about the Mlir-commits
mailing list