[Mlir-commits] [mlir] [mlir][bufferization] Convert tensor enconding into memref layout (PR #161166)

Andrei Golubev llvmlistbot at llvm.org
Mon Oct 6 05:41:42 PDT 2025


https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/161166

>From ece4805a82221b7b8cca6e6352980d424a9a8e41 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Fri, 26 Sep 2025 14:50:09 +0000
Subject: [PATCH 1/3] [mlir][bufferization] Convert tensor enconding -> memref
 layout

Support custom types (4/N): allow user-specified bufferization of tensor
encoding into memref layout.

Both tensor encoding and memref layout could be user-specified
attributes to store arbitrary information. It is often the case that
this information has to be preserved during tensor -> memref
bufferization. Thus, provide an option function to create memref
layout during TensorType::getBufferType() execution.

As a drive by, update AllocTensorOp::getBufferType() to work via
TensorLikeType::getBufferType() when memref layout is user-specified.
---
 .../IR/BufferizableOpInterface.h              |   9 +
 .../Bufferization/IR/BufferizationDialect.cpp |   5 +-
 .../Bufferization/IR/BufferizationOps.cpp     |   9 +
 mlir/unittests/Transforms/CMakeLists.txt      |   5 +-
 .../Transforms/OneShotBufferization.cpp       | 171 ++++++++++++++++++
 5 files changed, 197 insertions(+), 2 deletions(-)
 create mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
+  /// Construct a MemRefLayoutAttrInterface from a tensor type.
+  using ConstructMemRefLayoutFn =
+      std::function<MemRefLayoutAttrInterface(TensorType t)>;
 
   BufferizationOptions();
 
@@ -364,6 +367,12 @@ struct BufferizationOptions {
   DefaultMemorySpaceFn defaultMemorySpaceFn =
       [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
 
+  /// Construction function used to determine the memref layout based on the
+  /// original tensor type. Can be used to specialize tensor encoding -> memref
+  /// layout conversion. By default, it is unset, making the layout construction
+  /// behavior depend on the place where it is used.
+  ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
   bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..db06edcbe3d59 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
     auto memSpace = options.defaultMemorySpaceFn(tensorType);
     if (!memSpace.has_value())
       return emitError() << "could not infer memory space";
+    MemRefLayoutAttrInterface layout = {};
+    if (options.constructMemRefLayoutFn)
+      layout = options.constructMemRefLayoutFn(tensorType);
 
     return cast<BufferLikeType>(
-        getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+        getMemRefType(tensorType, options, layout, *memSpace));
   }
 
   mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
+  // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+  // explicitly specified by the user. Otherwise, the default behavior is to
+  // return a fully dynamic layout map which is the opposite of the default
+  // behavior of this function.
+  if (options.constructMemRefLayoutFn) {
+    return cast<TensorLikeType>(getType()).getBufferType(
+        options, [&]() { return emitError(); });
+  }
+
   return cast<BufferLikeType>(
       getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRTransformsTests
   Canonicalizer.cpp
   DialectConversion.cpp
+  OneShotBufferization.cpp
 )
 mlir_target_link_libraries(MLIRTransformsTests
   PRIVATE
   MLIRParser
-  MLIRTransforms)
+  MLIRTransforms
+  MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  static TestTensorAttr fromStringAttr(StringAttr attr) {
+    return mlir::dyn_cast<TestTensorAttr>(attr);
+  }
+};
+
+class TestTensorEncodingVerifier final
+    : public mlir::VerifiableTensorEncoding::ExternalModel<
+          TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  mlir::LogicalResult verifyEncoding(
+      mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+      mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestTensorAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown Tensor enconding: " << attr;
+  }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  mlir::AffineMap getAffineMap() const {
+    return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+  }
+};
+
+class TestMemRefAttrLayout final
+    : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+          TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  bool isIdentity(mlir::Attribute) const { return true; }
+  mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+    return cast<TestMemRefAttr>(attr).getAffineMap();
+  }
+  mlir::LogicalResult
+  verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+               mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestMemRefAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown MemRef layout: " << attr;
+  }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+  MLIRContext context;
+  context.getOrLoadDialect<BuiltinDialect>();
+  context.getOrLoadDialect<func::FuncDialect>();
+  context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+  DialectRegistry registry;
+  registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+    TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+    TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+  });
+  bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+      registry);
+  context.appendDialectRegistry(registry);
+
+  const char *const code = R"mlir(
+    func.func @foo(%t: tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello"> {
+      return %t : tensor<42xf32, "hello">
+    }
+
+    func.func @bar(%t1: tensor<42xf32, "hello">)
+        -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+      %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello">
+
+      %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+      return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+    }
+  )mlir";
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+  ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+  bufferization::OneShotBufferizationOptions options{};
+  options.bufferizeFunctionBoundaries = true;
+  options.constructMemRefLayoutFn =
+      [](TensorType tensor) -> MemRefLayoutAttrInterface {
+    assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+    auto tensorType = cast<RankedTensorType>(tensor);
+    if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+      return cast<MemRefLayoutAttrInterface>(
+          TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+    }
+    return {};
+  };
+  options.functionArgTypeConverterFn =
+      [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+          func::FuncOp, const bufferization::BufferizationOptions &) {
+        assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+        auto tensorType = cast<RankedTensorType>(tensor);
+        auto layout = options.constructMemRefLayoutFn(tensorType);
+        return cast<bufferization::BufferLikeType>(
+            MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+                            layout, memSpace));
+      };
+
+  bufferization::BufferizationState state;
+  ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+      module->getOperation(), options, state)));
+
+  const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+    if (auto memref = dyn_cast<MemRefType>(type)) {
+      if (auto layout = memref.getLayout();
+          isa_and_nonnull<TestMemRefAttr>(layout)) {
+        return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+      }
+    }
+    return false;
+  };
+
+  auto fooOp = *module->getOps<func::FuncOp>().begin();
+  ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+  auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+  ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace

>From a6c3ef2d965eb76fae4d206bb5b94b2ee900f83b Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 6 Oct 2025 11:50:34 +0000
Subject: [PATCH 2/3] [mlir][bufferization] Test tensor encoding -> memref
 layout conversion

Support custom types (4/N): test that it is possible to customize memref
layout specification for custom operations and function boundaries.

This is purely a test setup (no API modifications) to ensure users are
able to pass information from tensors to memrefs within bufferization
process. To achieve this, a test pass is required (since bufferization
options have to be set manually). As there is already a
--test-one-shot-module-bufferize pass present, it is extended for the
purpose.
---
 .../one-shot-non-module-bufferize.mlir        | 38 +++++++++++++++-
 .../TestOneShotModuleBufferize.cpp            | 26 +++++++++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 17 +++++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 18 ++++++++
 mlir/test/lib/Dialect/Test/TestAttributes.h   |  1 +
 mlir/test/lib/Dialect/Test/TestDialect.h      |  1 +
 mlir/test/lib/Dialect/Test/TestDialect.td     |  5 ++-
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 44 ++++++++++++++++---
 mlir/test/lib/Dialect/Test/TestOps.td         | 15 ++++---
 9 files changed, 150 insertions(+), 15 deletions(-)

diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
index e2ab876f8b46a..b52612d0d1f10 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-non-module-bufferize.mlir
@@ -24,10 +24,46 @@
     // CHECK-NOT: copy
     // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]])
     %0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
-    // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32,{{.*}}>
+    // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref<?xf32{{.*}}>
     return %1, %0 : f32, tensor<?xf32>
   }
   "test.finish" () : () -> ()
 }) : () -> ()
 
+// -----
 
+#enc1 = #test.tensor_encoding<"hello">
+#enc2 = #test.tensor_encoding<"not hello">
+
+"test.symbol_scope_isolated"() ({
+  // CHECK: func @inner_func(
+  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+  // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"hello">>
+  func.func @inner_func(%t: tensor<?xf32, #enc1>)
+      -> tensor<?xf32, #enc1> {
+    // CHECK: return %[[arg0]]
+    return %t : tensor<?xf32, #enc1>
+  }
+
+  // CHECK: func @outer_func(
+  // CHECK-SAME:  %[[arg0:.*]]: memref<?xf32, #test.memref_layout<"hello">>)
+  // CHECK-SAME:  -> (memref<?xf32, #test.memref_layout<"hello">>,
+  // CHECK-SAME:      memref<?xf32, #test.memref_layout<"not hello">>)
+  func.func @outer_func(%t0: tensor<?xf32, #enc1>)
+      -> (tensor<?xf32, #enc1>, tensor<?xf32, #enc2>) {
+    // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
+    %0 = call @inner_func(%t0)
+      : (tensor<?xf32, #enc1>) -> (tensor<?xf32, #enc1>)
+
+    // CHECK: %[[local:.*]] = "test.create_memref_op"() : ()
+    // CHECK-SAME:  -> memref<?xf32, #test.memref_layout<"not hello">>
+    %local = "test.create_tensor_op"() : () -> tensor<?xf32, #enc2>
+    // CHECK: %[[dummy:.*]] = "test.dummy_memref_op"(%[[local]])
+    %1 = "test.dummy_tensor_op"(%local) : (tensor<?xf32, #enc2>)
+      -> tensor<?xf32, #enc2>
+
+    // CHECK: return %[[call]], %[[dummy]]
+    return %0, %1 : tensor<?xf32, #enc1>, tensor<?xf32, #enc2>
+  }
+  "test.finish" () : () -> ()
+}) : () -> ()
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index 1e2d4a7c8f08d..4069a74dbf63c 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -11,11 +11,25 @@
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 
+#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr
+#include "TestDialect.h"
+
 using namespace mlir;
 
 namespace {
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+  if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>(
+          tensorType.getEncoding())) {
+    return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+        tensorType.getContext(), encoding.getDummy()));
+  }
+  return {};
+}
+
 struct TestOneShotModuleBufferizePass
     : public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
@@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass
       : PassWrapper(pass) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<test::TestDialect>();
     registry.insert<bufferization::BufferizationDialect>();
   }
   StringRef getArgument() const final {
@@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass
     bufferization::OneShotBufferizationOptions opt;
 
     opt.bufferizeFunctionBoundaries = true;
+    opt.functionArgTypeConverterFn =
+        [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+            func::FuncOp, const bufferization::BufferizationOptions &) {
+          assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+          auto tensorType = cast<RankedTensorType>(tensor);
+          auto layout = getMemRefLayoutForTensorEncoding(tensorType);
+          return cast<bufferization::BufferLikeType>(
+              MemRefType::get(tensorType.getShape(),
+                              tensorType.getElementType(), layout, memSpace));
+        };
+
     bufferization::BufferizationState bufferizationState;
 
     if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 5685004bbbd25..9e7e4f883b576 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
 include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/TensorEncoding.td"
 
 // All of the attributes will extend this class.
 class Test_Attr<string name, list<Trait> traits = []>
@@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> {
     let hasStorageCustomConstructor = 1;
 }
 
+def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding",
+    [DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> {
+  let mnemonic = "tensor_encoding";
+
+  let parameters = (ins "mlir::StringAttr":$dummy);
+  let assemblyFormat = "`<` $dummy `>`";
+}
+
+def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
+    [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
+  let mnemonic = "memref_layout";
+
+  let parameters = (ins "mlir::StringAttr":$dummy);
+  let assemblyFormat = "`<` $dummy `>`";
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index fe1e9166a3099..9db7b01dd193b 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -541,6 +541,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct(
   return nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+// TestTensorEncodingAttr
+//===----------------------------------------------------------------------===//
+
+::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding(
+    mlir::ArrayRef<int64_t> shape, mlir::Type elementType,
+    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestMemRefLayoutAttr
+//===----------------------------------------------------------------------===//
+
+mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const {
+  return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 778d84fae7365..0ad5ab641c6d0 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -24,6 +24,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/TensorEncoding.h"
 
 // generated files require above includes to come first
 #include "TestAttrInterfaces.h.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index f2adca6310d78..bcf3b55d33cb9 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -18,6 +18,7 @@
 #include "TestInterfaces.h"
 #include "TestTypes.h"
 #include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 2b5491fc0c6a0..37a263f1d10b8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect {
   let useDefaultTypePrinterParser = 0;
   let useDefaultAttributePrinterParser = 1;
   let isExtensible = 1;
-  let dependentDialects = ["::mlir::DLTIDialect"];
+  let dependentDialects = [
+    "::mlir::DLTIDialect",
+    "::mlir::bufferization::BufferizationDialect"
+  ];
   let discardableAttrs = (ins
      "mlir::IntegerAttr":$discardable_attr_key,
      "SimpleAAttr":$other_discardable_attr_key
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 53055fea215b7..b211e243f234c 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete(
   return createNewMultiAllocaWithoutSlot(slot, builder, *this);
 }
 
+namespace {
+/// Returns test dialect's memref layout for test dialect's tensor encoding when
+/// applicable.
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+  if (auto encoding =
+          dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
+    return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+        tensorType.getContext(), encoding.getDummy()));
+  }
+  return {};
+}
+
+/// Auxiliary bufferization function for test and builtin tensors.
+bufferization::BufferLikeType
+convertTensorToBuffer(mlir::Operation *op,
+                      const bufferization::BufferizationOptions &options,
+                      bufferization::TensorLikeType tensorLike) {
+  auto buffer =
+      *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
+  if (auto memref = dyn_cast<MemRefType>(buffer)) {
+    // Note: For the sake of testing, we want to ensure that encoding -> layout
+    // bufferization happens. This is currently achieved manually.
+    auto layout =
+        getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
+    return cast<bufferization::BufferLikeType>(
+        MemRefType::get(memref.getShape(), memref.getElementType(), layout,
+                        memref.getMemorySpace()));
+  }
+  return buffer;
+}
+} // namespace
+
 ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
     ::mlir::RewriterBase &rewriter,
     const ::mlir::bufferization::BufferizationOptions &options,
@@ -1435,8 +1468,8 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
     return failure();
 
   const auto outType = getOutput().getType();
-  const auto bufferizedOutType = test::TestMemrefType::get(
-      getContext(), outType.getShape(), outType.getElementType(), nullptr);
+  const auto bufferizedOutType =
+      convertTensorToBuffer(getOperation(), options, outType);
   // replace op with memref analogy
   auto dummyMemrefOp = test::TestDummyMemrefOp::create(
       rewriter, getLoc(), bufferizedOutType, *buffer);
@@ -1470,13 +1503,12 @@ ::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
 
 mlir::FailureOr<mlir::bufferization::BufferLikeType>
 test::TestCreateTensorOp::getBufferType(
-    mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+    mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
     const mlir::bufferization::BufferizationState &,
     llvm::SmallVector<::mlir::Value> &) {
-  const auto type = dyn_cast<test::TestTensorType>(value.getType());
+  const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
   if (type == nullptr)
     return failure();
 
-  return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
-      getContext(), type.getShape(), type.getElementType(), nullptr));
+  return convertTensorToBuffer(getOperation(), options, type);
 }
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187655ee..b08933c08d132 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ValueBoundsOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 
 // Include the attribute definitions.
 include "TestAttrDefs.td"
@@ -2322,7 +2323,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
 }
 
 //===----------------------------------------------------------------------===//
-// Copy Operation Test 
+// Copy Operation Test
 //===----------------------------------------------------------------------===//
 
 def CopyOp : TEST_Op<"copy", []> {
@@ -3663,10 +3664,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
         ["bufferize", "bufferizesToMemoryRead",
          "bufferizesToMemoryWrite", "getAliasingValues"]>]> {
   let arguments = (ins
-    Arg<TestTensorType>:$input
+    Arg<Bufferization_TensorLikeTypeInterface>:$input
   );
   let results = (outs
-    Arg<TestTensorType>:$output
+    Arg<Bufferization_TensorLikeTypeInterface>:$output
   );
 
   let extraClassDefinition = [{
@@ -3688,10 +3689,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
 
 def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
   let arguments = (ins
-    Arg<TestMemrefType>:$input
+    Arg<Bufferization_BufferLikeTypeInterface>:$input
   );
   let results = (outs
-    Arg<TestMemrefType>:$output
+    Arg<Bufferization_BufferLikeTypeInterface>:$output
   );
 }
 
@@ -3701,7 +3702,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
          "bufferizesToMemoryWrite", "getAliasingValues",
          "bufferizesToAllocation"]>]> {
   let arguments = (ins);
-  let results = (outs Arg<TestTensorType>:$output);
+  let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
   let extraClassDefinition = [{
     bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
         const ::mlir::bufferization::AnalysisState&) {
@@ -3725,7 +3726,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
 
 def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
   let arguments = (ins);
-  let results = (outs Arg<TestMemrefType>:$output);
+  let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output);
 }
 
 //===----------------------------------------------------------------------===//

>From eb51e559e9b012d00a31c7b023d9336e38d005d1 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 6 Oct 2025 11:56:28 +0000
Subject: [PATCH 3/3] Revert "[mlir][bufferization] Convert tensor enconding ->
 memref layout"

This reverts commit 94419f20c22c9645f2bb5bbd87fa66d94f34f665.
---
 .../IR/BufferizableOpInterface.h              |   9 -
 .../Bufferization/IR/BufferizationDialect.cpp |   5 +-
 .../Bufferization/IR/BufferizationOps.cpp     |   9 -
 mlir/unittests/Transforms/CMakeLists.txt      |   5 +-
 .../Transforms/OneShotBufferization.cpp       | 171 ------------------
 5 files changed, 2 insertions(+), 197 deletions(-)
 delete mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index bc5ebbcc64031..dd693a25fd54f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,9 +272,6 @@ struct BufferizationOptions {
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
-  /// Construct a MemRefLayoutAttrInterface from a tensor type.
-  using ConstructMemRefLayoutFn =
-      std::function<MemRefLayoutAttrInterface(TensorType t)>;
 
   BufferizationOptions();
 
@@ -367,12 +364,6 @@ struct BufferizationOptions {
   DefaultMemorySpaceFn defaultMemorySpaceFn =
       [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
 
-  /// Construction function used to determine the memref layout based on the
-  /// original tensor type. Can be used to specialize tensor encoding -> memref
-  /// layout conversion. By default, it is unset, making the layout construction
-  /// behavior depend on the place where it is used.
-  ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
-
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
   bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index db06edcbe3d59..6c08cdfb669f3 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,12 +65,9 @@ struct BuiltinTensorExternalModel
     auto memSpace = options.defaultMemorySpaceFn(tensorType);
     if (!memSpace.has_value())
       return emitError() << "could not infer memory space";
-    MemRefLayoutAttrInterface layout = {};
-    if (options.constructMemRefLayoutFn)
-      layout = options.constructMemRefLayoutFn(tensorType);
 
     return cast<BufferLikeType>(
-        getMemRefType(tensorType, options, layout, *memSpace));
+        getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
   }
 
   mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index a55b38aea6297..56ff2121e4620 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,15 +244,6 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
-  // explicitly specified by the user. Otherwise, the default behavior is to
-  // return a fully dynamic layout map which is the opposite of the default
-  // behavior of this function.
-  if (options.constructMemRefLayoutFn) {
-    return cast<TensorLikeType>(getType()).getBufferType(
-        options, [&]() { return emitError(); });
-  }
-
   return cast<BufferLikeType>(
       getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index cd2548f45c94e..dc5920087b505 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,11 +1,8 @@
 add_mlir_unittest(MLIRTransformsTests
   Canonicalizer.cpp
   DialectConversion.cpp
-  OneShotBufferization.cpp
 )
 mlir_target_link_libraries(MLIRTransformsTests
   PRIVATE
   MLIRParser
-  MLIRTransforms
-  MLIRBufferizationTransforms
-)
+  MLIRTransforms)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
deleted file mode 100644
index a1d888b556c8c..0000000000000
--- a/mlir/unittests/Transforms/OneShotBufferization.cpp
+++ /dev/null
@@ -1,171 +0,0 @@
-//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/BuiltinDialect.h"
-#include "mlir/IR/TensorEncoding.h"
-#include "mlir/Parser/Parser.h"
-#include "mlir/Pass/PassManager.h"
-
-#include "gtest/gtest.h"
-
-using namespace mlir;
-
-namespace {
-
-struct TestTensorAttr : public StringAttr {
-  using mlir::StringAttr::StringAttr;
-
-  static bool classof(mlir::Attribute attr) {
-    return mlir::isa<mlir::StringAttr>(attr);
-  }
-
-  static TestTensorAttr fromStringAttr(StringAttr attr) {
-    return mlir::dyn_cast<TestTensorAttr>(attr);
-  }
-};
-
-class TestTensorEncodingVerifier final
-    : public mlir::VerifiableTensorEncoding::ExternalModel<
-          TestTensorEncodingVerifier, TestTensorAttr> {
-public:
-  using ConcreteEntity = mlir::StringAttr;
-
-  mlir::LogicalResult verifyEncoding(
-      mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
-      mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
-    std::ignore = shape;
-
-    if (mlir::isa<TestTensorAttr>(attr)) {
-      return mlir::success();
-    }
-    return emitError() << "Unknown Tensor enconding: " << attr;
-  }
-};
-
-struct TestMemRefAttr : public mlir::StringAttr {
-  using mlir::StringAttr::StringAttr;
-
-  static bool classof(mlir::Attribute attr) {
-    return mlir::isa<mlir::StringAttr>(attr);
-  }
-
-  mlir::AffineMap getAffineMap() const {
-    return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
-  }
-};
-
-class TestMemRefAttrLayout final
-    : public mlir::MemRefLayoutAttrInterface::ExternalModel<
-          TestMemRefAttrLayout, TestMemRefAttr> {
-public:
-  using ConcreteEntity = mlir::StringAttr;
-
-  bool isIdentity(mlir::Attribute) const { return true; }
-  mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
-    return cast<TestMemRefAttr>(attr).getAffineMap();
-  }
-  mlir::LogicalResult
-  verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
-               mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
-    std::ignore = shape;
-
-    if (mlir::isa<TestMemRefAttr>(attr)) {
-      return mlir::success();
-    }
-    return emitError() << "Unknown MemRef layout: " << attr;
-  }
-};
-
-TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
-  MLIRContext context;
-  context.getOrLoadDialect<BuiltinDialect>();
-  context.getOrLoadDialect<func::FuncDialect>();
-  context.getOrLoadDialect<bufferization::BufferizationDialect>();
-
-  DialectRegistry registry;
-  registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
-    TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
-    TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
-  });
-  bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
-      registry);
-  context.appendDialectRegistry(registry);
-
-  const char *const code = R"mlir(
-    func.func @foo(%t: tensor<42xf32, "hello">)
-        -> tensor<42xf32, "hello"> {
-      return %t : tensor<42xf32, "hello">
-    }
-
-    func.func @bar(%t1: tensor<42xf32, "hello">)
-        -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
-      %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
-        -> tensor<42xf32, "hello">
-
-      %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
-
-      return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
-    }
-  )mlir";
-
-  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
-  ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
-
-  bufferization::OneShotBufferizationOptions options{};
-  options.bufferizeFunctionBoundaries = true;
-  options.constructMemRefLayoutFn =
-      [](TensorType tensor) -> MemRefLayoutAttrInterface {
-    assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
-    auto tensorType = cast<RankedTensorType>(tensor);
-    if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
-      return cast<MemRefLayoutAttrInterface>(
-          TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
-    }
-    return {};
-  };
-  options.functionArgTypeConverterFn =
-      [&](bufferization::TensorLikeType tensor, Attribute memSpace,
-          func::FuncOp, const bufferization::BufferizationOptions &) {
-        assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
-        auto tensorType = cast<RankedTensorType>(tensor);
-        auto layout = options.constructMemRefLayoutFn(tensorType);
-        return cast<bufferization::BufferLikeType>(
-            MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
-                            layout, memSpace));
-      };
-
-  bufferization::BufferizationState state;
-  ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
-      module->getOperation(), options, state)));
-
-  const auto checkType = [](Type type, StringRef expectedLayoutValue) {
-    if (auto memref = dyn_cast<MemRefType>(type)) {
-      if (auto layout = memref.getLayout();
-          isa_and_nonnull<TestMemRefAttr>(layout)) {
-        return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
-      }
-    }
-    return false;
-  };
-
-  auto fooOp = *module->getOps<func::FuncOp>().begin();
-  ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
-  ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
-
-  auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
-  ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
-  ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
-  ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
-}
-
-} // end anonymous namespace



More information about the Mlir-commits mailing list