[Mlir-commits] [mlir] [mlir][func][bufferization] Fix cast incompatible when bufferize callOp (PR #105929)

Longsheng Mou llvmlistbot at llvm.org
Sun Aug 25 18:12:47 PDT 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/105929

>From e5b92e0a51b10dc244ad7e584ca9f019a337897b Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Sat, 24 Aug 2024 15:11:52 +0800
Subject: [PATCH] [mlir][func][bufferization] Fix cast incompatible when
 bufferize callOp

Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead
of just a `CastOp`. The method insert a reallocation + copy if it cannot
be statically guaranteed that a direct cast would be valid.
---
 .../FuncBufferizableOpInterfaceImpl.cpp       | 19 ++++++++-------
 .../Transforms/one-shot-module-bufferize.mlir | 24 +++++++++++++++++++
 2 files changed, 35 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..9fbe574ec392dc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,23 @@ struct CallOpInterface
         return failure();
       Value buffer = *maybeBuffer;
 
-      // Caller / callee type mismatch is handled with a CastOp.
+      // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
       // that will either canonicalize away or fail compilation until we can do
-      // something better.
+      // something better. Insert a reallocation + copy if it cannot be
+      // statically guaranteed that a direct cast would be valid.
       if (buffer.getType() != memRefType) {
-        assert(
-            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
-            "CallOp::bufferize: cast incompatible");
-        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
-                                                           memRefType, buffer);
-        buffer = castBuffer;
+        auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+        assert(memrefDstType &&
+               "buffer layout not supported on unranked tensors");
+        FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+            rewriter, buffer, memrefDstType, options);
+        if (failed(replacement))
+          return failure();
+        buffer = *replacement;
       }
       newOperands.push_back(buffer);
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..522cb2c4537ce2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
 
 // -----
 
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME:                   %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+  return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME:                                  %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_2:.*]] = memref.alloc() : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+  %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+  %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+  return %1 : tensor<3x8xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func private @private_func
 // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
 func.func private @private_func(tensor<?xf32>) -> (f32)



More information about the Mlir-commits mailing list