[Mlir-commits] [mlir] [mlir][tensor] Bufferize tensor.reshape with non-identity layouts (PR #65654)
Spenser Bauman
llvmlistbot at llvm.org
Thu Sep 7 11:55:28 PDT 2023
https://github.com/sabauma created https://github.com/llvm/llvm-project/pull/65654:
Bufferization of tensor.reshape generates a memref.reshape operation. memref.reshape requires the source memref to have an identity layout. The bufferization process may result in the source memref having a non-identity layout, resulting in a verification failure.
This change causes the bufferization interface for tensor.reshape to copy the source memref to a new buffer when the source has a non-identity layout.
>From b3d8b600f0ac68c246c0b44ad41f1f68a92a1752 Mon Sep 17 00:00:00 2001
From: Spenser Bauman <sbauman at mathworks.com>
Date: Thu, 7 Sep 2023 14:12:21 -0400
Subject: [PATCH] [mlir][tensor] Bufferize tensor.reshape with non-identity
layouts
Bufferization of tensor.reshape generates a memref.reshape operation.
memref.reshape requires the source memref to have an identity layout.
The bufferization process may result in the source memref having
a non-identity layout, resulting in a verification failure.
This change causes the bufferization interface for tensor.reshape to
copy the source memref to a new buffer when the source has
a non-identity layout.
---
.../BufferizableOpInterfaceImpl.cpp | 14 +++++++++++++
.../Dialect/Tensor/one-shot-bufferize.mlir | 21 +++++++++++++++++++
2 files changed, 35 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index a67ea0334b22b9..33ebebbf53991d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -995,6 +995,20 @@ struct ReshapeOpInterface
bufferization::getBufferType(reshapeOp.getResult(), options);
if (failed(maybeResultMemRefType))
return failure();
+
+ // memref.reshape requires the source buffer to have an identity layout.
+ // If the source memref does not have an identity layout, clone the source
+ // into a new buffer with an identity layout.
+ auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
+ if (srcType && !srcType.getLayout().isIdentity()) {
+ auto identityType =
+ MemRefType::get(srcType.getShape(), srcType.getElementType());
+ srcBuffer = rewriter
+ .create<bufferization::CloneOp>(op->getLoc(),
+ identityType, *srcBuffer)
+ .getResult();
+ }
+
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
return success();
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 2aeb5a820812ea..13d520aa40723b 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -418,3 +418,24 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
// CHECK: return %[[RESHAPED]]
return %reshaped : tensor<2x2x5xf32>
}
+
+// -----
+
+// CHECK-LABEL: @reshape_with_non_identity_layout(
+// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>,
+// CHECK-SAME: %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>)
+func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> {
+
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>>
+ %extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32>
+
+ // To satisify the constraints of memref.reshape, the subview must be cloned into
+ // a buffer with an identity layout.
+ // CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32>
+ // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32>
+
+ %reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32>
+
+ // CHECK: return %[[RESHAPED]] : memref<1x2xf32>
+ return %reshape : tensor<1x2xf32>
+}
More information about the Mlir-commits
mailing list