[Mlir-commits] [mlir] 0a0c7e8 - [mlir][tensor] Bufferize tensor.reshape with non-identity layouts (#65654)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 18 17:50:47 PDT 2023


Author: Spenser Bauman
Date: 2023-09-19T09:50:43+09:00
New Revision: 0a0c7e89780a7b429ed8d63146c6caa2f1890e61

URL: https://github.com/llvm/llvm-project/commit/0a0c7e89780a7b429ed8d63146c6caa2f1890e61
DIFF: https://github.com/llvm/llvm-project/commit/0a0c7e89780a7b429ed8d63146c6caa2f1890e61.diff

LOG: [mlir][tensor] Bufferize tensor.reshape with non-identity layouts (#65654)

Bufferization of tensor.reshape generates a memref.reshape operation.
memref.reshape requires the source memref to have an identity layout.
The bufferization process may result in the source memref having a
non-identity layout, resulting in a verification failure.

This change causes the bufferization interface for tensor.reshape to
copy the source memref to a new buffer when the source has a
non-identity layout.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 1535e83376edebb..b08283f0070784c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -865,6 +865,20 @@ struct ReshapeOpInterface
         bufferization::getBufferType(reshapeOp.getResult(), options);
     if (failed(maybeResultMemRefType))
       return failure();
+
+    // memref.reshape requires the source buffer to have an identity layout.
+    // If the source memref does not have an identity layout, clone the source
+    // into a new buffer with an identity layout.
+    auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
+    if (srcType && !srcType.getLayout().isIdentity()) {
+      auto identityType =
+          MemRefType::get(srcType.getShape(), srcType.getElementType());
+      srcBuffer = rewriter
+                      .create<bufferization::CloneOp>(op->getLoc(),
+                                                      identityType, *srcBuffer)
+                      .getResult();
+    }
+
     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
         rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer);
     return success();

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 04877b1b21e1aab..9052744a1d3f984 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -380,3 +380,24 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
   // CHECK: return %[[RESHAPED]]
   return %reshaped : tensor<2x2x5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @reshape_with_non_identity_layout(
+// CHECK-SAME:    %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>>,
+// CHECK-SAME:    %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>)
+func.func @reshape_with_non_identity_layout(%arg0: tensor<2x2xf32>, %arg1: tensor<2xi32>) -> tensor<1x2xf32> {
+
+  // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[INPUT]][1, 0] [1, 2] [1, 1] : memref<2x2xf32, strided<[?, ?], offset: ?>> to memref<2xf32, strided<[?], offset: ?>>
+  %extracted_slice = tensor.extract_slice %arg0[1, 0] [1, 2] [1, 1] : tensor<2x2xf32> to tensor<2xf32>
+
+  // To satisify the constraints of memref.reshape, the subview must be cloned into
+  // a buffer with an identity layout.
+  // CHECK: %[[CLONED:.+]] = bufferization.clone %[[SUBVIEW]] : memref<2xf32, strided<[?], offset: ?>> to memref<2xf32>
+  // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[CLONED]](%[[LAYOUT]]) : (memref<2xf32>, memref<2xi32, strided<[?], offset: ?>>) -> memref<1x2xf32>
+
+  %reshape = tensor.reshape %extracted_slice(%arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x2xf32>
+
+  // CHECK: return %[[RESHAPED]] : memref<1x2xf32>
+  return %reshape : tensor<1x2xf32>
+}


        


More information about the Mlir-commits mailing list