[Mlir-commits] [mlir] [mlir][tensor] Fix bufferization interface for 'tensor.reshape' (PR #128590)
Christopher Bate
llvmlistbot at llvm.org
Mon Feb 24 14:50:10 PST 2025
https://github.com/christopherbate created https://github.com/llvm/llvm-project/pull/128590
Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
>From 48867b0ee32d85f948c81615ba297cc800f93389 Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Mon, 24 Feb 2025 22:45:49 +0000
Subject: [PATCH] [mlir][tensor] Fix bufferization interface for
'tensor.reshape'
Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
---
.../BufferizableOpInterfaceImpl.cpp | 4 +++
.../Dialect/Tensor/one-shot-bufferize.mlir | 27 +++++++++++++++++++
2 files changed, 31 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..8b7aee67ea5c2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
+ // Only the 'source' operand aliases the result.
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ if (reshapeOp.getSourceMutable() != opOperand)
+ return {};
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index af4f84640890b..2983cd30258a5 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
// -----
+// CHECK-LABEL: func @tensor_reshape_aliasing
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
+ %t1_static = arith.constant dense<0.> : tensor<10xf32>
+ // CHECK-DAG: %[[T1:.+]] = memref.cast
+ %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
+
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+
+ // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
+ %shape = bufferization.alloc_tensor() : tensor<2xindex>
+ // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
+ %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
+ // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
+ %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
+
+ // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
+ %reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+ // CHECK: return %[[RESHAPED]]
+ return %reshaped : tensor<?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: @reshape_with_non_identity_layout(
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,
More information about the Mlir-commits
mailing list