[Mlir-commits] [mlir] [mlir][bufferization] Convert tensor enconding -> memref layout (PR #161166)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 29 03:15:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Support custom types (4/N): allow user-specified bufferization of tensor encoding into memref layout.
Both tensor encoding and memref layout could be user-specified attributes to store arbitrary information. It is often the case that this information has to be preserved during tensor -> memref bufferization. Thus, provide an option function to create memref layout during TensorType::getBufferType() execution.
As a drive by, update AllocTensorOp::getBufferType() to work via TensorLikeType::getBufferType() when memref layout is user-specified.
---
Full diff: https://github.com/llvm/llvm-project/pull/161166.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+9)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+4-1)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+9)
- (modified) mlir/unittests/Transforms/CMakeLists.txt (+4-1)
- (added) mlir/unittests/Transforms/OneShotBufferization.cpp (+171)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
+ /// Construct a MemRefLayoutAttrInterface from a tensor type.
+ using ConstructMemRefLayoutFn =
+ std::function<MemRefLayoutAttrInterface(TensorType t)>;
BufferizationOptions();
@@ -364,6 +367,12 @@ struct BufferizationOptions {
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ /// Construction function used to determine the memref layout based on the
+ /// original tensor type. Can be used to specialize tensor encoding -> memref
+ /// layout conversion. By default, it is unset, making the layout construction
+ /// behavior depend on the place where it is used.
+ ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..74864bfd57e58 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
auto memSpace = options.defaultMemorySpaceFn(tensorType);
if (!memSpace.has_value())
return emitError() << "could not infer memory space";
+ MemRefLayoutAttrInterface layout = {};
+ if (options.constructMemRefLayoutFn)
+ layout = options.constructMemRefLayoutFn(tensorType);
return cast<BufferLikeType>(
- getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+ getMemRefType(tensorType, options, layout, *memSpace));
}
mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
+ // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+ // explicitly specified by the user. Otherwise, the default behavior is to
+ // return a fully dynamic layout map which is the opposite of the default
+ // behavior of this function.
+ if (options.constructMemRefLayoutFn) {
+ return cast<TensorLikeType>(getType()).getBufferType(
+ options, [&]() { return emitError(); });
+ }
+
return cast<BufferLikeType>(
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformsTests
Canonicalizer.cpp
DialectConversion.cpp
+ OneShotBufferization.cpp
)
mlir_target_link_libraries(MLIRTransformsTests
PRIVATE
MLIRParser
- MLIRTransforms)
+ MLIRTransforms
+ MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ static TestTensorAttr fromStringAttr(StringAttr attr) {
+ return mlir::dyn_cast<TestTensorAttr>(attr);
+ }
+};
+
+class TestTensorEncodingVerifier final
+ : public mlir::VerifiableTensorEncoding::ExternalModel<
+ TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ mlir::LogicalResult verifyEncoding(
+ mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestTensorAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown Tensor enconding: " << attr;
+ }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ mlir::AffineMap getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+ }
+};
+
+class TestMemRefAttrLayout final
+ : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+ TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ bool isIdentity(mlir::Attribute) const { return true; }
+ mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+ return cast<TestMemRefAttr>(attr).getAffineMap();
+ }
+ mlir::LogicalResult
+ verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestMemRefAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown MemRef layout: " << attr;
+ }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+ MLIRContext context;
+ context.getOrLoadDialect<BuiltinDialect>();
+ context.getOrLoadDialect<func::FuncDialect>();
+ context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+ DialectRegistry registry;
+ registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+ TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+ TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+ });
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ context.appendDialectRegistry(registry);
+
+ const char *const code = R"mlir(
+ func.func @foo(%t: tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello"> {
+ return %t : tensor<42xf32, "hello">
+ }
+
+ func.func @bar(%t1: tensor<42xf32, "hello">)
+ -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+ %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello">
+
+ %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+ return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+ }
+ )mlir";
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+ ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+ bufferization::OneShotBufferizationOptions options{};
+ options.bufferizeFunctionBoundaries = true;
+ options.constructMemRefLayoutFn =
+ [](TensorType tensor) -> MemRefLayoutAttrInterface {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(
+ TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+ }
+ return {};
+ };
+ options.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = options.constructMemRefLayoutFn(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memSpace));
+ };
+
+ bufferization::BufferizationState state;
+ ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+ module->getOperation(), options, state)));
+
+ const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+ if (auto memref = dyn_cast<MemRefType>(type)) {
+ if (auto layout = memref.getLayout();
+ isa_and_nonnull<TestMemRefAttr>(layout)) {
+ return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+ }
+ }
+ return false;
+ };
+
+ auto fooOp = *module->getOps<func::FuncOp>().begin();
+ ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+ auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+ ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/161166
More information about the Mlir-commits
mailing list