[Mlir-commits] [mlir] [mlir][bufferization] Introduce createMemRefLayoutFn hook (PR #195622)
Andrei Golubev
llvmlistbot at llvm.org
Mon May 4 02:30:51 PDT 2026
https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/195622
Add a framework-provided hook that derives a MemRefLayoutAttrInterface from a tensor type encoding during tensor-to-memref conversion. The hook is intended to be a op-independent customization point for users wishing to preserve custom encoding information across the bufferization process.
The hook is wired into the two main tensor-to-memref entry points that are not driven by an op-specific BufferizableOpInterface::getBufferType:
* BuiltinTensorExternalModel::getBufferType (generic tensor-like fallback),
* bufferization.alloc_tensor (replaces the unconditional identity layout when the hook is set; identity is preserved as the default).
Behaviour is unchanged when the hook is not set (nullptr by default).
>From 6cc43c99f7d29dcee0af9565ac150f1e668bf73f Mon Sep 17 00:00:00 2001
From: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
Date: Fri, 24 Apr 2026 09:02:02 +0000
Subject: [PATCH] [mlir][bufferization] Introduce createMemRefLayoutFn hook
Add a framework-provided hook that derives a MemRefLayoutAttrInterface
from a tensor type encoding during tensor-to-memref conversion. The hook
is intended to be a op-independent customization point for users wishing
to preserve custom encoding information across the bufferization
process.
The hook is wired into the two main tensor-to-memref entry points that
are not driven by an op-specific BufferizableOpInterface::getBufferType:
* BuiltinTensorExternalModel::getBufferType (generic tensor-like
fallback),
* bufferization.alloc_tensor (replaces the unconditional identity layout
when the hook is set; identity is preserved as the default).
Behaviour is unchanged when the hook is not set (nullptr by default).
Co-authored-by: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
---
.../IR/BufferizableOpInterface.h | 11 +++
.../Bufferization/IR/BufferizationDialect.cpp | 5 +-
.../Bufferization/IR/BufferizationOps.cpp | 7 ++
.../one-shot-non-module-bufferize.mlir | 38 --------
.../test-one-shot-module-bufferize.mlir | 91 +++++++++++++++++++
.../TestOneShotModuleBufferize.cpp | 5 +
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 52 +++--------
7 files changed, 129 insertions(+), 80 deletions(-)
create mode 100644 mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3f8392e3b8970..5a4a199c1b45e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,11 @@ struct BufferizationOptions {
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
+ /// Create a MemRefLayoutAttrInterface object from the tensor encoding. This
+ /// is called whenever bufferization materializes a memref from a tensor value
+ /// (allocations, generic tensor-like conversion fallbacks, etc.).
+ using CreateMemRefLayoutFn =
+ std::function<MemRefLayoutAttrInterface(TensorType)>;
BufferizationOptions();
@@ -364,6 +369,12 @@ struct BufferizationOptions {
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ /// Optional universal bufferization hook to control a tensor-encoding ->
+ /// memref-layout conversion process. By default, it is unset (`nullptr`),
+ /// meaning that the bufferization process would perform call-site dependent
+ /// layout inference (e.g. return identity layout, or fully dynamic layout).
+ CreateMemRefLayoutFn createMemRefLayoutFn = nullptr;
+
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index bd177ba1afccd..b5a04e352e138 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -47,8 +47,11 @@ struct BuiltinTensorExternalModel
if (!memSpace.has_value())
return emitError() << "could not infer memory space";
+ MemRefLayoutAttrInterface layout = {};
+ if (options.createMemRefLayoutFn)
+ layout = options.createMemRefLayoutFn(tensorType);
return cast<BufferLikeType>(
- getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+ getMemRefType(tensorType, options, layout, *memSpace));
}
mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index c525ec116f699..790e09d0d01ec 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -246,6 +246,13 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
+ if (options.createMemRefLayoutFn) {
+ if (auto layout = options.createMemRefLayoutFn(getType())) {
+ return cast<BufferLikeType>(
+ getMemRefType(getType(), options, layout, memorySpace));
+ }
+ }
+
return cast<BufferLikeType>(
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
index b52612d0d1f10..09fdf8231b7cc 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -29,41 +29,3 @@
}
"test.finish" () : () -> ()
}) : () -> ()
-
-// -----
-
-#enc1 = #test.tensor_encoding<"hello">
-#enc2 = #test.tensor_encoding<"not hello">
-
-"test.symbol_scope_isolated"() ({
- // CHECK: func @inner_func(
- // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
- // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">>
- func.func @inner_func(%t: tensor<?xf32, #enc1>)
- -> tensor<?xf32, #enc1> {
- // CHECK: return %[[arg0]]
- return %t : tensor<?xf32, #enc1>
- }
-
- // CHECK: func @outer_func(
- // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
- // CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>,
- // CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>)
- func.func @outer_func(%t0: tensor<?xf32, #enc1>)
- -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
- // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
- %0 = call @inner_func(%t0)
- : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
-
- // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
- // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">>
- %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
- // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
- %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
- -> tensor<?xf32, #enc2>
-
- // CHECK: return %[[call]], %[[dummy]]
- return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
- }
- "test.finish" () : () -> ()
-}) : () -> ()
diff --git a/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir
new file mode 100644
index 0000000000000..d336d265b6d4b
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s -test-one-shot-module-bufferize -split-input-file | FileCheck %s
+
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+module @BufferizeEncodingThroughFunctionBoundaryAndCustomOps {
+ // CHECK: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">>
+ func.func @inner_func(%t: tensor<?xf32, #enc1>)
+ -> tensor<?xf32, #enc1> {
+ // CHECK: return %[[arg0]]
+ return %t : tensor<?xf32, #enc1>
+ }
+
+ // CHECK: func @outer_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>,
+ // CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>)
+ func.func @outer_func(%t0: tensor<?xf32, #enc1>)
+ -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
+ // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+ %0 = call @inner_func(%t0)
+ : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
+
+ // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">>
+ %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
+ // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
+ %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
+ -> tensor<?xf32, #enc2>
+
+ // CHECK: return %[[call]], %[[dummy]]
+ return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
+ }
+}
+
+// -----
+
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+// The memref's layout must come from the encoding, not from the default
+// static-identity layout.
+module @BufferizeEncodingForAlloc {
+ // CHECK: func @some_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<42xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> (memref<42xf32, #test.memref_layout<"hello">>,
+ // CHECK-SAME: memref<42xf32, #test.memref_layout<"not hello">>)
+ func.func @some_func(%t0: tensor<42xf32, #enc1>)
+ -> (tensor<42xf32, #enc1>, tensor<42xf32, #enc2>) {
+ // CHECK: %[[T0:.+]] = memref.alloc() {{.*}} : memref<42xf32, #test.memref_layout<"hello">>
+ %0 = bufferization.alloc_tensor() : tensor<42xf32, #enc1>
+
+ // CHECK: %[[T1:.+]] = memref.alloc() {{.*}} : memref<42xf32, #test.memref_layout<"not hello">>
+ %1 = bufferization.alloc_tensor() : tensor<42xf32, #enc2>
+
+ // CHECK: return %[[T0]], %[[T1]]
+ return %0, %1 : tensor<42xf32, #enc1>, tensor<42xf32, #enc2>
+ }
+}
+
+// -----
+
+#enc1 = #test.tensor_encoding<"custom">
+
+module @BufferizeEncodingWithDefaultBufferizationApi {
+
+ // CHECK: func.func @custom_encoding_inside_scf(
+ // CHECK-SAME: %[[arg:.*]]: memref<42xf64, #test.memref_layout<"custom">>,
+ // CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index)
+ // CHECK-SAME: -> memref<42xf64, #test.memref_layout<"custom">>
+ func.func @custom_encoding_inside_scf(
+ %arg: tensor<42xf64, #enc1>,
+ %lb: index, %ub: index, %step: index)
+ -> tensor<42xf64, #enc1> {
+ // CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]]
+ // CHECK-SAME: iter_args(%[[iter:.+]] = %[[arg]]) -> (memref<42xf64, #test.memref_layout<"custom">>) {
+ // CHECK: %[[call:.+]] = "test.dummy_memref_op"(%[[iter]])
+ // CHECK: scf.yield %[[call]] : memref<42xf64, #test.memref_layout<"custom">>
+ %loop = scf.for %i = %lb to %ub step %step
+ iter_args(%iter = %arg) -> (tensor<42xf64, #enc1>) {
+ %call = "test.dummy_tensor_op"(%iter) : (tensor<42xf64, #enc1>)
+ -> tensor<42xf64, #enc1>
+ scf.yield %call : tensor<42xf64, #enc1>
+ }
+
+ // CHECK: return %[[loop]]
+ return %loop : tensor<42xf64, #enc1>
+ }
+}
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index dead1a4b7e047..58660e1cf8f45 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -66,6 +66,11 @@ struct TestOneShotModuleBufferizePass
MemRefType::get(tensorType.getShape(),
tensorType.getElementType(), layout, memSpace));
};
+ opt.createMemRefLayoutFn = [&](TensorType tensor) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ return getMemRefLayoutForTensorEncoding(tensorType);
+ };
bufferization::BufferizationState bufferizationState;
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 340b44b14dd96..48036ce6f167b 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1741,39 +1741,6 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
-namespace {
-/// Returns test dialect's memref layout for test dialect's tensor encoding when
-/// applicable.
-MemRefLayoutAttrInterface
-getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
- if (auto encoding =
- dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
- return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
- tensorType.getContext(), encoding.getDummy()));
- }
- return {};
-}
-
-/// Auxiliary bufferization function for test and builtin tensors.
-bufferization::BufferLikeType
-convertTensorToBuffer(mlir::Operation *op,
- const bufferization::BufferizationOptions &options,
- bufferization::TensorLikeType tensorLike) {
- auto buffer =
- *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
- if (auto memref = dyn_cast<MemRefType>(buffer)) {
- // Note: For the sake of testing, we want to ensure that encoding -> layout
- // bufferization happens. This is currently achieved manually.
- auto layout =
- getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
- return cast<bufferization::BufferLikeType>(
- MemRefType::get(memref.getShape(), memref.getElementType(), layout,
- memref.getMemorySpace()));
- }
- return buffer;
-}
-} // namespace
-
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
@@ -1783,12 +1750,16 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
if (mlir::failed(buffer))
return failure();
- const auto outType = getOutput().getType();
+ // Note: mlir::bufferization::getBufferType() would internally call
+ // TestDummyTensorOp::getBufferType()
const auto bufferizedOutType =
- convertTensorToBuffer(getOperation(), options, outType);
+ mlir::bufferization::getBufferType(getOutput(), options, state);
+ if (mlir::failed(bufferizedOutType))
+ return failure();
+
// replace op with memref analogy
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
- rewriter, getLoc(), bufferizedOutType, *buffer);
+ rewriter, getLoc(), *bufferizedOutType, *buffer);
mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
dummyMemrefOp.getResult());
@@ -1798,15 +1769,14 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
mlir::FailureOr<mlir::bufferization::BufferLikeType>
test::TestDummyTensorOp::getBufferType(
- mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
const mlir::bufferization::BufferizationState &,
llvm::SmallVector<::mlir::Value> &) {
- const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
if (type == nullptr)
return failure();
- return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
- getContext(), type.getShape(), type.getElementType(), nullptr));
+ return type.getBufferType(options, [&]() { return emitError(); });
}
::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
@@ -1839,7 +1809,7 @@ test::TestCreateTensorOp::getBufferType(
if (type == nullptr)
return failure();
- return convertTensorToBuffer(getOperation(), options, type);
+ return type.getBufferType(options, [&]() { return emitError(); });
}
// Define a custom builder for ManyRegionsOp declared in TestOps.td.
More information about the Mlir-commits
mailing list