[Mlir-commits] [mlir] [mlir][tensor] Fix bufferization interface for 'tensor.reshape' (PR #128590)

Christopher Bate llvmlistbot at llvm.org
Wed Mar 12 08:44:46 PDT 2025


https://github.com/christopherbate updated https://github.com/llvm/llvm-project/pull/128590

>From 48867b0ee32d85f948c81615ba297cc800f93389 Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Mon, 24 Feb 2025 22:45:49 +0000
Subject: [PATCH] [mlir][tensor] Fix bufferization interface for
 'tensor.reshape'

Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
---
 .../BufferizableOpInterfaceImpl.cpp           |  4 +++
 .../Dialect/Tensor/one-shot-bufferize.mlir    | 27 +++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..8b7aee67ea5c2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
 
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
+    // Only the 'source' operand aliases the result.
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    if (reshapeOp.getSourceMutable() != opOperand)
+      return {};
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index af4f84640890b..2983cd30258a5 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
 
 // -----
 
+// CHECK-LABEL: func @tensor_reshape_aliasing
+//  CHECK-SAME:  (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
+  %t1_static = arith.constant dense<0.> : tensor<10xf32>
+  // CHECK-DAG: %[[T1:.+]] = memref.cast
+  %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
+
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
+  %shape = bufferization.alloc_tensor() : tensor<2xindex>
+  // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
+  %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
+  // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
+  %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
+
+  // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
+  %reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  // CHECK: return %[[RESHAPED]]
+  return %reshaped : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @reshape_with_non_identity_layout(
 // CHECK-SAME:    %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
 // CHECK-SAME:    %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,



More information about the Mlir-commits mailing list