[Mlir-commits] [mlir] ada9aa5 - [mlir] Make MemRef element type extensible

Alex Zinenko llvmlistbot at llvm.org
Tue Jun 8 02:11:41 PDT 2021


Author: Alex Zinenko
Date: 2021-06-08T11:11:30+02:00
New Revision: ada9aa5a228200cb71269c371308e82c42fd4abc

URL: https://github.com/llvm/llvm-project/commit/ada9aa5a228200cb71269c371308e82c42fd4abc
DIFF: https://github.com/llvm/llvm-project/commit/ada9aa5a228200cb71269c371308e82c42fd4abc.diff

LOG: [mlir] Make MemRef element type extensible

Historically, MemRef only supported a restricted list of element types that
were known to be storable in memory. This is unnecessarily restrictive given
the open nature of MLIR's type system. Allow types to opt into being used as
MemRef elements by implementing a type interface. For now, the interface is
merely a declaration with no methods. Later, methods to query, e.g., the type
size or whether a type can alias elements of another type may be added.

Harden the "standard"-to-LLVM conversion against memrefs with non-builtin
types.

See https://llvm.discourse.group/t/rfc-memref-of-custom-types/3558.

Depends On D103826

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D103827

Added: 
    

Modified: 
    mlir/docs/Dialects/Builtin.md
    mlir/include/mlir/IR/BuiltinTypes.h
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/include/mlir/IR/CMakeLists.txt
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/IR/BuiltinTypes.cpp
    mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
    mlir/test/Conversion/StandardToLLVM/invalid.mlir
    mlir/test/IR/parser.mlir
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestTypeDefs.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/Builtin.md b/mlir/docs/Dialects/Builtin.md
index c48fc1bede687..b39506a39b5ab 100644
--- a/mlir/docs/Dialects/Builtin.md
+++ b/mlir/docs/Dialects/Builtin.md
@@ -30,3 +30,7 @@ Operations.
 ## Types
 
 [include "Dialects/BuiltinTypes.md"]
+
+## Type Interfaces
+
+[include "Dialects/BuiltinTypeInterfaces.md"]

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 718fffd3e7b6f..d858c3129091b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -192,6 +192,12 @@ class BaseMemRefType : public ShapedType {
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.h.inc"
 
+//===----------------------------------------------------------------------===//
+// Tablegen Interface Declarations
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
+
 namespace mlir {
 //===----------------------------------------------------------------------===//
 // MemRefType
@@ -266,7 +272,8 @@ inline bool BaseMemRefType::classof(Type type) {
 }
 
 inline bool BaseMemRefType::isValidElementType(Type type) {
-  return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>();
+  return type.isIntOrIndexOrFloat() || type.isa<ComplexType, VectorType>() ||
+         type.isa<MemRefElementTypeInterface>();
 }
 
 inline bool FloatType::classof(Type type) {

diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 349da5663f9d8..85787afc49547 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -248,6 +248,31 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MemRefElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type can be used as element in memref types.
+
+    Implementing this interface establishes a contract between this type and the
+    memref type indicating that this type can be used as element of ranked or
+    unranked memrefs. The type is expected to:
+
+      - model an entity stored in memory;
+      - have non-zero size.
+
+    For example, scalar values such as integers can implement this interface,
+    but indicator types such as `void` or `unit` should not.
+
+    The interface currently has no methods and is used by types to opt into
+    being memref elements. This may change in the future, in particular to
+    require types to provide their size or alignment given a data layout.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefType
 //===----------------------------------------------------------------------===//
@@ -282,6 +307,14 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
     on the rank. Other uses of this type are disallowed or will have undefined
     behavior.
 
+    Are accepted as elements:
+
+    - built-in integer types;
+    - built-in index type;
+    - built-in floating point types;
+    - built-in vector types with elements of the above types;
+    - any other type implementing `MemRefElementTypeInterface`.
+
     ##### Codegen of Unranked Memref
 
     Using unranked memref in codegen besides the case mentioned above is highly

diff  --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt
index 42e07811a4a54..b8b49aa425a9b 100644
--- a/mlir/include/mlir/IR/CMakeLists.txt
+++ b/mlir/include/mlir/IR/CMakeLists.txt
@@ -24,6 +24,8 @@ add_public_tablegen_target(MLIRBuiltinOpsIncGen)
 set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
 mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
 mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
+mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
 add_public_tablegen_target(MLIRBuiltinTypesIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
@@ -35,3 +37,4 @@ add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
 add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
 add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)
 add_mlir_doc(BuiltinTypes BuiltinTypes Dialects/ -gen-typedef-doc)
+add_mlir_doc(BuiltinTypes BuiltinTypeInterfaces Dialects/ -gen-type-interface-docs)

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index dcbb4b336213f..11d0cd6fdc766 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -349,6 +349,8 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
   // unpack the `sizes` and `strides` arrays.
   SmallVector<Type, 5> types =
       getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
+  if (types.empty())
+    return {};
   return LLVM::LLVMStructType::getLiteral(&getContext(), types);
 }
 
@@ -368,6 +370,8 @@ SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
 }
 
 Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) {
+  if (!convertType(type.getElementType()))
+    return {};
   return LLVM::LLVMStructType::getLiteral(&getContext(),
                                           getUnrankedMemRefDescriptorFields());
 }

diff  --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index baadd8d0433cc..77d64080de6e2 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -31,6 +31,12 @@ using namespace mlir::detail;
 #define GET_TYPEDEF_CLASSES
 #include "mlir/IR/BuiltinTypes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+/// Tablegen Interface Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // BuiltinDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
index 27623393148f3..6df3c94943759 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir
@@ -427,3 +427,23 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32>> } {
     return
   }
 }
+
+// -----
+
+// Should not convert memrefs with unsupported types in any convention.
+
+// CHECK: @unsupported_memref_element_type
+// CHECK-SAME: memref<
+// CHECK-NOT: !llvm.struct
+// BAREPTR: @unsupported_memref_element_type
+// BAREPTR-SAME: memref<
+// BAREPTR-NOT: !llvm.ptr
+func private @unsupported_memref_element_type() -> memref<42 x !test.memref_element>
+
+// CHECK: @unsupported_unranked_memref_element_type
+// CHECK-SAME: memref<
+// CHECK-NOT: !llvm.struct
+// BAREPTR: @unsupported_unranked_memref_element_type
+// BAREPTR-SAME: memref<
+// BAREPTR-NOT: !llvm.ptr
+func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element>

diff  --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
index 8dbc2bfddd800..5b6e7577cc77e 100644
--- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir
@@ -6,3 +6,4 @@ func private @unsupported_signature() -> tensor<10 x i32>
 // -----
 
 func private @partially_supported_signature() -> (vector<10 x i32>, tensor<10 x i32>)
+

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 7e8810c7479d9..2a3487cffe4c4 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -178,6 +178,9 @@ func private @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
 // CHECK: func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
 func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
 
+// CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
+func private @memref_with_custom_elem(memref<1x?x!test.memref_element>)
+
 // CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 func private @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
 

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index d1cf46ae5788b..30fe52e150790 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -17,8 +17,8 @@ mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs)
 add_public_tablegen_target(MLIRTestAttrDefIncGen)
 
 set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
-mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
-mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
+mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test)
+mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test)
 add_public_tablegen_target(MLIRTestTypeDefIncGen)
 
 

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 9821774eeede7..a5ae219780b4b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -15,6 +15,7 @@
 
 // To get the test dialect def.
 include "TestOps.td"
+include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 
 // All of the types will extend this class.
@@ -176,4 +177,9 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
   }];
 }
 
+def TestMemRefElementType : Test_Type<"TestMemRefElementType",
+                                      [MemRefElementTypeInterface]> {
+  let mnemonic = "memref_element";
+}
+
 #endif // TEST_TYPEDEFS


        


More information about the Mlir-commits mailing list