[Mlir-commits] [mlir] [mlir][func][bufferization] Fix cast incompatible when bufferize callOp (PR #105929)
Longsheng Mou
llvmlistbot at llvm.org
Sun Aug 25 19:14:36 PDT 2024
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/105929
>From f00a89dd5d3a3a56da8537c8e4a5a5c5ab04238e Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Sat, 24 Aug 2024 15:11:52 +0800
Subject: [PATCH] [mlir][func][bufferization] Fix cast incompatible when
bufferize callOp
Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead
of just a `CastOp`. The method insert a reallocation + copy if it cannot
be statically guaranteed that a direct cast would be valid.
---
.../FuncBufferizableOpInterfaceImpl.cpp | 19 ++++++++-------
.../Transforms/one-shot-module-bufferize.mlir | 24 +++++++++++++++++++
2 files changed, 35 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..9fbe574ec392dc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,23 @@ struct CallOpInterface
return failure();
Value buffer = *maybeBuffer;
- // Caller / callee type mismatch is handled with a CastOp.
+ // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
- // something better.
+ // something better. Insert a reallocation + copy if it cannot be
+ // statically guaranteed that a direct cast would be valid.
if (buffer.getType() != memRefType) {
- assert(
- memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
- "CallOp::bufferize: cast incompatible");
- Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
- memRefType, buffer);
- buffer = castBuffer;
+ auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+ assert(memrefDstType &&
+ "buffer layout not supported on unranked tensors");
+ FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+ rewriter, buffer, memrefDstType, options);
+ if (failed(replacement))
+ return failure();
+ buffer = *replacement;
}
newOperands.push_back(buffer);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..0d5224514e3a02 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
// -----
+// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+ return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP: }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+ %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+ %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+ return %1 : tensor<3x8xf16>
+}
+
+// -----
+
// CHECK-LABEL: func private @private_func
// CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
func.func private @private_func(tensor<?xf32>) -> (f32)
More information about the Mlir-commits
mailing list