[Mlir-commits] [mlir] [mlir][bufferization] Introduce createMemRefLayoutFn hook (PR #195622)

Andrei Golubev llvmlistbot at llvm.org
Mon May 4 02:30:51 PDT 2026


https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/195622

Add a framework-provided hook that derives a MemRefLayoutAttrInterface from a tensor type encoding during tensor-to-memref conversion. The hook is intended to be a op-independent customization point for users wishing to preserve custom encoding information across the bufferization process.

The hook is wired into the two main tensor-to-memref entry points that are not driven by an op-specific BufferizableOpInterface::getBufferType:

* BuiltinTensorExternalModel::getBufferType (generic tensor-like fallback),
* bufferization.alloc_tensor (replaces the unconditional identity layout when the hook is set; identity is preserved as the default).

Behaviour is unchanged when the hook is not set (nullptr by default).

>From 6cc43c99f7d29dcee0af9565ac150f1e668bf73f Mon Sep 17 00:00:00 2001
From: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
Date: Fri, 24 Apr 2026 09:02:02 +0000
Subject: [PATCH] [mlir][bufferization] Introduce createMemRefLayoutFn hook

Add a framework-provided hook that derives a MemRefLayoutAttrInterface
from a tensor type encoding during tensor-to-memref conversion. The hook
is intended to be a op-independent customization point for users wishing
to preserve custom encoding information across the bufferization
process.

The hook is wired into the two main tensor-to-memref entry points that
are not driven by an op-specific BufferizableOpInterface::getBufferType:

* BuiltinTensorExternalModel::getBufferType (generic tensor-like
  fallback),
* bufferization.alloc_tensor (replaces the unconditional identity layout
  when the hook is set; identity is preserved as the default).

Behaviour is unchanged when the hook is not set (nullptr by default).

Co-authored-by: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
---
 .../IR/BufferizableOpInterface.h              | 11 +++
 .../Bufferization/IR/BufferizationDialect.cpp |  5 +-
 .../Bufferization/IR/BufferizationOps.cpp     |  7 ++
 .../one-shot-non-module-bufferize.mlir        | 38 --------
 .../test-one-shot-module-bufferize.mlir       | 91 +++++++++++++++++++
 .../TestOneShotModuleBufferize.cpp            |  5 +
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 52 +++--------
 7 files changed, 129 insertions(+), 80 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3f8392e3b8970..5a4a199c1b45e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,11 @@ struct BufferizationOptions {
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
+  /// Create a MemRefLayoutAttrInterface object from the tensor encoding. This
+  /// is called whenever bufferization materializes a memref from a tensor value
+  /// (allocations, generic tensor-like conversion fallbacks, etc.).
+  using CreateMemRefLayoutFn =
+      std::function<MemRefLayoutAttrInterface(TensorType)>;
 
   BufferizationOptions();
 
@@ -364,6 +369,12 @@ struct BufferizationOptions {
   DefaultMemorySpaceFn defaultMemorySpaceFn =
       [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
 
+  /// Optional universal bufferization hook to control a tensor-encoding ->
+  /// memref-layout conversion process. By default, it is unset (`nullptr`),
+  /// meaning that the bufferization process would perform call-site dependent
+  /// layout inference (e.g. return identity layout, or fully dynamic layout).
+  CreateMemRefLayoutFn createMemRefLayoutFn = nullptr;
+
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
   bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index bd177ba1afccd..b5a04e352e138 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -47,8 +47,11 @@ struct BuiltinTensorExternalModel
     if (!memSpace.has_value())
       return emitError() << "could not infer memory space";
 
+    MemRefLayoutAttrInterface layout = {};
+    if (options.createMemRefLayoutFn)
+      layout = options.createMemRefLayoutFn(tensorType);
     return cast<BufferLikeType>(
-        getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+        getMemRefType(tensorType, options, layout, *memSpace));
   }
 
   mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index c525ec116f699..790e09d0d01ec 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -246,6 +246,13 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
+  if (options.createMemRefLayoutFn) {
+    if (auto layout = options.createMemRefLayoutFn(getType())) {
+      return cast<BufferLikeType>(
+          getMemRefType(getType(), options, layout, memorySpace));
+    }
+  }
+
   return cast<BufferLikeType>(
       getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
index b52612d0d1f10..09fdf8231b7cc 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -29,41 +29,3 @@
   }
   "test.finish" () : () -> ()
 }) : () -> ()
-
-// -----
-
-#enc1 = #test.tensor_encoding<"hello">
-#enc2 = #test.tensor_encoding<"not hello">
-
-"test.symbol_scope_isolated"() ({
-  // CHECK: func @inner_func(
-  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
-  // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"hello">>
-  func.func @inner_func(%t: tensor<?xf32, #enc1>)
-      -> tensor<?xf32, #enc1> {
-    // CHECK: return %[[arg0]]
-    return %t : tensor<?xf32, #enc1>
-  }
-
-  // CHECK: func @outer_func(
-  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
-  // CHECK-SAME:  -> (memref<?xf32, #test.memref_layout<"hello">>,
-  // CHECK-SAME:      memref<?xf32, #test.memref_layout<"not hello">>)
-  func.func @outer_func(%t0: tensor<?xf32, #enc1>)
-      -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
-    // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
-    %0 = call @inner_func(%t0)
-      : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
-
-    // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
-    // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"not hello">>
-    %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
-    // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
-    %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
-      -> tensor<?xf32, #enc2>
-
-    // CHECK: return %[[call]], %[[dummy]]
-    return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
-  }
-  "test.finish" () : () -> ()
-}) : () -> ()
diff --git a/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir
new file mode 100644
index 0000000000000..d336d265b6d4b
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/test-one-shot-module-bufferize.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s -test-one-shot-module-bufferize -split-input-file | FileCheck %s
+
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+module @BufferizeEncodingThroughFunctionBoundaryAndCustomOps {
+  // CHECK: func @inner_func(
+  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+  // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"hello">>
+  func.func @inner_func(%t: tensor<?xf32, #enc1>)
+      -> tensor<?xf32, #enc1> {
+    // CHECK: return %[[arg0]]
+    return %t : tensor<?xf32, #enc1>
+  }
+
+  // CHECK: func @outer_func(
+  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+  // CHECK-SAME:  -> (memref<?xf32, #test.memref_layout<"hello">>,
+  // CHECK-SAME:      memref<?xf32, #test.memref_layout<"not hello">>)
+  func.func @outer_func(%t0: tensor<?xf32, #enc1>)
+      -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
+    // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+    %0 = call @inner_func(%t0)
+      : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
+
+    // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
+    // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"not hello">>
+    %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
+    // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
+    %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
+      -> tensor<?xf32, #enc2>
+
+    // CHECK: return %[[call]], %[[dummy]]
+    return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
+  }
+}
+
+// -----
+
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+// The memref's layout must come from the encoding, not from the default
+// static-identity layout.
+module @BufferizeEncodingForAlloc {
+  // CHECK: func @some_func(
+  // CHECK-SAME:  %[[arg0:.*]]: memref<42xf32, #test.memref_layout<"hello">>)
+  // CHECK-SAME:  -> (memref<42xf32, #test.memref_layout<"hello">>,
+  // CHECK-SAME:      memref<42xf32, #test.memref_layout<"not hello">>)
+  func.func @some_func(%t0: tensor<42xf32, #enc1>)
+      -> (tensor<42xf32, #enc1>, tensor<42xf32, #enc2>) {
+    // CHECK: %[[T0:.+]] = memref.alloc() {{.*}} : memref<42xf32, #test.memref_layout<"hello">>
+    %0 = bufferization.alloc_tensor() : tensor<42xf32, #enc1>
+
+    // CHECK: %[[T1:.+]] = memref.alloc() {{.*}} : memref<42xf32, #test.memref_layout<"not hello">>
+    %1 = bufferization.alloc_tensor() : tensor<42xf32, #enc2>
+
+    // CHECK: return %[[T0]], %[[T1]]
+    return %0, %1 : tensor<42xf32, #enc1>, tensor<42xf32, #enc2>
+  }
+}
+
+// -----
+
+#enc1 = #test.tensor_encoding<"custom">
+
+module @BufferizeEncodingWithDefaultBufferizationApi {
+
+  // CHECK: func.func @custom_encoding_inside_scf(
+  // CHECK-SAME:  %[[arg:.*]]: memref<42xf64, #test.memref_layout<"custom">>,
+  // CHECK-SAME:  %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index)
+  // CHECK-SAME:  -> memref<42xf64, #test.memref_layout<"custom">>
+  func.func @custom_encoding_inside_scf(
+      %arg: tensor<42xf64, #enc1>,
+      %lb: index, %ub: index, %step: index)
+      -> tensor<42xf64, #enc1> {
+    // CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]]
+    // CHECK-SAME: iter_args(%[[iter:.+]] = %[[arg]]) -> (memref<42xf64, #test.memref_layout<"custom">>) {
+    // CHECK: %[[call:.+]] = "test.dummy_memref_op"(%[[iter]])
+    // CHECK: scf.yield %[[call]] : memref<42xf64, #test.memref_layout<"custom">>
+    %loop = scf.for %i = %lb to %ub step %step
+        iter_args(%iter = %arg) -> (tensor<42xf64, #enc1>) {
+      %call = "test.dummy_tensor_op"(%iter) : (tensor<42xf64, #enc1>)
+        -> tensor<42xf64, #enc1>
+      scf.yield %call : tensor<42xf64, #enc1>
+    }
+
+    // CHECK: return %[[loop]]
+    return %loop : tensor<42xf64, #enc1>
+  }
+}
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index dead1a4b7e047..58660e1cf8f45 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -66,6 +66,11 @@ struct TestOneShotModuleBufferizePass
               MemRefType::get(tensorType.getShape(),
                               tensorType.getElementType(), layout, memSpace));
         };
+    opt.createMemRefLayoutFn = [&](TensorType tensor) {
+      assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+      auto tensorType = cast<RankedTensorType>(tensor);
+      return getMemRefLayoutForTensorEncoding(tensorType);
+    };
 
     bufferization::BufferizationState bufferizationState;
 
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 340b44b14dd96..48036ce6f167b 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1741,39 +1741,6 @@ TestMultiSlotAlloca::handleDestructuringComplete(
   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
 }
 
-namespace {
-/// Returns test dialect's memref layout for test dialect's tensor encoding when
-/// applicable.
-MemRefLayoutAttrInterface
-getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
-  if (auto encoding =
-          dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
-    return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
-        tensorType.getContext(), encoding.getDummy()));
-  }
-  return {};
-}
-
-/// Auxiliary bufferization function for test and builtin tensors.
-bufferization::BufferLikeType
-convertTensorToBuffer(mlir::Operation *op,
-                      const bufferization::BufferizationOptions &options,
-                      bufferization::TensorLikeType tensorLike) {
-  auto buffer =
-      *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
-  if (auto memref = dyn_cast<MemRefType>(buffer)) {
-    // Note: For the sake of testing, we want to ensure that encoding -> layout
-    // bufferization happens. This is currently achieved manually.
-    auto layout =
-        getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
-    return cast<bufferization::BufferLikeType>(
-        MemRefType::get(memref.getShape(), memref.getElementType(), layout,
-                        memref.getMemorySpace()));
-  }
-  return buffer;
-}
-} // namespace
-
 ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
     ::mlir::RewriterBase &rewriter,
     const ::mlir::bufferization::BufferizationOptions &options,
@@ -1783,12 +1750,16 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
   if (mlir::failed(buffer))
     return failure();
 
-  const auto outType = getOutput().getType();
+  // Note: mlir::bufferization::getBufferType() would internally call
+  // TestDummyTensorOp::getBufferType()
   const auto bufferizedOutType =
-      convertTensorToBuffer(getOperation(), options, outType);
+      mlir::bufferization::getBufferType(getOutput(), options, state);
+  if (mlir::failed(bufferizedOutType))
+    return failure();
+
   // replace op with memref analogy
   auto dummyMemrefOp = test::TestDummyMemrefOp::create(
-      rewriter, getLoc(), bufferizedOutType, *buffer);
+      rewriter, getLoc(), *bufferizedOutType, *buffer);
 
   mlir::bufferization::replaceOpWithBufferizedValues(rewriter, getOperation(),
                                                      dummyMemrefOp.getResult());
@@ -1798,15 +1769,14 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
 
 mlir::FailureOr<mlir::bufferization::BufferLikeType>
 test::TestDummyTensorOp::getBufferType(
-    mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+    mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
     const mlir::bufferization::BufferizationState &,
     llvm::SmallVector<::mlir::Value> &) {
-  const auto type = dyn_cast<test::TestTensorType>(value.getType());
+  const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
   if (type == nullptr)
     return failure();
 
-  return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
-      getContext(), type.getShape(), type.getElementType(), nullptr));
+  return type.getBufferType(options, [&]() { return emitError(); });
 }
 
 ::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
@@ -1839,7 +1809,7 @@ test::TestCreateTensorOp::getBufferType(
   if (type == nullptr)
     return failure();
 
-  return convertTensorToBuffer(getOperation(), options, type);
+  return type.getBufferType(options, [&]() { return emitError(); });
 }
 
 // Define a custom builder for ManyRegionsOp declared in TestOps.td.



More information about the Mlir-commits mailing list