[Mlir-commits] [mlir] [mlir][bufferization] Convert tensor enconding into memref layout (PR #161166)
Andrei Golubev
llvmlistbot at llvm.org
Mon Oct 6 05:41:42 PDT 2025
https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/161166
>From ece4805a82221b7b8cca6e6352980d424a9a8e41 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Fri, 26 Sep 2025 14:50:09 +0000
Subject: [PATCH 1/3] [mlir][bufferization] Convert tensor enconding -> memref
layout
Support custom types (4/N): allow user-specified bufferization of tensor
encoding into memref layout.
Both tensor encoding and memref layout could be user-specified
attributes to store arbitrary information. It is often the case that
this information has to be preserved during tensor -> memref
bufferization. Thus, provide an option function to create memref
layout during TensorType::getBufferType() execution.
As a drive by, update AllocTensorOp::getBufferType() to work via
TensorLikeType::getBufferType() when memref layout is user-specified.
---
.../IR/BufferizableOpInterface.h | 9 +
.../Bufferization/IR/BufferizationDialect.cpp | 5 +-
.../Bufferization/IR/BufferizationOps.cpp | 9 +
mlir/unittests/Transforms/CMakeLists.txt | 5 +-
.../Transforms/OneShotBufferization.cpp | 171 ++++++++++++++++++
5 files changed, 197 insertions(+), 2 deletions(-)
create mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
+ /// Construct a MemRefLayoutAttrInterface from a tensor type.
+ using ConstructMemRefLayoutFn =
+ std::function<MemRefLayoutAttrInterface(TensorType t)>;
BufferizationOptions();
@@ -364,6 +367,12 @@ struct BufferizationOptions {
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ /// Construction function used to determine the memref layout based on the
+ /// original tensor type. Can be used to specialize tensor encoding -> memref
+ /// layout conversion. By default, it is unset, making the layout construction
+ /// behavior depend on the place where it is used.
+ ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..db06edcbe3d59 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
auto memSpace = options.defaultMemorySpaceFn(tensorType);
if (!memSpace.has_value())
return emitError() << "could not infer memory space";
+ MemRefLayoutAttrInterface layout = {};
+ if (options.constructMemRefLayoutFn)
+ layout = options.constructMemRefLayoutFn(tensorType);
return cast<BufferLikeType>(
- getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+ getMemRefType(tensorType, options, layout, *memSpace));
}
mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
+ // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+ // explicitly specified by the user. Otherwise, the default behavior is to
+ // return a fully dynamic layout map which is the opposite of the default
+ // behavior of this function.
+ if (options.constructMemRefLayoutFn) {
+ return cast<TensorLikeType>(getType()).getBufferType(
+ options, [&]() { return emitError(); });
+ }
+
return cast<BufferLikeType>(
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
add_mlir_unittest(MLIRTransformsTests
Canonicalizer.cpp
DialectConversion.cpp
+ OneShotBufferization.cpp
)
mlir_target_link_libraries(MLIRTransformsTests
PRIVATE
MLIRParser
- MLIRTransforms)
+ MLIRTransforms
+ MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ static TestTensorAttr fromStringAttr(StringAttr attr) {
+ return mlir::dyn_cast<TestTensorAttr>(attr);
+ }
+};
+
+class TestTensorEncodingVerifier final
+ : public mlir::VerifiableTensorEncoding::ExternalModel<
+ TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ mlir::LogicalResult verifyEncoding(
+ mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestTensorAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown Tensor enconding: " << attr;
+ }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+ using mlir::StringAttr::StringAttr;
+
+ static bool classof(mlir::Attribute attr) {
+ return mlir::isa<mlir::StringAttr>(attr);
+ }
+
+ mlir::AffineMap getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+ }
+};
+
+class TestMemRefAttrLayout final
+ : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+ TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+ using ConcreteEntity = mlir::StringAttr;
+
+ bool isIdentity(mlir::Attribute) const { return true; }
+ mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+ return cast<TestMemRefAttr>(attr).getAffineMap();
+ }
+ mlir::LogicalResult
+ verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+ mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+ std::ignore = shape;
+
+ if (mlir::isa<TestMemRefAttr>(attr)) {
+ return mlir::success();
+ }
+ return emitError() << "Unknown MemRef layout: " << attr;
+ }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+ MLIRContext context;
+ context.getOrLoadDialect<BuiltinDialect>();
+ context.getOrLoadDialect<func::FuncDialect>();
+ context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+ DialectRegistry registry;
+ registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+ TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+ TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+ });
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ context.appendDialectRegistry(registry);
+
+ const char *const code = R"mlir(
+ func.func @foo(%t: tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello"> {
+ return %t : tensor<42xf32, "hello">
+ }
+
+ func.func @bar(%t1: tensor<42xf32, "hello">)
+ -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+ %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+ -> tensor<42xf32, "hello">
+
+ %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+ return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+ }
+ )mlir";
+
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+ ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+ bufferization::OneShotBufferizationOptions options{};
+ options.bufferizeFunctionBoundaries = true;
+ options.constructMemRefLayoutFn =
+ [](TensorType tensor) -> MemRefLayoutAttrInterface {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(
+ TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+ }
+ return {};
+ };
+ options.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = options.constructMemRefLayoutFn(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+ layout, memSpace));
+ };
+
+ bufferization::BufferizationState state;
+ ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+ module->getOperation(), options, state)));
+
+ const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+ if (auto memref = dyn_cast<MemRefType>(type)) {
+ if (auto layout = memref.getLayout();
+ isa_and_nonnull<TestMemRefAttr>(layout)) {
+ return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+ }
+ }
+ return false;
+ };
+
+ auto fooOp = *module->getOps<func::FuncOp>().begin();
+ ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+ auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+ ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+ ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace
>From a6c3ef2d965eb76fae4d206bb5b94b2ee900f83b Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 6 Oct 2025 11:50:34 +0000
Subject: [PATCH 2/3] [mlir][bufferization] Test tensor encoding -> memref
layout conversion
Support custom types (4/N): test that it is possible to customize memref
layout specification for custom operations and function boundaries.
This is purely a test setup (no API modifications) to ensure users are
able to pass information from tensors to memrefs within bufferization
process. To achieve this, a test pass is required (since bufferization
options have to be set manually). As there is already a
--test-one-shot-module-bufferize pass present, it is extended for the
purpose.
---
.../one-shot-non-module-bufferize.mlir | 38 +++++++++++++++-
.../TestOneShotModuleBufferize.cpp | 26 +++++++++++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 17 +++++++
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 18 ++++++++
mlir/test/lib/Dialect/Test/TestAttributes.h | 1 +
mlir/test/lib/Dialect/Test/TestDialect.h | 1 +
mlir/test/lib/Dialect/Test/TestDialect.td | 5 ++-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 44 ++++++++++++++++---
mlir/test/lib/Dialect/Test/TestOps.td | 15 ++++---
9 files changed, 150 insertions(+), 15 deletions(-)
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
index e2ab876f8b46a..b52612d0d1f10 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -24,10 +24,46 @@
// CHECK-NOT: copy
// CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
- // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+ // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}>
return %1, %0 : f32, tensor<?xf32>
}
"test.finish" () : () -> ()
}) : () -> ()
+// -----
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+"test.symbol_scope_isolated"() ({
+ // CHECK: func @inner_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"hello">>
+ func.func @inner_func(%t: tensor<?xf32, #enc1>)
+ -> tensor<?xf32, #enc1> {
+ // CHECK: return %[[arg0]]
+ return %t : tensor<?xf32, #enc1>
+ }
+
+ // CHECK: func @outer_func(
+ // CHECK-SAME: %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+ // CHECK-SAME: -> (memref<?xf32, #test.memref_layout<"hello">>,
+ // CHECK-SAME: memref<?xf32, #test.memref_layout<"not hello">>)
+ func.func @outer_func(%t0: tensor<?xf32, #enc1>)
+ -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
+ // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+ %0 = call @inner_func(%t0)
+ : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
+
+ // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
+ // CHECK-SAME: -> memref<?xf32, #test.memref_layout<"not hello">>
+ %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
+ // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
+ %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
+ -> tensor<?xf32, #enc2>
+
+ // CHECK: return %[[call]], %[[dummy]]
+ return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
+ }
+ "test.finish" () : () -> ()
+}) : () -> ()
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index 1e2d4a7c8f08d..4069a74dbf63c 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -11,11 +11,25 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
+#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr
+#include "TestDialect.h"
+
using namespace mlir;
namespace {
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>(
+ tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
struct TestOneShotModuleBufferizePass
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
@@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<test::TestDialect>();
registry.insert<bufferization::BufferizationDialect>();
}
StringRef getArgument() const final {
@@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass
bufferization::OneShotBufferizationOptions opt;
opt.bufferizeFunctionBoundaries = true;
+ opt.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = getMemRefLayoutForTensorEncoding(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType(), layout, memSpace));
+ };
+
bufferization::BufferizationState bufferizationState;
if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 5685004bbbd25..9e7e4f883b576 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/TensorEncoding.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> {
let hasStorageCustomConstructor = 1;
}
+def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding",
+ [DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> {
+ let mnemonic = "tensor_encoding";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
+def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
+ [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
+ let mnemonic = "memref_layout";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index fe1e9166a3099..9db7b01dd193b 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -541,6 +541,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct(
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// TestTensorEncodingAttr
+//===----------------------------------------------------------------------===//
+
+::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding(
+ mlir::ArrayRef<int64_t> shape, mlir::Type elementType,
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestMemRefLayoutAttr
+//===----------------------------------------------------------------------===//
+
+mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+}
+
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 778d84fae7365..0ad5ab641c6d0 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -24,6 +24,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/TensorEncoding.h"
// generated files require above includes to come first
#include "TestAttrInterfaces.h.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index f2adca6310d78..bcf3b55d33cb9 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -18,6 +18,7 @@
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 2b5491fc0c6a0..37a263f1d10b8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect {
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
- let dependentDialects = ["::mlir::DLTIDialect"];
+ let dependentDialects = [
+ "::mlir::DLTIDialect",
+ "::mlir::bufferization::BufferizationDialect"
+ ];
let discardableAttrs = (ins
"mlir::IntegerAttr":$discardable_attr_key,
"SimpleAAttr":$other_discardable_attr_key
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 53055fea215b7..b211e243f234c 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+namespace {
+/// Returns test dialect's memref layout for test dialect's tensor encoding when
+/// applicable.
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding =
+ dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
+/// Auxiliary bufferization function for test and builtin tensors.
+bufferization::BufferLikeType
+convertTensorToBuffer(mlir::Operation *op,
+ const bufferization::BufferizationOptions &options,
+ bufferization::TensorLikeType tensorLike) {
+ auto buffer =
+ *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
+ if (auto memref = dyn_cast<MemRefType>(buffer)) {
+ // Note: For the sake of testing, we want to ensure that encoding -> layout
+ // bufferization happens. This is currently achieved manually.
+ auto layout =
+ getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(memref.getShape(), memref.getElementType(), layout,
+ memref.getMemorySpace()));
+ }
+ return buffer;
+}
+} // namespace
+
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
@@ -1435,8 +1468,8 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
return failure();
const auto outType = getOutput().getType();
- const auto bufferizedOutType = test::TestMemrefType::get(
- getContext(), outType.getShape(), outType.getElementType(), nullptr);
+ const auto bufferizedOutType =
+ convertTensorToBuffer(getOperation(), options, outType);
// replace op with memref analogy
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
rewriter, getLoc(), bufferizedOutType, *buffer);
@@ -1470,13 +1503,12 @@ ::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
mlir::FailureOr<mlir::bufferization::BufferLikeType>
test::TestCreateTensorOp::getBufferType(
- mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
const mlir::bufferization::BufferizationState &,
llvm::SmallVector<::mlir::Value> &) {
- const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
if (type == nullptr)
return failure();
- return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
- getContext(), type.getShape(), type.getElementType(), nullptr));
+ return convertTensorToBuffer(getOperation(), options, type);
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..b08933c08d132 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
// Include the attribute definitions.
include "TestAttrDefs.td"
@@ -2322,7 +2323,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
}
//===----------------------------------------------------------------------===//
-// Copy Operation Test
+// Copy Operation Test
//===----------------------------------------------------------------------===//
def CopyOp : TEST_Op<"copy", []> {
@@ -3663,10 +3664,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
["bufferize", "bufferizesToMemoryRead",
"bufferizesToMemoryWrite", "getAliasingValues"]>]> {
let arguments = (ins
- Arg<TestTensorType>:$input
+ Arg<Bufferization_TensorLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestTensorType>:$output
+ Arg<Bufferization_TensorLikeTypeInterface>:$output
);
let extraClassDefinition = [{
@@ -3688,10 +3689,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
let arguments = (ins
- Arg<TestMemrefType>:$input
+ Arg<Bufferization_BufferLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestMemrefType>:$output
+ Arg<Bufferization_BufferLikeTypeInterface>:$output
);
}
@@ -3701,7 +3702,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
"bufferizesToMemoryWrite", "getAliasingValues",
"bufferizesToAllocation"]>]> {
let arguments = (ins);
- let results = (outs Arg<TestTensorType>:$output);
+ let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
let extraClassDefinition = [{
bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
const ::mlir::bufferization::AnalysisState&) {
@@ -3725,7 +3726,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
let arguments = (ins);
- let results = (outs Arg<TestMemrefType>:$output);
+ let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output);
}
//===----------------------------------------------------------------------===//
>From eb51e559e9b012d00a31c7b023d9336e38d005d1 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 6 Oct 2025 11:56:28 +0000
Subject: [PATCH 3/3] Revert "[mlir][bufferization] Convert tensor enconding ->
memref layout"
This reverts commit 94419f20c22c9645f2bb5bbd87fa66d94f34f665.
---
.../IR/BufferizableOpInterface.h | 9 -
.../Bufferization/IR/BufferizationDialect.cpp | 5 +-
.../Bufferization/IR/BufferizationOps.cpp | 9 -
mlir/unittests/Transforms/CMakeLists.txt | 5 +-
.../Transforms/OneShotBufferization.cpp | 171 ------------------
5 files changed, 2 insertions(+), 197 deletions(-)
delete mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index bc5ebbcc64031..dd693a25fd54f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,9 +272,6 @@ struct BufferizationOptions {
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
- /// Construct a MemRefLayoutAttrInterface from a tensor type.
- using ConstructMemRefLayoutFn =
- std::function<MemRefLayoutAttrInterface(TensorType t)>;
BufferizationOptions();
@@ -367,12 +364,6 @@ struct BufferizationOptions {
DefaultMemorySpaceFn defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return Attribute(); };
- /// Construction function used to determine the memref layout based on the
- /// original tensor type. Can be used to specialize tensor encoding -> memref
- /// layout conversion. By default, it is unset, making the layout construction
- /// behavior depend on the place where it is used.
- ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
-
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index db06edcbe3d59..6c08cdfb669f3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,12 +65,9 @@ struct BuiltinTensorExternalModel
auto memSpace = options.defaultMemorySpaceFn(tensorType);
if (!memSpace.has_value())
return emitError() << "could not infer memory space";
- MemRefLayoutAttrInterface layout = {};
- if (options.constructMemRefLayoutFn)
- layout = options.constructMemRefLayoutFn(tensorType);
return cast<BufferLikeType>(
- getMemRefType(tensorType, options, layout, *memSpace));
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index a55b38aea6297..56ff2121e4620 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,15 +244,6 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
- // explicitly specified by the user. Otherwise, the default behavior is to
- // return a fully dynamic layout map which is the opposite of the default
- // behavior of this function.
- if (options.constructMemRefLayoutFn) {
- return cast<TensorLikeType>(getType()).getBufferType(
- options, [&]() { return emitError(); });
- }
-
return cast<BufferLikeType>(
getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index cd2548f45c94e..dc5920087b505 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,11 +1,8 @@
add_mlir_unittest(MLIRTransformsTests
Canonicalizer.cpp
DialectConversion.cpp
- OneShotBufferization.cpp
)
mlir_target_link_libraries(MLIRTransformsTests
PRIVATE
MLIRParser
- MLIRTransforms
- MLIRBufferizationTransforms
-)
+ MLIRTransforms)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
deleted file mode 100644
index a1d888b556c8c..0000000000000
--- a/mlir/unittests/Transforms/OneShotBufferization.cpp
+++ /dev/null
@@ -1,171 +0,0 @@
-//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/TensorEncoding.h"
-#include "mlir/Parser/Parser.h"
-#include "mlir/Pass/PassManager.h"
-
-#include "gtest/gtest.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestTensorAttr : public StringAttr {
- using mlir::StringAttr::StringAttr;
-
- static bool classof(mlir::Attribute attr) {
- return mlir::isa<mlir::StringAttr>(attr);
- }
-
- static TestTensorAttr fromStringAttr(StringAttr attr) {
- return mlir::dyn_cast<TestTensorAttr>(attr);
- }
-};
-
-class TestTensorEncodingVerifier final
- : public mlir::VerifiableTensorEncoding::ExternalModel<
- TestTensorEncodingVerifier, TestTensorAttr> {
-public:
- using ConcreteEntity = mlir::StringAttr;
-
- mlir::LogicalResult verifyEncoding(
- mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
- mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- std::ignore = shape;
-
- if (mlir::isa<TestTensorAttr>(attr)) {
- return mlir::success();
- }
- return emitError() << "Unknown Tensor enconding: " << attr;
- }
-};
-
-struct TestMemRefAttr : public mlir::StringAttr {
- using mlir::StringAttr::StringAttr;
-
- static bool classof(mlir::Attribute attr) {
- return mlir::isa<mlir::StringAttr>(attr);
- }
-
- mlir::AffineMap getAffineMap() const {
- return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
- }
-};
-
-class TestMemRefAttrLayout final
- : public mlir::MemRefLayoutAttrInterface::ExternalModel<
- TestMemRefAttrLayout, TestMemRefAttr> {
-public:
- using ConcreteEntity = mlir::StringAttr;
-
- bool isIdentity(mlir::Attribute) const { return true; }
- mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
- return cast<TestMemRefAttr>(attr).getAffineMap();
- }
- mlir::LogicalResult
- verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
- mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
- std::ignore = shape;
-
- if (mlir::isa<TestMemRefAttr>(attr)) {
- return mlir::success();
- }
- return emitError() << "Unknown MemRef layout: " << attr;
- }
-};
-
-TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
- MLIRContext context;
- context.getOrLoadDialect<BuiltinDialect>();
- context.getOrLoadDialect<func::FuncDialect>();
- context.getOrLoadDialect<bufferization::BufferizationDialect>();
-
- DialectRegistry registry;
- registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
- TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
- TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
- });
- bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
- registry);
- context.appendDialectRegistry(registry);
-
- const char *const code = R"mlir(
- func.func @foo(%t: tensor<42xf32, "hello">)
- -> tensor<42xf32, "hello"> {
- return %t : tensor<42xf32, "hello">
- }
-
- func.func @bar(%t1: tensor<42xf32, "hello">)
- -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
- %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
- -> tensor<42xf32, "hello">
-
- %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
-
- return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
- }
- )mlir";
-
- OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
- ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
-
- bufferization::OneShotBufferizationOptions options{};
- options.bufferizeFunctionBoundaries = true;
- options.constructMemRefLayoutFn =
- [](TensorType tensor) -> MemRefLayoutAttrInterface {
- assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
- auto tensorType = cast<RankedTensorType>(tensor);
- if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
- return cast<MemRefLayoutAttrInterface>(
- TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
- }
- return {};
- };
- options.functionArgTypeConverterFn =
- [&](bufferization::TensorLikeType tensor, Attribute memSpace,
- func::FuncOp, const bufferization::BufferizationOptions &) {
- assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
- auto tensorType = cast<RankedTensorType>(tensor);
- auto layout = options.constructMemRefLayoutFn(tensorType);
- return cast<bufferization::BufferLikeType>(
- MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
- layout, memSpace));
- };
-
- bufferization::BufferizationState state;
- ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
- module->getOperation(), options, state)));
-
- const auto checkType = [](Type type, StringRef expectedLayoutValue) {
- if (auto memref = dyn_cast<MemRefType>(type)) {
- if (auto layout = memref.getLayout();
- isa_and_nonnull<TestMemRefAttr>(layout)) {
- return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
- }
- }
- return false;
- };
-
- auto fooOp = *module->getOps<func::FuncOp>().begin();
- ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
- ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
-
- auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
- ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
- ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
- ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
-}
-
-} // end anonymous namespace
More information about the Mlir-commits
mailing list