[Mlir-commits] [mlir] [mlir][bufferization] Convert tensor enconding -> memref layout (PR #161166)
Andrei Golubev
llvmlistbot at llvm.org
Mon Sep 29 03:44:29 PDT 2025
https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/161166
>From 94419f20c22c9645f2bb5bbd87fa66d94f34f665 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Fri, 26 Sep 2025 14:50:09 +0000
Subject: [PATCH] [mlir][bufferization] Convert tensor enconding -> memref
layout
Support custom types (4/N): allow user-specified bufferization of tensor
encoding into memref layout.
Both tensor encoding and memref layout could be user-specified
attributes to store arbitrary information. It is often the case that
this information has to be preserved during tensor -> memref
bufferization. Thus, provide an option function to create memref
layout during TensorType::getBufferType() execution.
As a drive by, update AllocTensorOp::getBufferType() to work via
TensorLikeType::getBufferType() when memref layout is user-specified.
---
.../IR/BufferizableOpInterface.h | 9 +
.../Bufferization/IR/BufferizationDialect.cpp | 5 +-
.../Bufferization/IR/BufferizationOps.cpp | 9 +
mlir/unittests/Transforms/CMakeLists.txt | 5 +-
.../Transforms/OneShotBufferization.cpp | 171 ++++++++++++++++++
5 files changed, 197 insertions(+), 2 deletions(-)
create mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
+ /// Construct a MemRefLayoutAttrInterface from a tensor type.
+ using ConstructMemRefLayoutFn =
+ std::function<MemRefLayoutAttrInterface(TensorType t)>;
BufferizationOptions();
@@ -364,6 +367,12 @@ struct BufferizationOptions {
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ /// Construction function used to determine the memref layout based on the
+ /// original tensor type. Can be used to specialize tensor encoding -> memref
+ /// layout conversion. By default, it is unset, making the layout construction
+ /// behavior depend on the place where it is used.
+ ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..db06edcbe3d59 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
auto memSpace = options.defaultMemorySpaceFn(tensorType);
if (!memSpace.has_value())
return emitError() << "could not infer memory space";
+ MemRefLayoutAttrInterface layout = {};
+ if (options.constructMemRefLayoutFn)
+ layout = options.constructMemRefLayoutFn(tensorType);
return cast<BufferLikeType>(
- getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+ getMemRefType(tensorType, options, layout, *memSpace));
}
mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
+ // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+ // explicitly specified by the user. Otherwise, the default behavior is to
+ // return a fully dynamic layout map which is the opposite of the default
+ // behavior of this function.
+ if (options.constructMemRefLayoutFn) {
+ return cast<TensorLikeType>(getType()).getBufferType(
+ options, [&]() { return emitError(); });
+ }
+
return cast<BufferLikeType>(
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformsTests
Canonicalizer.cpp
DialectConversion.cpp
+ OneShotBufferization.cpp
)
mlir_target_link_libraries(MLIRTransformsTests
PRIVATE
MLIRParser
- MLIRTransforms)
+ MLIRTransforms
+ MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ static TestTensorAttr fromStringAttr(StringAttr attr) {
+ return mlir::dyn_cast<TestTensorAttr>(attr);
+ }
+};
+
+class TestTensorEncodingVerifier final
+ : public mlir::VerifiableTensorEncoding::ExternalModel<
+ TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ mlir::LogicalResult verifyEncoding(
+ mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestTensorAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown Tensor enconding: " << attr;
+ }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ mlir::AffineMap getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+ }
+};
+
+class TestMemRefAttrLayout final
+ : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+ TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ bool isIdentity(mlir::Attribute) const { return true; }
+ mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+ return cast<TestMemRefAttr>(attr).getAffineMap();
+ }
+ mlir::LogicalResult
+ verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestMemRefAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown MemRef layout: " << attr;
+ }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+ MLIRContext context;
+ context.getOrLoadDialect<BuiltinDialect>();
+ context.getOrLoadDialect<func::FuncDialect>();
+ context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+ DialectRegistry registry;
+ registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+ TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+ TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+ });
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ context.appendDialectRegistry(registry);
+
+ const char *const code = R"mlir(
+ func.func @foo(%t: tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello"> {
+ return %t : tensor<42xf32, "hello">
+ }
+
+ func.func @bar(%t1: tensor<42xf32, "hello">)
+ -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+ %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello">
+
+ %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+ return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+ }
+ )mlir";
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+ ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+ bufferization::OneShotBufferizationOptions options{};
+ options.bufferizeFunctionBoundaries = true;
+ options.constructMemRefLayoutFn =
+ [](TensorType tensor) -> MemRefLayoutAttrInterface {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(
+ TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+ }
+ return {};
+ };
+ options.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = options.constructMemRefLayoutFn(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memSpace));
+ };
+
+ bufferization::BufferizationState state;
+ ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+ module->getOperation(), options, state)));
+
+ const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+ if (auto memref = dyn_cast<MemRefType>(type)) {
+ if (auto layout = memref.getLayout();
+ isa_and_nonnull<TestMemRefAttr>(layout)) {
+ return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+ }
+ }
+ return false;
+ };
+
+ auto fooOp = *module->getOps<func::FuncOp>().begin();
+ ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+ auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+ ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace
More information about the Mlir-commits
mailing list