[Mlir-commits] [mlir] [mlir][bufferization] Convert tensor enconding -> memref layout (PR #161166)

Andrei Golubev llvmlistbot at llvm.org
Mon Sep 29 03:44:29 PDT 2025


https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/161166

>From 94419f20c22c9645f2bb5bbd87fa66d94f34f665 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Fri, 26 Sep 2025 14:50:09 +0000
Subject: [PATCH] [mlir][bufferization] Convert tensor enconding -> memref
 layout

Support custom types (4/N): allow user-specified bufferization of tensor
encoding into memref layout.

Both tensor encoding and memref layout could be user-specified
attributes to store arbitrary information. It is often the case that
this information has to be preserved during tensor -> memref
bufferization. Thus, provide an option function to create memref
layout during TensorType::getBufferType() execution.

As a drive by, update AllocTensorOp::getBufferType() to work via
TensorLikeType::getBufferType() when memref layout is user-specified.
---
 .../IR/BufferizableOpInterface.h              |   9 +
 .../Bufferization/IR/BufferizationDialect.cpp |   5 +-
 .../Bufferization/IR/BufferizationOps.cpp     |   9 +
 mlir/unittests/Transforms/CMakeLists.txt      |   5 +-
 .../Transforms/OneShotBufferization.cpp       | 171 ++++++++++++++++++
 5 files changed, 197 insertions(+), 2 deletions(-)
 create mode 100644 mlir/unittests/Transforms/OneShotBufferization.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index dd693a25fd54f..bc5ebbcc64031 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -272,6 +272,9 @@ struct BufferizationOptions {
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
       std::function<std::optional<Attribute>(TensorType t)>;
+  /// Construct a MemRefLayoutAttrInterface from a tensor type.
+  using ConstructMemRefLayoutFn =
+      std::function<MemRefLayoutAttrInterface(TensorType t)>;
 
   BufferizationOptions();
 
@@ -364,6 +367,12 @@ struct BufferizationOptions {
   DefaultMemorySpaceFn defaultMemorySpaceFn =
       [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
 
+  /// Construction function used to determine the memref layout based on the
+  /// original tensor type. Can be used to specialize tensor encoding -> memref
+  /// layout conversion. By default, it is unset, making the layout construction
+  /// behavior depend on the place where it is used.
+  ConstructMemRefLayoutFn constructMemRefLayoutFn = nullptr;
+
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
   bool copyBeforeWrite = false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 6c08cdfb669f3..db06edcbe3d59 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -65,9 +65,12 @@ struct BuiltinTensorExternalModel
     auto memSpace = options.defaultMemorySpaceFn(tensorType);
     if (!memSpace.has_value())
       return emitError() << "could not infer memory space";
+    MemRefLayoutAttrInterface layout = {};
+    if (options.constructMemRefLayoutFn)
+      layout = options.constructMemRefLayoutFn(tensorType);
 
     return cast<BufferLikeType>(
-        getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
+        getMemRefType(tensorType, options, layout, *memSpace));
   }
 
   mlir::LogicalResult verifyCompatibleBufferType(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..a55b38aea6297 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -244,6 +244,15 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
+  // Note: Only rely on TensorLikeType::getBufferType() if memref layout is
+  // explicitly specified by the user. Otherwise, the default behavior is to
+  // return a fully dynamic layout map which is the opposite of the default
+  // behavior of this function.
+  if (options.constructMemRefLayoutFn) {
+    return cast<TensorLikeType>(getType()).getBufferType(
+        options, [&]() { return emitError(); });
+  }
+
   return cast<BufferLikeType>(
       getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt
index dc5920087b505..cd2548f45c94e 100644
--- a/mlir/unittests/Transforms/CMakeLists.txt
+++ b/mlir/unittests/Transforms/CMakeLists.txt
@@ -1,8 +1,11 @@
 add_mlir_unittest(MLIRTransformsTests
   Canonicalizer.cpp
   DialectConversion.cpp
+  OneShotBufferization.cpp
 )
 mlir_target_link_libraries(MLIRTransformsTests
   PRIVATE
   MLIRParser
-  MLIRTransforms)
+  MLIRTransforms
+  MLIRBufferizationTransforms
+)
diff --git a/mlir/unittests/Transforms/OneShotBufferization.cpp b/mlir/unittests/Transforms/OneShotBufferization.cpp
new file mode 100644
index 0000000000000..a1d888b556c8c
--- /dev/null
+++ b/mlir/unittests/Transforms/OneShotBufferization.cpp
@@ -0,0 +1,171 @@
+//===- OneShotBufferization.cpp - One-shot bufferization unit tests -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/PassManager.h"
+
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestTensorAttr : public StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  static TestTensorAttr fromStringAttr(StringAttr attr) {
+    return mlir::dyn_cast<TestTensorAttr>(attr);
+  }
+};
+
+class TestTensorEncodingVerifier final
+    : public mlir::VerifiableTensorEncoding::ExternalModel<
+          TestTensorEncodingVerifier, TestTensorAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  mlir::LogicalResult verifyEncoding(
+      mlir::Attribute attr, mlir::ArrayRef<int64_t> shape, mlir::Type,
+      mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestTensorAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown Tensor enconding: " << attr;
+  }
+};
+
+struct TestMemRefAttr : public mlir::StringAttr {
+  using mlir::StringAttr::StringAttr;
+
+  static bool classof(mlir::Attribute attr) {
+    return mlir::isa<mlir::StringAttr>(attr);
+  }
+
+  mlir::AffineMap getAffineMap() const {
+    return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+  }
+};
+
+class TestMemRefAttrLayout final
+    : public mlir::MemRefLayoutAttrInterface::ExternalModel<
+          TestMemRefAttrLayout, TestMemRefAttr> {
+public:
+  using ConcreteEntity = mlir::StringAttr;
+
+  bool isIdentity(mlir::Attribute) const { return true; }
+  mlir::AffineMap getAffineMap(mlir::Attribute attr) const {
+    return cast<TestMemRefAttr>(attr).getAffineMap();
+  }
+  mlir::LogicalResult
+  verifyLayout(mlir::Attribute attr, mlir::ArrayRef<int64_t> shape,
+               mlir::function_ref<mlir::InFlightDiagnostic()> emitError) const {
+    std::ignore = shape;
+
+    if (mlir::isa<TestMemRefAttr>(attr)) {
+      return mlir::success();
+    }
+    return emitError() << "Unknown MemRef layout: " << attr;
+  }
+};
+
+TEST(OneShotBufferizationTest, BufferizeTensorEncodingIntoMemRefLayout) {
+  MLIRContext context;
+  context.getOrLoadDialect<BuiltinDialect>();
+  context.getOrLoadDialect<func::FuncDialect>();
+  context.getOrLoadDialect<bufferization::BufferizationDialect>();
+
+  DialectRegistry registry;
+  registry.addExtension(+[](mlir::MLIRContext *ctx, BuiltinDialect *) {
+    TestTensorAttr::attachInterface<TestTensorEncodingVerifier>(*ctx);
+    TestMemRefAttr::attachInterface<TestMemRefAttrLayout>(*ctx);
+  });
+  bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+      registry);
+  context.appendDialectRegistry(registry);
+
+  const char *const code = R"mlir(
+    func.func @foo(%t: tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello"> {
+      return %t : tensor<42xf32, "hello">
+    }
+
+    func.func @bar(%t1: tensor<42xf32, "hello">)
+        -> (tensor<42xf32, "hello">, tensor<12xf32, "not hello">) {
+      %out1 = func.call @foo(%t1) : (tensor<42xf32, "hello">)
+        -> tensor<42xf32, "hello">
+
+      %out2 = bufferization.alloc_tensor() : tensor<12xf32, "not hello">
+
+      return %out1, %out2 : tensor<42xf32, "hello">, tensor<12xf32, "not hello">
+    }
+  )mlir";
+
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
+  ASSERT_NE(module.get(), nullptr) << "parsing should be successful";
+
+  bufferization::OneShotBufferizationOptions options{};
+  options.bufferizeFunctionBoundaries = true;
+  options.constructMemRefLayoutFn =
+      [](TensorType tensor) -> MemRefLayoutAttrInterface {
+    assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+    auto tensorType = cast<RankedTensorType>(tensor);
+    if (auto encoding = dyn_cast<TestTensorAttr>(tensorType.getEncoding())) {
+      return cast<MemRefLayoutAttrInterface>(
+          TestMemRefAttr::get(tensor.getContext(), encoding.strref()));
+    }
+    return {};
+  };
+  options.functionArgTypeConverterFn =
+      [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+          func::FuncOp, const bufferization::BufferizationOptions &) {
+        assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+        auto tensorType = cast<RankedTensorType>(tensor);
+        auto layout = options.constructMemRefLayoutFn(tensorType);
+        return cast<bufferization::BufferLikeType>(
+            MemRefType::get(tensorType.getShape(), tensorType.getElementType(),
+                            layout, memSpace));
+      };
+
+  bufferization::BufferizationState state;
+  ASSERT_TRUE(succeeded(bufferization::runOneShotModuleBufferize(
+      module->getOperation(), options, state)));
+
+  const auto checkType = [](Type type, StringRef expectedLayoutValue) {
+    if (auto memref = dyn_cast<MemRefType>(type)) {
+      if (auto layout = memref.getLayout();
+          isa_and_nonnull<TestMemRefAttr>(layout)) {
+        return cast<TestMemRefAttr>(layout) == expectedLayoutValue;
+      }
+    }
+    return false;
+  };
+
+  auto fooOp = *module->getOps<func::FuncOp>().begin();
+  ASSERT_TRUE(checkType(fooOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(fooOp.getResultTypes()[0], "hello"));
+
+  auto barOp = *std::next(module->getOps<func::FuncOp>().begin());
+  ASSERT_TRUE(checkType(barOp.getArgumentTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[0], "hello"));
+  ASSERT_TRUE(checkType(barOp.getResultTypes()[1], "not hello"));
+}
+
+} // end anonymous namespace



More information about the Mlir-commits mailing list