[Mlir-commits] [mlir] [mlir][func][bufferization] Fix cast incompatible when bufferize callOp (PR #105929)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Aug 24 00:28:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Handle caller/callee type mismatch using `castOrReallocMemRefValue` instead of just a `CastOp`. The method insert a reallocation + copy if it cannot be statically guaranteed that a direct cast would be valid. Fix #<!-- -->105916.

---
Full diff: https://github.com/llvm/llvm-project/pull/105929.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+13-8) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+24) 


``````````diff
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 053ea7935260a2..f85d0c35c5af33 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,20 +258,25 @@ struct CallOpInterface
         return failure();
       Value buffer = *maybeBuffer;
 
-      // Caller / callee type mismatch is handled with a CastOp.
+      // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
       auto memRefType = funcType.getInput(opOperand.getOperandNumber());
       // Since we don't yet have a clear layout story, to_memref may
       // conservatively turn tensors into more dynamic memref than necessary.
       // If the memref type of the callee fails, introduce an extra memref.cast
       // that will either canonicalize away or fail compilation until we can do
-      // something better.
+      // something better. Insert a reallocation + copy if it cannot be
+      // statically guaranteed that a direct cast would be valid.
       if (buffer.getType() != memRefType) {
-        assert(
-            memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
-            "CallOp::bufferize: cast incompatible");
-        Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
-                                                           memRefType, buffer);
-        buffer = castBuffer;
+        auto memrefDestType = dyn_cast<MemRefType>(memRefType);
+        assert(memrefDestType &&
+               "buffer layout not supported on unranked tensors");
+        BufferizationOptions options;
+        options.bufferAlignment = 0;
+        FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
+            rewriter, buffer, memrefDestType, options);
+        if (failed(replacement))
+          return failure();
+        buffer = *replacement;
       }
       newOperands.push_back(buffer);
     }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 0248afb11f1672..522cb2c4537ce2 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)
 
 // -----
 
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @foo(
+// CHECK-NO-LAYOUT-MAP-SAME:                   %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_0]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
+  return %arg0 : tensor<3x8xf16>
+}
+
+// CHECK-NO-LAYOUT-MAP-LABEL:   func.func @call_extract_slice(
+// CHECK-NO-LAYOUT-MAP-SAME:                                  %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_2:.*]] = memref.alloc() : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:           return %[[VAL_3]] : memref<3x8xf16>
+// CHECK-NO-LAYOUT-MAP:         }
+func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
+  %0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
+  %1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
+  return %1 : tensor<3x8xf16>
+}
+
+// -----
+
 // CHECK-LABEL: func private @private_func
 // CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
 func.func private @private_func(tensor<?xf32>) -> (f32)

``````````

</details>


https://github.com/llvm/llvm-project/pull/105929


More information about the Mlir-commits mailing list