[Mlir-commits] [mlir] [mlir][bufferization] Add tensor-like and memref-like interfaces (PR #134220)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 3 01:44:53 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Current one-shot bufferization infrastructure operates on top of TensorType and BaseMemRefType. These are non-extensible base classes of the respective builtins: tensor and memref. Thus, the infrastructure is bound to work only with builtin tensor/memref types. At the same time, there are customization points that allow one to provide custom logic to control the bufferization behavior.
This patch introduces new type interfaces: tensor-like and memref-like that aim to supersede TensorType/BaseMemRefType within the bufferization dialect and allow custom tensors / memrefs to be used. Additionally, these new type interfaces are attached to the respective builtin types so that the switch is seamless.
Note that this patch does very minimal initial work, it does NOT refactor bufferization infrastructure.
---
Full diff: https://github.com/llvm/llvm-project/pull/134220.diff
8 Files Affected:
- (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+21)
- (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+51)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt (+6)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+23)
- (modified) mlir/test/lib/Dialect/Test/CMakeLists.txt (+1)
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+40)
- (modified) mlir/test/lib/Dialect/Test/TestTypes.h (+1)
- (modified) mlir/unittests/IR/InterfaceTest.cpp (+77)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
new file mode 100644
index 0000000000000..9e83f0e3ad2de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -0,0 +1,21 @@
+//===- BufferizationTypeInterfaces.h - Bufferization type interfaces -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
+
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Bufferization Type Interfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
new file mode 100644
index 0000000000000..c4c760d38203c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -0,0 +1,51 @@
+//===- BufferizationTypeInterfaces.td - Bufferization type interfaces -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------------===//
+//
+// This is the definition file for type interfaces used in Bufferization.
+//
+//===---------------------------------------------------------------------------------===//
+
+#ifndef BUFFERIZATION_TYPE_INTERFACES
+#define BUFFERIZATION_TYPE_INTERFACES
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+
+def Bufferization_TensorLikeTypeInterface
+ : TypeInterface<"TensorLikeType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir::bufferization";
+ let description = [{
+ Indicates that the type that attaches this interface can be treated as a
+ tensor type (similarly to a MLIR builtin tensor) within the Bufferization
+ dialect.
+
+ Implementing this interface means that the type also implements
+ ShapedTypeInterface.
+
+ The interface currently has no methods as it is used by types to opt into
+ being supported by the bufferization procedures.
+ }];
+}
+
+def Bufferization_MemRefLikeTypeInterface
+ : TypeInterface<"MemRefLikeType", [ShapedTypeInterface]> {
+ let cppNamespace = "::mlir::bufferization";
+ let description = [{
+ Indicates that the type that attaches this interface can be treated as a
+ memref type (similarly to a MLIR builtin memref) within the Bufferization
+ dialect.
+
+ Implementing this interface means that the type also implements
+ ShapedTypeInterface.
+
+ The interface currently has no methods as it is used by types to opt into
+ being supported by the bufferization procedures.
+ }];
+}
+
+#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 13a5bc370a4fc..3ead52148c208 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
+mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
+add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index e5a0c3c45b09e..31faa96695379 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -9,8 +9,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/InliningUtils.h"
@@ -51,6 +54,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
return true;
}
};
+
+template <typename Tensor>
+struct BuiltinTensorExternalModel
+ : TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
+ Tensor> {};
+
+template <typename MemRef>
+struct BuiltinMemRefExternalModel
+ : MemRefLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
+ MemRef> {};
} // namespace
//===----------------------------------------------------------------------===//
@@ -63,6 +76,16 @@ void mlir::bufferization::BufferizationDialect::initialize() {
#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
>();
addInterfaces<BufferizationInlinerInterface>();
+
+ assert(getContext() != nullptr);
+ RankedTensorType::attachInterface<
+ BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
+ UnrankedTensorType::attachInterface<
+ BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
+ MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
+ *getContext());
+ UnrankedMemRefType::attachInterface<
+ BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
}
LogicalResult BufferizationDialect::verifyRegionArgAttribute(
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index a48ac24ca056d..6e608e4772391 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRTransformUtils
MLIRTransforms
MLIRValueBoundsOpInterface
+ MLIRBufferizationDialect
)
add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..76f3644345215 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -19,6 +19,7 @@ include "TestAttrDefs.td"
include "TestInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
// All of the types will extend this class.
class Test_Type<string name, list<Trait> traits = []>
@@ -403,4 +404,43 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
let mnemonic = "op_asm_type_interface";
}
+def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterface]> {
+ let mnemonic = "test_tensor";
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "mlir::Type":$elementType
+ );
+ let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
+
+ let extraClassDeclaration = [{
+ // ShapedTypeInterface:
+ bool hasRank() const {
+ return true;
+ }
+ test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+ return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType);
+ }
+ }];
+}
+
+def TestMemrefType : Test_Type<"TestMemref", [Bufferization_MemRefLikeTypeInterface]> {
+ let mnemonic = "test_memref";
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "mlir::Type":$elementType,
+ DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
+ );
+ let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
+
+ let extraClassDeclaration = [{
+ // ShapedTypeInterface:
+ bool hasRank() const {
+ return true;
+ }
+ test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+ return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace());
+ }
+ }];
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index cef3f056a7986..6499a96f495d0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -18,6 +18,7 @@
#include <tuple>
#include "TestTraits.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7da..4547e2dd3c8d0 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@@ -84,3 +86,78 @@ TEST(InterfaceTest, TestImplicitConversion) {
typeA = typeB;
EXPECT_EQ(typeA, typeB);
}
+
+TEST(InterfaceTest, TestBuiltinTensorIsTensorLikeType) {
+ MLIRContext context;
+ // Note: attaches external model to builtins
+ context.loadDialect<bufferization::BufferizationDialect>();
+
+ auto builtinRankedTensor = mlir::RankedTensorType::get(
+ {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+ EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinRankedTensor));
+ EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedTensor));
+
+ auto builtinUnrankedTensor =
+ mlir::UnrankedTensorType::get(mlir::IntegerType::get(&context, 32));
+ EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedTensor));
+ EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedTensor));
+}
+
+TEST(InterfaceTest, TestCustomTensorIsTensorLikeType) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ auto customTensorType = test::TestTensorType::get(
+ &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+ EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customTensorType));
+
+ auto customCloneType = customTensorType.cloneWith(
+ ArrayRef<int64_t>{3, 4, 5}, customTensorType.getElementType());
+ EXPECT_EQ(customTensorType.getElementType(),
+ customCloneType.getElementType());
+ EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customCloneType));
+ EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));
+
+ // user-specified conversions
+ bufferization::TensorLikeType baseCopy = customTensorType;
+ std::ignore = baseCopy;
+}
+
+TEST(InterfaceTest, TestBuiltinMemrefIsMemRefLikeType) {
+ MLIRContext context;
+ // Note: attaches external model to builtins
+ context.loadDialect<bufferization::BufferizationDialect>();
+
+ auto builtinRankedMemref =
+ mlir::MemRefType::get({1, 2, 3}, mlir::IntegerType::get(&context, 32));
+ EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedMemref));
+ EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinRankedMemref));
+
+ auto builtinUnrankedMemref = mlir::UnrankedMemRefType::get(
+ mlir::IntegerType::get(&context, 32), nullptr);
+ EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedMemref));
+ EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedMemref));
+}
+
+TEST(InterfaceTest, TestCustomMemrefIsMemRefLikeType) {
+ MLIRContext context;
+ context.loadDialect<test::TestDialect>();
+
+ auto customMemrefType = test::TestMemrefType::get(
+ &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
+ mlir::StringAttr::get(&context, "some_memspace"));
+ EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customMemrefType));
+
+ auto customCloneType = customMemrefType.cloneWith(
+ ArrayRef<int64_t>{3, 4, 5}, customMemrefType.getElementType());
+ EXPECT_EQ(customMemrefType.getElementType(),
+ customCloneType.getElementType());
+ EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customCloneType));
+ EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
+ EXPECT_EQ(customMemrefType.getMemSpace(),
+ mlir::cast<test::TestMemrefType>(customCloneType).getMemSpace());
+
+ // user-specified conversions
+ bufferization::MemRefLikeType baseCopy = customMemrefType;
+ std::ignore = baseCopy;
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/134220
More information about the Mlir-commits
mailing list