[Mlir-commits] [mlir] [mlir][bufferization] Add tensor-like and memref-like interfaces (PR #134220)

Andrei Golubev llvmlistbot at llvm.org
Mon Apr 7 05:26:51 PDT 2025


https://github.com/andrey-golubev updated https://github.com/llvm/llvm-project/pull/134220

>From 33e03b3035becca0467985447414044aeee81850 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Thu, 3 Apr 2025 07:58:32 +0000
Subject: [PATCH 1/3] [mlir][bufferization] Add tensor-like and memref-like
 interfaces

Current one-shot bufferization infrastructure operates on top of
TensorType and BaseMemRefType. These are non-extensible base classes of
the respective builtins: tensor and memref. Thus, the infrastructure is
bound to work only with builtin tensor/memref types. At the same time,
there are customization points that allow one to provide custom logic to
control the bufferization behavior.

This patch introduces new type interfaces: tensor-like and memref-like
that aim to supersede TensorType/BaseMemRefType within the bufferization
dialect and allow custom tensors / memrefs to be used. Additionally,
these new type interfaces are attached to the respective builtin types
so that the switch is seamless.

Note that this patch does very minimal initial work, it does NOT
refactor bufferization infrastructure.
---
 .../IR/BufferizationTypeInterfaces.h          | 21 +++++
 .../IR/BufferizationTypeInterfaces.td         | 51 ++++++++++++
 .../Dialect/Bufferization/IR/CMakeLists.txt   |  6 ++
 .../Bufferization/IR/BufferizationDialect.cpp | 23 ++++++
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |  1 +
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    | 40 ++++++++++
 mlir/test/lib/Dialect/Test/TestTypes.h        |  1 +
 mlir/unittests/IR/InterfaceTest.cpp           | 77 +++++++++++++++++++
 8 files changed, 220 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
 create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
new file mode 100644
index 0000000000000..9e83f0e3ad2de
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -0,0 +1,21 @@
+//===- BufferizationTypeInterfaces.h - Bufferization type interfaces -*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
+
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+
+//===----------------------------------------------------------------------===//
+// Bufferization Type Interfaces
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
new file mode 100644
index 0000000000000..c4c760d38203c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -0,0 +1,51 @@
+//===- BufferizationTypeInterfaces.td - Bufferization type interfaces -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------------===//
+//
+// This is the definition file for type interfaces used in Bufferization.
+//
+//===---------------------------------------------------------------------------------===//
+
+#ifndef BUFFERIZATION_TYPE_INTERFACES
+#define BUFFERIZATION_TYPE_INTERFACES
+
+include "mlir/IR/OpBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
+
+def Bufferization_TensorLikeTypeInterface
+    : TypeInterface<"TensorLikeType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir::bufferization";
+  let description = [{
+    Indicates that the type that attaches this interface can be treated as a
+    tensor type (similarly to a MLIR builtin tensor) within the Bufferization
+    dialect.
+
+    Implementing this interface means that the type also implements
+    ShapedTypeInterface.
+
+    The interface currently has no methods as it is used by types to opt into
+    being supported by the bufferization procedures.
+  }];
+}
+
+def Bufferization_MemRefLikeTypeInterface
+    : TypeInterface<"MemRefLikeType", [ShapedTypeInterface]> {
+  let cppNamespace = "::mlir::bufferization";
+  let description = [{
+    Indicates that the type that attaches this interface can be treated as a
+    memref type (similarly to a MLIR builtin memref) within the Bufferization
+    dialect.
+
+    Implementing this interface means that the type also implements
+    ShapedTypeInterface.
+
+    The interface currently has no methods as it is used by types to opt into
+    being supported by the bufferization procedures.
+  }];
+}
+
+#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 13a5bc370a4fc..3ead52148c208 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -10,3 +10,9 @@ mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
 mlir_tablegen(BufferizationEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRBufferizationEnumsIncGen)
 add_dependencies(mlir-headers MLIRBufferizationEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS BufferizationTypeInterfaces.td)
+mlir_tablegen(BufferizationTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(BufferizationTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRBufferizationTypeInterfacesIncGen)
+add_dependencies(mlir-headers MLIRBufferizationTypeInterfacesIncGen)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index e5a0c3c45b09e..31faa96695379 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -9,8 +9,11 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/ExecutionEngine/CRunnerUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
 
@@ -51,6 +54,16 @@ struct BufferizationInlinerInterface : public DialectInlinerInterface {
     return true;
   }
 };
+
+template <typename Tensor>
+struct BuiltinTensorExternalModel
+    : TensorLikeType::ExternalModel<BuiltinTensorExternalModel<Tensor>,
+                                    Tensor> {};
+
+template <typename MemRef>
+struct BuiltinMemRefExternalModel
+    : MemRefLikeType::ExternalModel<BuiltinMemRefExternalModel<MemRef>,
+                                    MemRef> {};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -63,6 +76,16 @@ void mlir::bufferization::BufferizationDialect::initialize() {
 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
       >();
   addInterfaces<BufferizationInlinerInterface>();
+
+  assert(getContext() != nullptr);
+  RankedTensorType::attachInterface<
+      BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
+  UnrankedTensorType::attachInterface<
+      BuiltinTensorExternalModel<UnrankedTensorType>>(*getContext());
+  MemRefType::attachInterface<BuiltinMemRefExternalModel<MemRefType>>(
+      *getContext());
+  UnrankedMemRefType::attachInterface<
+      BuiltinMemRefExternalModel<UnrankedMemRefType>>(*getContext());
 }
 
 LogicalResult BufferizationDialect::verifyRegionArgAttribute(
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index a48ac24ca056d..6e608e4772391 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -93,6 +93,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
   MLIRTransformUtils
   MLIRTransforms
   MLIRValueBoundsOpInterface
+  MLIRBufferizationDialect
 )
 
 add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..76f3644345215 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -19,6 +19,7 @@ include "TestAttrDefs.td"
 include "TestInterfaces.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 
 // All of the types will extend this class.
 class Test_Type<string name, list<Trait> traits = []>
@@ -403,4 +404,43 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
   let mnemonic = "op_asm_type_interface";
 }
 
+def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterface]> {
+  let mnemonic = "test_tensor";
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "mlir::Type":$elementType
+  );
+  let assemblyFormat = "`<` `[` $shape `]` `,` $elementType `>`";
+
+  let extraClassDeclaration = [{
+    // ShapedTypeInterface:
+    bool hasRank() const {
+      return true;
+    }
+    test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+      return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType);
+    }
+  }];
+}
+
+def TestMemrefType : Test_Type<"TestMemref", [Bufferization_MemRefLikeTypeInterface]> {
+  let mnemonic = "test_memref";
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "mlir::Type":$elementType,
+    DefaultValuedParameter<"mlir::Attribute", "nullptr">:$memSpace
+  );
+  let assemblyFormat = "`<` `[` $shape `]` `,` $elementType (`,` $memSpace^)? `>`";
+
+  let extraClassDeclaration = [{
+    // ShapedTypeInterface:
+    bool hasRank() const {
+      return true;
+    }
+    test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
+      return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace());
+    }
+  }];
+}
+
 #endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index cef3f056a7986..6499a96f495d0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -18,6 +18,7 @@
 #include <tuple>
 
 #include "TestTraits.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectImplementation.h"
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 42196b003e7da..4547e2dd3c8d0 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -84,3 +86,78 @@ TEST(InterfaceTest, TestImplicitConversion) {
   typeA = typeB;
   EXPECT_EQ(typeA, typeB);
 }
+
+TEST(InterfaceTest, TestBuiltinTensorIsTensorLikeType) {
+  MLIRContext context;
+  // Note: attaches external model to builtins
+  context.loadDialect<bufferization::BufferizationDialect>();
+
+  auto builtinRankedTensor = mlir::RankedTensorType::get(
+      {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinRankedTensor));
+  EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedTensor));
+
+  auto builtinUnrankedTensor =
+      mlir::UnrankedTensorType::get(mlir::IntegerType::get(&context, 32));
+  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedTensor));
+  EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedTensor));
+}
+
+TEST(InterfaceTest, TestCustomTensorIsTensorLikeType) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  auto customTensorType = test::TestTensorType::get(
+      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
+  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customTensorType));
+
+  auto customCloneType = customTensorType.cloneWith(
+      ArrayRef<int64_t>{3, 4, 5}, customTensorType.getElementType());
+  EXPECT_EQ(customTensorType.getElementType(),
+            customCloneType.getElementType());
+  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customCloneType));
+  EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));
+
+  // user-specified conversions
+  bufferization::TensorLikeType baseCopy = customTensorType;
+  std::ignore = baseCopy;
+}
+
+TEST(InterfaceTest, TestBuiltinMemrefIsMemRefLikeType) {
+  MLIRContext context;
+  // Note: attaches external model to builtins
+  context.loadDialect<bufferization::BufferizationDialect>();
+
+  auto builtinRankedMemref =
+      mlir::MemRefType::get({1, 2, 3}, mlir::IntegerType::get(&context, 32));
+  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedMemref));
+  EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinRankedMemref));
+
+  auto builtinUnrankedMemref = mlir::UnrankedMemRefType::get(
+      mlir::IntegerType::get(&context, 32), nullptr);
+  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedMemref));
+  EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedMemref));
+}
+
+TEST(InterfaceTest, TestCustomMemrefIsMemRefLikeType) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  auto customMemrefType = test::TestMemrefType::get(
+      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
+      mlir::StringAttr::get(&context, "some_memspace"));
+  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customMemrefType));
+
+  auto customCloneType = customMemrefType.cloneWith(
+      ArrayRef<int64_t>{3, 4, 5}, customMemrefType.getElementType());
+  EXPECT_EQ(customMemrefType.getElementType(),
+            customCloneType.getElementType());
+  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customCloneType));
+  EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
+  EXPECT_EQ(customMemrefType.getMemSpace(),
+            mlir::cast<test::TestMemrefType>(customCloneType).getMemSpace());
+
+  // user-specified conversions
+  bufferization::MemRefLikeType baseCopy = customMemrefType;
+  std::ignore = baseCopy;
+}

>From cb5892f5d674b9d18c27d8ddbc1dc29cf9c7c6ba Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Fri, 4 Apr 2025 16:30:17 +0000
Subject: [PATCH 2/3] Address code review feedback (simpler fixes)

---
 .../Bufferization/IR/BufferizationTypeInterfaces.h   |  7 +++----
 .../Bufferization/IR/BufferizationTypeInterfaces.td  | 12 +++++-------
 .../Bufferization/IR/BufferizationDialect.cpp        |  2 --
 mlir/test/lib/Dialect/Test/TestTypeDefs.td           | 12 ++++++++----
 4 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 9e83f0e3ad2de..b72d2b0419bbb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -1,16 +1,15 @@
-//===- BufferizationTypeInterfaces.h - Bufferization type interfaces -*- C++
-//-*-===//
+//===- BufferizationTypeInterfaces.h - Type Interfaces ----------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===---------------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
 
-#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h" // for ShapedTypeInterface
 
 //===----------------------------------------------------------------------===//
 // Bufferization Type Interfaces
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index c4c760d38203c..e4ad668224ede 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -1,14 +1,14 @@
-//===- BufferizationTypeInterfaces.td - Bufferization type interfaces -*- tablegen -*-===//
+//===- BufferizationTypeInterfaces.td - Type Interfaces ----*- tablegen -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
-//===---------------------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 //
 // This is the definition file for type interfaces used in Bufferization.
 //
-//===---------------------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
 
 #ifndef BUFFERIZATION_TYPE_INTERFACES
 #define BUFFERIZATION_TYPE_INTERFACES
@@ -21,8 +21,7 @@ def Bufferization_TensorLikeTypeInterface
   let cppNamespace = "::mlir::bufferization";
   let description = [{
     Indicates that the type that attaches this interface can be treated as a
-    tensor type (similarly to a MLIR builtin tensor) within the Bufferization
-    dialect.
+    tensor type (similarly to a MLIR builtin tensor) during bufferization.
 
     Implementing this interface means that the type also implements
     ShapedTypeInterface.
@@ -37,8 +36,7 @@ def Bufferization_MemRefLikeTypeInterface
   let cppNamespace = "::mlir::bufferization";
   let description = [{
     Indicates that the type that attaches this interface can be treated as a
-    memref type (similarly to a MLIR builtin memref) within the Bufferization
-    dialect.
+    memref type (similarly to a MLIR builtin memref) during bufferization.
 
     Implementing this interface means that the type also implements
     ShapedTypeInterface.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 31faa96695379..3ed66fcc479f8 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -12,7 +12,6 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -77,7 +76,6 @@ void mlir::bufferization::BufferizationDialect::initialize() {
       >();
   addInterfaces<BufferizationInlinerInterface>();
 
-  assert(getContext() != nullptr);
   RankedTensorType::attachInterface<
       BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
   UnrankedTensorType::attachInterface<
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 76f3644345215..71b6c287f3193 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -417,8 +417,10 @@ def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterf
     bool hasRank() const {
       return true;
     }
-    test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
-      return test::TestTensorType::get(getContext(), shape.value_or(getShape()), elementType);
+    test::TestTensorType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
+                                   mlir::Type elementType) const {
+      return test::TestTensorType::get(
+        getContext(), shape.value_or(getShape()), elementType);
     }
   }];
 }
@@ -437,8 +439,10 @@ def TestMemrefType : Test_Type<"TestMemref", [Bufferization_MemRefLikeTypeInterf
     bool hasRank() const {
       return true;
     }
-    test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, mlir::Type elementType) const {
-      return test::TestMemrefType::get(getContext(), shape.value_or(getShape()), elementType, getMemSpace());
+    test::TestMemrefType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape,
+                                   mlir::Type elementType) const {
+      return test::TestMemrefType::get(
+        getContext(), shape.value_or(getShape()), elementType, getMemSpace());
     }
   }];
 }

>From 911096ee101c3762388768432995570ec62f3cf4 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Mon, 7 Apr 2025 12:04:02 +0000
Subject: [PATCH 3/3] Address code review feedback (part 2): rewrite tests,
 simplify code

---
 .../IR/BufferizationTypeInterfaces.h          |  2 -
 .../IR/BufferizationTypeInterfaces.td         | 19 ++--
 .../Bufferization/Transforms/Passes.td        |  4 +
 .../Bufferization/IR/BufferizationDialect.cpp |  5 +
 .../Bufferization/Transforms/Bufferize.cpp    |  5 -
 .../Transforms/tensorlike-memreflike.mlir     | 37 +++++++
 .../lib/Dialect/Bufferization/CMakeLists.txt  |  8 ++
 .../TestTensorLikeAndMemRefLike.cpp           | 99 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestTypeDefs.td    |  6 +-
 mlir/tools/mlir-opt/mlir-opt.cpp              |  2 +
 mlir/unittests/IR/InterfaceTest.cpp           | 77 ---------------
 11 files changed, 165 insertions(+), 99 deletions(-)
 create mode 100644 mlir/test/Dialect/Bufferization/Transforms/tensorlike-memreflike.mlir
 create mode 100644 mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndMemRefLike.cpp

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index b72d2b0419bbb..f6b296eccd748 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
 
-#include "mlir/IR/BuiltinTypeInterfaces.h" // for ShapedTypeInterface
-
 //===----------------------------------------------------------------------===//
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index e4ad668224ede..38345ebbcec00 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -14,17 +14,13 @@
 #define BUFFERIZATION_TYPE_INTERFACES
 
 include "mlir/IR/OpBase.td"
-include "mlir/IR/BuiltinTypeInterfaces.td"
 
 def Bufferization_TensorLikeTypeInterface
-    : TypeInterface<"TensorLikeType", [ShapedTypeInterface]> {
+    : TypeInterface<"TensorLikeType"> {
   let cppNamespace = "::mlir::bufferization";
   let description = [{
-    Indicates that the type that attaches this interface can be treated as a
-    tensor type (similarly to a MLIR builtin tensor) during bufferization.
-
-    Implementing this interface means that the type also implements
-    ShapedTypeInterface.
+    Indicates that this type is a tensor type (similarly to a MLIR builtin
+    tensor) for bufferization purposes.
 
     The interface currently has no methods as it is used by types to opt into
     being supported by the bufferization procedures.
@@ -32,14 +28,11 @@ def Bufferization_TensorLikeTypeInterface
 }
 
 def Bufferization_MemRefLikeTypeInterface
-    : TypeInterface<"MemRefLikeType", [ShapedTypeInterface]> {
+    : TypeInterface<"MemRefLikeType"> {
   let cppNamespace = "::mlir::bufferization";
   let description = [{
-    Indicates that the type that attaches this interface can be treated as a
-    memref type (similarly to a MLIR builtin memref) during bufferization.
-
-    Implementing this interface means that the type also implements
-    ShapedTypeInterface.
+    Indicates that this type is a memref type (similarly to a MLIR builtin
+    memref) for bufferization purposes.
 
     The interface currently has no methods as it is used by types to opt into
     being supported by the bufferization procedures.
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
index f53f569070f09..ee33476f441ee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
@@ -471,6 +471,10 @@ def OneShotBufferizePass : Pass<"one-shot-bufferize", "ModuleOp"> {
     Statistic<"numTensorOutOfPlace", "num-tensor-out-of-place",
               "Number of out-of-place tensor OpOperands">,
   ];
+
+  let dependentDialects = [
+    "bufferization::BufferizationDialect", "memref::MemRefDialect"
+  ];
 }
 
 def PromoteBuffersToStackPass
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
index 3ed66fcc479f8..dd00e1293dd2b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp
@@ -76,6 +76,11 @@ void mlir::bufferization::BufferizationDialect::initialize() {
       >();
   addInterfaces<BufferizationInlinerInterface>();
 
+  // Note: Unlike with other external models, declaring bufferization's
+  // "promised interfaces" in builtins for TensorLike and MemRefLike type
+  // interfaces is not possible (due to builtins being independent of
+  // bufferization). Thus, the compromise is to attach these interfaces directly
+  // during dialect initialization.
   RankedTensorType::attachInterface<
       BuiltinTensorExternalModel<RankedTensorType>>(*getContext());
   UnrankedTensorType::attachInterface<
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index e97b34b20ff72..0b60c44ece5fd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -57,11 +57,6 @@ struct OneShotBufferizePass
           OneShotBufferizePass> {
   using Base::Base;
 
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
-  }
-
   void runOnOperation() override {
     OneShotBufferizationOptions opt;
     if (!options) {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/tensorlike-memreflike.mlir b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-memreflike.mlir
new file mode 100644
index 0000000000000..f676fc5fdea65
--- /dev/null
+++ b/mlir/test/Dialect/Bufferization/Transforms/tensorlike-memreflike.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -test-tensorlike-memreflike -split-input-file | FileCheck %s
+
+// CHECK: func.func @builtin_unranked
+// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_memref_like"}}
+func.func @builtin_unranked(%t: tensor<*xf32>) -> (memref<*xf32>)
+{
+  %0 = bufferization.to_memref %t : tensor<*xf32> to memref<*xf32>
+  return %0 : memref<*xf32>
+}
+
+// -----
+
+// CHECK: func.func @builtin_ranked
+// CHECK-SAME: {found = {operand_0 = "is_tensor_like", result_0 = "is_memref_like"}}
+func.func @builtin_ranked(%t: tensor<42xf32>) -> (memref<42xf32>)
+{
+  %0 = bufferization.to_memref %t : tensor<42xf32> to memref<42xf32>
+  return %0 : memref<42xf32>
+}
+
+// -----
+
+// CHECK: func.func @custom_tensor
+// CHECK-SAME: {found = {operand_0 = "is_tensor_like"}}
+func.func @custom_tensor(%t: !test.test_tensor<[42], f32>) -> ()
+{
+  return
+}
+
+// -----
+
+// CHECK: func.func @custom_memref
+// CHECK-SAME: {found = {operand_0 = "is_memref_like"}}
+func.func @custom_memref(%t: !test.test_memref<[42], f32>) -> ()
+{
+  return
+}
diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
index c14a9f2cc9bb0..a80541b298cbb 100644
--- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRBufferizationTestPasses
   TestTensorCopyInsertion.cpp
+  TestTensorLikeAndMemRefLike.cpp
 
   EXCLUDE_FROM_LIBMLIR
 )
@@ -9,4 +10,11 @@ mlir_target_link_libraries(MLIRBufferizationTestPasses PUBLIC
   MLIRBufferizationTransforms
   MLIRIR
   MLIRPass
+  MLIRTestDialect
 )
+
+target_include_directories(MLIRBufferizationTestPasses
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
+  )
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndMemRefLike.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndMemRefLike.cpp
new file mode 100644
index 0000000000000..7fd07bd518086
--- /dev/null
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorLikeAndMemRefLike.cpp
@@ -0,0 +1,99 @@
+//===- TestTensorLikeAndMemRefLike.cpp - Bufferization Test -----*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Pass/Pass.h"
+
+#include <string>
+
+using namespace mlir;
+
+namespace {
+std::string getImplementationStatus(Type type) {
+  if (isa<bufferization::TensorLikeType>(type)) {
+    return "is_tensor_like";
+  }
+  if (isa<bufferization::MemRefLikeType>(type)) {
+    return "is_memref_like";
+  }
+  return {};
+}
+
+DictionaryAttr findAllImplementeesOfTensorOrMemRefLike(func::FuncOp funcOp) {
+  llvm::SmallVector<NamedAttribute> attributes;
+
+  const auto funcType = funcOp.getFunctionType();
+  for (auto [index, inputType] : llvm::enumerate(funcType.getInputs())) {
+    const auto status = getImplementationStatus(inputType);
+    if (status.empty()) {
+      continue;
+    }
+
+    attributes.push_back(
+        NamedAttribute(StringAttr::get(funcOp.getContext(),
+                                       "operand_" + std::to_string(index)),
+                       StringAttr::get(funcOp.getContext(), status)));
+  }
+
+  for (auto [index, resultType] : llvm::enumerate(funcType.getResults())) {
+    const auto status = getImplementationStatus(resultType);
+    if (status.empty()) {
+      continue;
+    }
+
+    attributes.push_back(NamedAttribute(
+        StringAttr::get(funcOp.getContext(), "result_" + std::to_string(index)),
+        StringAttr::get(funcOp.getContext(), status)));
+  }
+
+  return mlir::DictionaryAttr::get(funcOp.getContext(), attributes);
+}
+
+/// This pass tests whether specified types implement TensorLike and (or)
+/// MemRefLike type interfaces defined in bufferization.
+///
+/// The pass analyses operation signature. When the aforementioned interface
+/// implementation found, an attribute is added to the operation, signifying the
+/// associated operand / result.
+struct TestTensorLikeAndMemRefLikePass
+    : public PassWrapper<TestTensorLikeAndMemRefLikePass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTensorLikeAndMemRefLikePass)
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, test::TestDialect>();
+  }
+  StringRef getArgument() const final { return "test-tensorlike-memreflike"; }
+  StringRef getDescription() const final {
+    return "Module pass to test custom types that implement TensorLike / "
+           "MemRefLike interfaces";
+  }
+
+  void runOnOperation() override {
+    auto op = getOperation();
+
+    op.walk([](func::FuncOp funcOp) {
+      const auto dict = findAllImplementeesOfTensorOrMemRefLike(funcOp);
+      if (!dict.empty()) {
+        funcOp->setAttr("found", dict);
+      }
+    });
+  }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestTensorLikeAndMemRefLikePass() {
+  PassRegistration<TestTensorLikeAndMemRefLikePass>();
+}
+} // namespace mlir::test
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 71b6c287f3193..eaf645cad2c43 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -404,7 +404,8 @@ def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
   let mnemonic = "op_asm_type_interface";
 }
 
-def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterface]> {
+def TestTensorType : Test_Type<"TestTensor",
+    [Bufferization_TensorLikeTypeInterface, ShapedTypeInterface]> {
   let mnemonic = "test_tensor";
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
@@ -425,7 +426,8 @@ def TestTensorType : Test_Type<"TestTensor", [Bufferization_TensorLikeTypeInterf
   }];
 }
 
-def TestMemrefType : Test_Type<"TestMemref", [Bufferization_MemRefLikeTypeInterface]> {
+def TestMemrefType : Test_Type<"TestMemref",
+    [Bufferization_MemRefLikeTypeInterface, ShapedTypeInterface]> {
   let mnemonic = "test_memref";
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index d06ff8070e7cf..c3506d8966e05 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -149,6 +149,7 @@ void registerTestSPIRVCPURunnerPipeline();
 void registerTestSPIRVFuncSignatureConversion();
 void registerTestSPIRVVectorUnrolling();
 void registerTestTensorCopyInsertionPass();
+void registerTestTensorLikeAndMemRefLikePass();
 void registerTestTensorTransforms();
 void registerTestTopologicalSortAnalysisPass();
 void registerTestTransformDialectEraseSchedulePass();
@@ -291,6 +292,7 @@ void registerTestPasses() {
   mlir::test::registerTestSPIRVFuncSignatureConversion();
   mlir::test::registerTestSPIRVVectorUnrolling();
   mlir::test::registerTestTensorCopyInsertionPass();
+  mlir::test::registerTestTensorLikeAndMemRefLikePass();
   mlir::test::registerTestTensorTransforms();
   mlir::test::registerTestTopologicalSortAnalysisPass();
   mlir::test::registerTestTransformDialectEraseSchedulePass();
diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 4547e2dd3c8d0..42196b003e7da 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -6,8 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -86,78 +84,3 @@ TEST(InterfaceTest, TestImplicitConversion) {
   typeA = typeB;
   EXPECT_EQ(typeA, typeB);
 }
-
-TEST(InterfaceTest, TestBuiltinTensorIsTensorLikeType) {
-  MLIRContext context;
-  // Note: attaches external model to builtins
-  context.loadDialect<bufferization::BufferizationDialect>();
-
-  auto builtinRankedTensor = mlir::RankedTensorType::get(
-      {1, 2, 3}, mlir::IntegerType::get(&context, 32));
-  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinRankedTensor));
-  EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedTensor));
-
-  auto builtinUnrankedTensor =
-      mlir::UnrankedTensorType::get(mlir::IntegerType::get(&context, 32));
-  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedTensor));
-  EXPECT_FALSE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedTensor));
-}
-
-TEST(InterfaceTest, TestCustomTensorIsTensorLikeType) {
-  MLIRContext context;
-  context.loadDialect<test::TestDialect>();
-
-  auto customTensorType = test::TestTensorType::get(
-      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32));
-  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customTensorType));
-
-  auto customCloneType = customTensorType.cloneWith(
-      ArrayRef<int64_t>{3, 4, 5}, customTensorType.getElementType());
-  EXPECT_EQ(customTensorType.getElementType(),
-            customCloneType.getElementType());
-  EXPECT_TRUE(mlir::isa<bufferization::TensorLikeType>(customCloneType));
-  EXPECT_TRUE(mlir::isa<test::TestTensorType>(customCloneType));
-
-  // user-specified conversions
-  bufferization::TensorLikeType baseCopy = customTensorType;
-  std::ignore = baseCopy;
-}
-
-TEST(InterfaceTest, TestBuiltinMemrefIsMemRefLikeType) {
-  MLIRContext context;
-  // Note: attaches external model to builtins
-  context.loadDialect<bufferization::BufferizationDialect>();
-
-  auto builtinRankedMemref =
-      mlir::MemRefType::get({1, 2, 3}, mlir::IntegerType::get(&context, 32));
-  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinRankedMemref));
-  EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinRankedMemref));
-
-  auto builtinUnrankedMemref = mlir::UnrankedMemRefType::get(
-      mlir::IntegerType::get(&context, 32), nullptr);
-  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(builtinUnrankedMemref));
-  EXPECT_FALSE(mlir::isa<bufferization::TensorLikeType>(builtinUnrankedMemref));
-}
-
-TEST(InterfaceTest, TestCustomMemrefIsMemRefLikeType) {
-  MLIRContext context;
-  context.loadDialect<test::TestDialect>();
-
-  auto customMemrefType = test::TestMemrefType::get(
-      &context, {1, 2, 3}, mlir::IntegerType::get(&context, 32),
-      mlir::StringAttr::get(&context, "some_memspace"));
-  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customMemrefType));
-
-  auto customCloneType = customMemrefType.cloneWith(
-      ArrayRef<int64_t>{3, 4, 5}, customMemrefType.getElementType());
-  EXPECT_EQ(customMemrefType.getElementType(),
-            customCloneType.getElementType());
-  EXPECT_TRUE(mlir::isa<bufferization::MemRefLikeType>(customCloneType));
-  EXPECT_TRUE(mlir::isa<test::TestMemrefType>(customCloneType));
-  EXPECT_EQ(customMemrefType.getMemSpace(),
-            mlir::cast<test::TestMemrefType>(customCloneType).getMemSpace());
-
-  // user-specified conversions
-  bufferization::MemRefLikeType baseCopy = customMemrefType;
-  std::ignore = baseCopy;
-}



More information about the Mlir-commits mailing list