[Mlir-commits] [mlir] 7f04a8a - [mlir][func][bufferization] Fix cast incompatible when bufferize callOp (#105929)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 26 16:06:04 PDT 2024


Author: Longsheng Mou
Date: 2024-08-27T07:06:00+08:00
New Revision: 7f04a8ad131881b5a58b97c8191733ed42d18e20

URL: https://github.com/llvm/llvm-project/commit/7f04a8ad131881b5a58b97c8191733ed42d18e20
DIFF: https://github.com/llvm/llvm-project/commit/7f04a8ad131881b5a58b97c8191733ed42d18e20.diff

LOG: [mlir][func][bufferization] Fix cast incompatible when bufferize callOp (#105929)

Handle caller/callee type mismatch using `castOrReallocMemRefValue`
instead of just a `CastOp`. The method insert a reallocation + copy if
it cannot be statically guaranteed that a direct cast would be valid.
Fix #105916.

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..9fbe574ec392dc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,23 @@ struct CallOpInterface
         return failure();
       Value buffer = *maybeBuffer;
 
-      // Caller / callee type mismatch is handled with a CastOp.
+      // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
       // that will either canonicalize away or fail compilation until we can do
-      // something better.
+      // something better. Insert a reallocation + copy if it cannot be
+      // statically guaranteed that a direct cast would be valid.
       if (buffer.getType() != memRefType) {
-        assert(
-            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
-            "CallOp::bufferize: cast incompatible");
-        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
-                                                           memRefType, buffer);
-        buffer = castBuffer;
+        auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+        assert(memrefDstType &&
+               "buffer layout not supported on unranked tensors");
+        FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+            rewriter, buffer, memrefDstType, options);
+        if (failed(replacement))
+          return failure();
+        buffer = *replacement;
       }
       newOperands.push_back(buffer);
     }

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..0d5224514e3a02 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
 
 // -----
 
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME:                   %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+  return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME:                                  %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+  %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+  %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+  return %1 : tensor<3x8xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func private @private_func
 // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
 func.func private @private_func(tensor<?xf32>) -> (f32)


        


More information about the Mlir-commits mailing list