[Mlir-commits] [mlir] [mlir][tensor] Fix bufferization interface for 'tensor.reshape' (PR #128590)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 14:50:41 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

<details>
<summary>Changes</summary>

Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.


---
Full diff: https://github.com/llvm/llvm-project/pull/128590.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+4) 
- (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+27) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..8b7aee67ea5c2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
 
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
+    // Only the 'source' operand aliases the result.
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    if (reshapeOp.getSourceMutable() != opOperand)
+      return {};
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index af4f84640890b..2983cd30258a5 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
 
 // -----
 
+// CHECK-LABEL: func @tensor_reshape_aliasing
+//  CHECK-SAME:  (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
+  %t1_static = arith.constant dense<0.> : tensor<10xf32>
+  // CHECK-DAG: %[[T1:.+]] = memref.cast
+  %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
+
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
+  %shape = bufferization.alloc_tensor() : tensor<2xindex>
+  // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
+  %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
+  // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
+  %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
+
+  // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
+  %reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  // CHECK: return %[[RESHAPED]]
+  return %reshaped : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @reshape_with_non_identity_layout(
 // CHECK-SAME:    %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
 // CHECK-SAME:    %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,

``````````

</details>


https://github.com/llvm/llvm-project/pull/128590


More information about the Mlir-commits mailing list