[Mlir-commits] [mlir] [mlir][bufferization] Support custom types at function boundaries (PR #159766)

Andrei Golubev llvmlistbot at llvm.org
Wed Sep 24 02:47:55 PDT 2025


https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/159766

>From 16056beca1ff91853bec248d49eb424c26cc4d5b Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Wed, 9 Jul 2025 09:56:39 +0000
Subject: [PATCH 1/3] [mlir][bufferization] Support custom types at function
 boundaries

Support custom types (3/N): allow custom tensor and buffer types in
function signatures and at call-sites. This is one of the major building
blocks to move in the direction of module-level one-shot-bufferization
support.

In order to enable this, TensorLikeType is extended with a new interface
method that is invoked solely within the function boundary
bufferization.
---
 .../IR/BufferizationTypeInterfaces.h          |  1 +
 .../IR/BufferizationTypeInterfaces.td         | 12 +++
 .../Bufferization/IR/BufferizationDialect.cpp | 13 +++
 .../Bufferization/Transforms/Bufferize.cpp    |  2 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       | 90 ++++++++++---------
 .../Transforms/one-shot-module-bufferize.mlir | 56 ++++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  5 ++
 mlir/test/lib/Dialect/Test/TestTypes.cpp      |  8 ++
 8 files changed, 146 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index a2bfcb7ed2b75..9b052b8bb7e14 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index fb6fc4f5ad964..c4235cd067999 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,6 +43,18 @@ def Bufferization_TensorLikeTypeInterface
       /*args=*/(ins
         "::mlir::bufferization::BufferLikeType":$bufferType,
         "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+    >,
+    InterfaceMethod<[{
+        Returns a BufferLike type for this TensorLike type in the context of
+        this type being function argument or result.
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+      /*methodName=*/"getBufferTypeAtFunctionBoundary",
+      /*args=*/(ins
+        "::mlir::func::FuncOp":$funcOp,
+        "const ::mlir::bufferization::BufferizationOptions &":$options,
+        "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
+      )
     >
   ];
 }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..9b907922a24c4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,6 +87,19 @@ struct BuiltinTensorExternalModel
 
     return mlir::success();
   }
+
+  llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
+      mlir::Type tensor, mlir::func::FuncOp funcOp,
+      const BufferizationOptions &options,
+      llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    auto tensorType = cast<TensorType>(tensor);
+    auto memSpace = options.defaultMemorySpaceFn(tensorType);
+    if (!memSpace.has_value())
+      return emitError() << "could not infer memory space";
+
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *memSpace, funcOp, options));
+  }
 };
 
 template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 68ef51992efee..701ab52a491a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
   // Compute the new signature.
   SmallVector<Type> newTypes;
   for (BlockArgument &bbArg : block->getArguments()) {
-    auto tensorType = dyn_cast<TensorType>(bbArg.getType());
+    auto tensorType = dyn_cast<TensorLikeType>(bbArg.getType());
     if (!tensorType) {
       newTypes.push_back(bbArg.getType());
       continue;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index f69efd1b3fa8c..b7bac9f4623f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -52,26 +52,35 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
 /// specified by the user (as per `options.functionArgTypeConverterFn`).
-static BaseMemRefType
+static BufferLikeType
 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
   auto tensorType =
-      dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
-  assert(tensorType && "expected TensorType");
-
-  BaseMemRefType memrefType = options.functionArgTypeConverterFn(
-      tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
-
-  auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
-      index, BufferizationDialect::kBufferLayoutAttrName);
-  if (!layoutAttr)
-    return memrefType;
-
-  auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
-  assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
-  return MemRefType::get(rankedMemrefType.getShape(),
-                         rankedMemrefType.getElementType(), layoutAttr,
-                         rankedMemrefType.getMemorySpace());
+      dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
+  assert(tensorType && "expected TensorLikeType");
+  auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
+      funcOp, options, [&]() { return funcOp->emitError(); });
+  assert(mlir::succeeded(maybeBufferType) &&
+         "a valid buffer is always expected");
+
+  auto bufferType = *maybeBufferType;
+
+  // Note: For builtin tensors there is additional logic related to layout.
+  if (isa<TensorType>(tensorType)) {
+    auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
+        index, BufferizationDialect::kBufferLayoutAttrName);
+    if (!layoutAttr)
+      return bufferType;
+
+    auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+    assert(rankedMemrefType &&
+           "buffer layout not supported on unranked tensors");
+    return cast<BufferLikeType>(MemRefType::get(
+        rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
+        layoutAttr, rankedMemrefType.getMemorySpace()));
+  }
+
+  return bufferType;
 }
 
 /// Return the FuncOp called by `callOp`.
@@ -227,14 +236,13 @@ struct CallOpInterface
     FunctionType funcType = funcOp.getFunctionType();
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
-    if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return cast<BufferLikeType>(bufferizedType);
+    if (auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
+      return bufferizedType;
 
     // Otherwise, call the type converter to compute the bufferized type.
-    auto tensorType = cast<TensorType>(resultType);
-    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-        options));
+    auto tensorType = cast<TensorLikeType>(resultType);
+    return tensorType.getBufferTypeAtFunctionBoundary(
+        funcOp, options, [&]() { return funcOp->emitError(); });
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -248,7 +256,7 @@ struct CallOpInterface
     SmallVector<Type> resultTypes;
     for (Value result : callOp.getResults()) {
       Type returnType = result.getType();
-      if (!isa<TensorType>(returnType)) {
+      if (!isa<TensorLikeType>(returnType)) {
         // Non-tensor values are returned.
         resultTypes.push_back(returnType);
         continue;
@@ -272,7 +280,7 @@ struct CallOpInterface
 
     for (OpOperand &opOperand : callOp->getOpOperands()) {
       // Non-tensor operands are just copied.
-      if (!isa<TensorType>(opOperand.get().getType())) {
+      if (!isa<TensorLikeType>(opOperand.get().getType())) {
         newOperands.push_back(opOperand.get());
         continue;
       }
@@ -285,8 +293,8 @@ struct CallOpInterface
       Value buffer = *maybeBuffer;
 
       // Caller / callee type mismatch is handled with castOrReallocMemRefValue.
-      auto memRefType = funcType.getInput(opOperand.getOperandNumber());
-      if (!isa<BaseMemRefType>(memRefType)) {
+      auto bufferType = funcType.getInput(opOperand.getOperandNumber());
+      if (!isa<BufferLikeType>(bufferType)) {
         // The called function was not bufferized yet. This can happen when
         // there cycles in the function call graph. Compute the bufferized
         // result type.
@@ -296,7 +304,7 @@ struct CallOpInterface
                 state);
         if (failed(maybeBufferType))
           return failure();
-        memRefType = *maybeBufferType;
+        bufferType = *maybeBufferType;
       }
 
       // Since we don't yet have a clear layout story, to_buffer may
@@ -305,8 +313,8 @@ struct CallOpInterface
       // that will either canonicalize away or fail compilation until we can do
       // something better. Insert a reallocation + copy if it cannot be
       // statically guaranteed that a direct cast would be valid.
-      if (buffer.getType() != memRefType) {
-        auto memrefDstType = dyn_cast<MemRefType>(memRefType);
+      if (buffer.getType() != bufferType) {
+        auto memrefDstType = dyn_cast<MemRefType>(bufferType);
         assert(memrefDstType &&
                "buffer layout not supported on unranked tensors");
         FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
@@ -370,7 +378,7 @@ struct FuncOpInterface
   static bool supportsUnstructuredControlFlow() { return true; }
 
   bool hasTensorSemantics(Operation *op) const {
-    auto isaTensor = llvm::IsaPred<TensorType>;
+    auto isaTensor = llvm::IsaPred<TensorLikeType>;
 
     // A function has tensor semantics if it has tensor arguments/results.
     auto funcOp = cast<FuncOp>(op);
@@ -406,8 +414,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return cast<BufferLikeType>(
-          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
+      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
+                                          options);
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
@@ -430,7 +438,7 @@ struct FuncOpInterface
     SmallVector<Type> argTypes;
     for (const auto &it : llvm::enumerate(funcType.getInputs())) {
       Type argType = it.value();
-      if (isa<TensorType>(argType)) {
+      if (isa<TensorLikeType>(argType)) {
         argTypes.push_back(
             getBufferizedFunctionArgType(funcOp, it.index(), options));
         continue;
@@ -441,11 +449,13 @@ struct FuncOpInterface
     // Compute the result types.
     SmallVector<Type> retTypes;
     for (Type resultType : funcType.getResults()) {
-      if (auto tensorType = dyn_cast<TensorType>(resultType)) {
-        BaseMemRefType resultType = options.functionArgTypeConverterFn(
-            tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
-            options);
-        retTypes.push_back(resultType);
+      if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
+        FailureOr<BufferLikeType> resultType =
+            tensorType.getBufferTypeAtFunctionBoundary(
+                funcOp, options, [&]() { return funcOp->emitError(); });
+        assert(mlir::succeeded(resultType) &&
+               "a valid buffer is always expected");
+        retTypes.push_back(*resultType);
         continue;
       }
       retTypes.push_back(resultType);
@@ -473,7 +483,7 @@ struct FuncOpInterface
       SmallVector<Value> returnValues;
       for (auto [returnVal, bufferizedType] :
            llvm::zip_equal(returnOp->getOperands(), retTypes)) {
-        auto tensorType = dyn_cast<TensorType>(returnVal.getType());
+        auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
         rewriter.setInsertionPoint(returnOp);
 
         // If not a tensor type just forward it.
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index 2efb5893c8511..eb0093106dc11 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -810,3 +810,59 @@ module @inner_module {
     return %t : tensor<5xf32>
   }
 }
+
+// -----
+
+// CHECK:   func.func @custom_types(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> (!test.test_memref<[4, 8], f64>,
+// CHECK-SAME:        !test.test_memref<[4, 8], f64>)
+func.func @custom_types(%arg: !test.test_tensor<[4, 4], f64>)
+    -> (!test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>) {
+  // CHECK: %[[out1:.*]] = "test.dummy_memref_op"(%[[arg]]) :
+  // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+  %out1 = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: %[[alloc:.*]] = "test.create_memref_op"
+  // CHECK: %[[out2:.*]] = "test.dummy_memref_op"(%[[alloc]])
+  // CHECK-SAME: (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 8], f64>
+  %alloc = "test.create_tensor_op"() : () -> !test.test_tensor<[4, 4], f64>
+  %out2 = "test.dummy_tensor_op"(%alloc) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: return %[[out1]], %[[out2]]
+  return %out1, %out2 :
+    !test.test_tensor<[4, 8], f64>, !test.test_tensor<[4, 8], f64>
+}
+
+// -----
+
+// CHECK:   func.func @custom_types_foo(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_foo(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64> {
+  // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[arg]])
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+  // CHECK: return %[[out]]
+  return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK:   func.func @custom_types_bar(
+// CHECK-SAME:    %[[arg:.*]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME:  ) -> !test.test_memref<[4, 8], f64>
+func.func @custom_types_bar(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64> {
+  // CHECK: %[[call:.*]] = call @custom_types_foo(%[[arg]])
+  %call = func.call @custom_types_foo(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+
+  // CHECK: %[[out:.*]] = "test.dummy_memref_op"(%[[call]])
+  %out = "test.dummy_tensor_op"(%call) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 8], f64>
+
+  // CHECK: return %[[out]]
+  return %out : !test.test_tensor<[4, 8], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597231d58..562fc66acea2a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,6 +444,11 @@ def TestTensorType : Test_Type<"TestTensor",
     ::mlir::LogicalResult verifyCompatibleBufferType(
         ::mlir::bufferization::BufferLikeType bufferType,
         ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
+
+    ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+    getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
+        const ::mlir::bufferization::BufferizationOptions& options,
+        ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
   }];
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..3c92fb94aebee 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,3 +573,11 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
                      getElementType() == testMemref.getElementType();
   return mlir::success(valid);
 }
+
+::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
+TestTensorType::getBufferTypeAtFunctionBoundary(
+    mlir::func::FuncOp,
+    const ::mlir::bufferization::BufferizationOptions &options,
+    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
+  return getBufferType(options, emitError);
+}

>From e40d66a643d9b69e073e0c7cb9914a32445b3160 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 22 Sep 2025 11:22:53 +0000
Subject: [PATCH 2/3] Rely on BufferizationOptions::FunctionArgTypeConverterFn

Transform FunctionArgTypeConverterFn into a tensor-like -> buffer-like
converter so that it could be used as a generic function boundary
conversion utility.
---
 .../IR/BufferizableOpInterface.h              | 16 ++++---
 .../IR/BufferizationTypeInterfaces.h          |  1 -
 .../IR/BufferizationTypeInterfaces.td         | 12 -----
 .../IR/BufferizableOpInterface.cpp            | 39 +++++++++++----
 .../Bufferization/IR/BufferizationDialect.cpp | 13 -----
 .../FuncBufferizableOpInterfaceImpl.cpp       | 48 +++++++++++--------
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  5 --
 mlir/test/lib/Dialect/Test/TestTypes.cpp      |  8 ----
 8 files changed, 67 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f3b34f9fded7f..5bf3916630158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -260,10 +260,10 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor -> MemRef type converter.
-  /// Parameters: tensor type, memory space, func op, bufferization options
+  /// Tensor-like -> Buffer-like type converter.
+  /// Parameters: tensor-like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
-      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
   /// Tensor -> MemRef type converter.
   /// Parameters: tensor type, memory space, bufferization options
@@ -335,10 +335,12 @@ struct BufferizationOptions {
   /// predictable.
   void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
 
-  /// Type converter from tensors to memrefs. This type converter is used to
-  /// determine bufferized function argument and result types. By default, a
-  /// type converter that returns a memref type with a fully dynamic layout map
-  /// is used.
+  /// Type converter from tensors to buffers. This type converter is used to
+  /// determine bufferized function argument and result types.
+  ///
+  /// By default, if tensor is a (builtin) tensor type, a type converter that
+  /// returns a memref type with a fully dynamic layout map is used; if tensor
+  /// is a (generic) tensor-like type, TensorLikeType::getBufferType() is used.
   ///
   /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
   FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 9b052b8bb7e14..a2bfcb7ed2b75 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,7 +13,6 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Func/IR/FuncOps.h" // to access mlir::func::FuncOp
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index c4235cd067999..fb6fc4f5ad964 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -43,18 +43,6 @@ def Bufferization_TensorLikeTypeInterface
       /*args=*/(ins
         "::mlir::bufferization::BufferLikeType":$bufferType,
         "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
-    >,
-    InterfaceMethod<[{
-        Returns a BufferLike type for this TensorLike type in the context of
-        this type being function argument or result.
-      }],
-      /*retTy=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
-      /*methodName=*/"getBufferTypeAtFunctionBoundary",
-      /*args=*/(ins
-        "::mlir::func::FuncOp":$funcOp,
-        "const ::mlir::bufferization::BufferizationOptions &":$options,
-        "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError
-      )
     >
   ];
 }
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f7b0b87085f3d..fae1df69ed3e3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -338,11 +338,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 namespace {
 
 /// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
+BufferLikeType
+defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
                                 func::FuncOp funcOp,
                                 const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+  if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+    return cast<BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace));
+  }
+
+  // If not builtin, fallback to TensorLikeType::getBufferType()
+  auto bufferType =
+      type.getBufferType(options, [&]() { return funcOp->emitError(); });
+  assert(mlir::succeeded(bufferType) &&
+         "a valid buffer is always expected at function boundary");
+  return *bufferType;
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
 BaseMemRefType
@@ -385,14 +395,25 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
     LayoutMapOption layoutMapOption) {
-  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
+  functionArgTypeConverterFn = [=](TensorLikeType type, Attribute memorySpace,
                                    func::FuncOp funcOp,
                                    const BufferizationOptions &options) {
-    if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
-                                                                  memorySpace);
-    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                              memorySpace);
+    if (auto tensorType = mlir::dyn_cast<TensorType>(type)) {
+      if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
+        return cast<BufferLikeType>(
+            bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
+                                                                 memorySpace));
+      return cast<BufferLikeType>(
+          bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+                                                             memorySpace));
+    }
+
+    // If not builtin, fallback to TensorLikeType::getBufferType()
+    auto bufferType =
+        type.getBufferType(options, [&]() { return funcOp->emitError(); });
+    assert(mlir::succeeded(bufferType) &&
+           "a valid buffer is always expected at function boundary");
+    return *bufferType;
   };
   inferFunctionResultLayout =
       layoutMapOption == LayoutMapOption::InferLayoutMap;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 9b907922a24c4..6c08cdfb669f3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -87,19 +87,6 @@ struct BuiltinTensorExternalModel
 
     return mlir::success();
   }
-
-  llvm::FailureOr<BufferLikeType> getBufferTypeAtFunctionBoundary(
-      mlir::Type tensor, mlir::func::FuncOp funcOp,
-      const BufferizationOptions &options,
-      llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
-    auto tensorType = cast<TensorType>(tensor);
-    auto memSpace = options.defaultMemorySpaceFn(tensorType);
-    if (!memSpace.has_value())
-      return emitError() << "could not infer memory space";
-
-    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
-        tensorType, *memSpace, funcOp, options));
-  }
 };
 
 template <typename MemRef>
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index b7bac9f4623f1..d9d69342e42a8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -49,30 +49,38 @@ void FuncAnalysisState::startFunctionAnalysis(FuncOp funcOp) {
 #endif // NDEBUG
 }
 
+// Note: this is a local adaptor to unify TensorType and TensorLikeType code
+// paths that both work with BufferizationOptions.
+static mlir::Attribute
+getDefaultMemorySpace(const BufferizationOptions &options,
+                      TensorLikeType type) {
+  if (auto tensorType = dyn_cast<TensorType>(type)) {
+    return *options.defaultMemorySpaceFn(tensorType);
+  }
+  return nullptr;
+}
+
 /// Return the index-th bufferized function argument type. This assumes that the
 /// specified argument is a tensor. If the tensor is ranked, a layout map may be
 /// specified by the user (as per `options.functionArgTypeConverterFn`).
 static BufferLikeType
 getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
                              const BufferizationOptions &options) {
-  auto tensorType =
+  auto type =
       dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(index));
-  assert(tensorType && "expected TensorLikeType");
-  auto maybeBufferType = tensorType.getBufferTypeAtFunctionBoundary(
-      funcOp, options, [&]() { return funcOp->emitError(); });
-  assert(mlir::succeeded(maybeBufferType) &&
-         "a valid buffer is always expected");
-
-  auto bufferType = *maybeBufferType;
+  assert(type && "expected TensorLikeType");
 
   // Note: For builtin tensors there is additional logic related to layout.
-  if (isa<TensorType>(tensorType)) {
+  if (auto tensorType = dyn_cast<TensorType>(type)) {
+    BufferLikeType memrefType = options.functionArgTypeConverterFn(
+        type, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+
     auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
         index, BufferizationDialect::kBufferLayoutAttrName);
     if (!layoutAttr)
-      return bufferType;
+      return memrefType;
 
-    auto rankedMemrefType = dyn_cast<MemRefType>(bufferType);
+    auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
     assert(rankedMemrefType &&
            "buffer layout not supported on unranked tensors");
     return cast<BufferLikeType>(MemRefType::get(
@@ -80,7 +88,8 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
         layoutAttr, rankedMemrefType.getMemorySpace()));
   }
 
-  return bufferType;
+  return options.functionArgTypeConverterFn(type, /*memSpace=*/nullptr, funcOp,
+                                            options);
 }
 
 /// Return the FuncOp called by `callOp`.
@@ -241,8 +250,9 @@ struct CallOpInterface
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorLikeType>(resultType);
-    return tensorType.getBufferTypeAtFunctionBoundary(
-        funcOp, options, [&]() { return funcOp->emitError(); });
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -450,12 +460,10 @@ struct FuncOpInterface
     SmallVector<Type> retTypes;
     for (Type resultType : funcType.getResults()) {
       if (auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
-        FailureOr<BufferLikeType> resultType =
-            tensorType.getBufferTypeAtFunctionBoundary(
-                funcOp, options, [&]() { return funcOp->emitError(); });
-        assert(mlir::succeeded(resultType) &&
-               "a valid buffer is always expected");
-        retTypes.push_back(*resultType);
+        BufferLikeType resultType = options.functionArgTypeConverterFn(
+            tensorType, getDefaultMemorySpace(options, tensorType), funcOp,
+            options);
+        retTypes.push_back(resultType);
         continue;
       }
       retTypes.push_back(resultType);
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 562fc66acea2a..ea20597231d58 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -444,11 +444,6 @@ def TestTensorType : Test_Type<"TestTensor",
     ::mlir::LogicalResult verifyCompatibleBufferType(
         ::mlir::bufferization::BufferLikeType bufferType,
         ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
-
-    ::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
-    getBufferTypeAtFunctionBoundary(mlir::func::FuncOp funcOp,
-        const ::mlir::bufferization::BufferizationOptions& options,
-        ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError);
   }];
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 3c92fb94aebee..bea043f56fe21 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -573,11 +573,3 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
                      getElementType() == testMemref.getElementType();
   return mlir::success(valid);
 }
-
-::mlir::FailureOr<::mlir::bufferization::BufferLikeType>
-TestTensorType::getBufferTypeAtFunctionBoundary(
-    mlir::func::FuncOp,
-    const ::mlir::bufferization::BufferizationOptions &options,
-    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
-  return getBufferType(options, emitError);
-}

>From 95e4724090591d6b034e68c2d624fe22fdda876b Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Wed, 24 Sep 2025 09:50:04 +0000
Subject: [PATCH 3/3] Change documentation wording for type converters

---
 .../IR/BufferizableOpInterface.h              | 19 +++++++++----------
 .../IR/BufferizableOpInterface.cpp            |  4 ++--
 2 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 5bf3916630158..dd693a25fd54f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -260,12 +260,12 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor-like -> Buffer-like type converter.
+  /// Tensor-like -> Buffer-like type conversion.
   /// Parameters: tensor-like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
       std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
-  /// Tensor -> MemRef type converter.
+  /// Tensor -> MemRef type conversion.
   /// Parameters: tensor type, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       TensorType, Attribute memorySpace, const BufferizationOptions &)>;
@@ -335,12 +335,12 @@ struct BufferizationOptions {
   /// predictable.
   void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
 
-  /// Type converter from tensors to buffers. This type converter is used to
+  /// Type conversion from tensors to buffers. This type conversion is used to
   /// determine bufferized function argument and result types.
   ///
-  /// By default, if tensor is a (builtin) tensor type, a type converter that
-  /// returns a memref type with a fully dynamic layout map is used; if tensor
-  /// is a (generic) tensor-like type, TensorLikeType::getBufferType() is used.
+  /// By default, if tensor is a (builtin) tensor type, it is converted to a
+  /// memref type with a fully dynamic layout map; if tensor is a (generic)
+  /// tensor-like type, it is converted using TensorLikeType::getBufferType().
   ///
   /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
   FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
@@ -352,10 +352,9 @@ struct BufferizationOptions {
   /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
   bool inferFunctionResultLayout = true;
 
-  /// Type converter from tensors to memrefs. This type converter is used if no
-  /// memref type could be inferred during bufferization. By default, a type
-  /// converter that returns a memref type with a fully dynamic layout map is
-  /// used.
+  /// Type conversion from tensors to memrefs. This type conversion is used if
+  /// no memref type could be inferred during bufferization. By default, returns
+  /// a memref type with a fully dynamic layout map.
   UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
 
   // Use during type conversion to determine the memory space for memref based
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index fae1df69ed3e3..e0cf353da207f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -350,7 +350,7 @@ defaultFunctionArgTypeConverter(TensorLikeType type, Attribute memorySpace,
   // If not builtin, fallback to TensorLikeType::getBufferType()
   auto bufferType =
       type.getBufferType(options, [&]() { return funcOp->emitError(); });
-  assert(mlir::succeeded(bufferType) &&
+  assert(succeeded(bufferType) &&
          "a valid buffer is always expected at function boundary");
   return *bufferType;
 }
@@ -411,7 +411,7 @@ void BufferizationOptions::setFunctionBoundaryTypeConversion(
     // If not builtin, fallback to TensorLikeType::getBufferType()
     auto bufferType =
         type.getBufferType(options, [&]() { return funcOp->emitError(); });
-    assert(mlir::succeeded(bufferType) &&
+    assert(succeeded(bufferType) &&
            "a valid buffer is always expected at function boundary");
     return *bufferType;
   };



More information about the Mlir-commits mailing list