[Mlir-commits] [mlir] [mlir][tensor] Fix bufferization interface for 'tensor.reshape' (PR #128590)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 24 14:50:41 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir
Author: Christopher Bate (christopherbate)
<details>
<summary>Changes</summary>
Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
---
Full diff: https://github.com/llvm/llvm-project/pull/128590.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+4)
- (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+27)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..8b7aee67ea5c2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
+ // Only the 'source' operand aliases the result.
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ if (reshapeOp.getSourceMutable() != opOperand)
+ return {};
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index af4f84640890b..2983cd30258a5 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
// -----
+// CHECK-LABEL: func @tensor_reshape_aliasing
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
+ %t1_static = arith.constant dense<0.> : tensor<10xf32>
+ // CHECK-DAG: %[[T1:.+]] = memref.cast
+ %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
+
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
+ %shape = bufferization.alloc_tensor() : tensor<2xindex>
+ // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
+ %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
+ // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
+ %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
+
+ // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
+ %reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+ // CHECK: return %[[RESHAPED]]
+ return %reshaped : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @reshape_with_non_identity_layout(
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,
``````````
</details>
https://github.com/llvm/llvm-project/pull/128590
More information about the Mlir-commits
mailing list