[Mlir-commits] [mlir] e4b79a5 - [mlir] add an interface to support custom types in LLVM dialect pointers

Alex Zinenko llvmlistbot at llvm.org
Fri Jul 16 04:05:40 PDT 2021


Author: Alex Zinenko
Date: 2021-07-16T13:05:27+02:00
New Revision: e4b79a542e2217b1838cae8e3434ebc3705e8d43

URL: https://github.com/llvm/llvm-project/commit/e4b79a542e2217b1838cae8e3434ebc3705e8d43
DIFF: https://github.com/llvm/llvm-project/commit/e4b79a542e2217b1838cae8e3434ebc3705e8d43.diff

LOG: [mlir] add an interface to support custom types in LLVM dialect pointers

This may be necessary in partial multi-stage conversion when a container type
from dialect A containing types from dialect B goes through the conversion
where only dialect A is converted to the LLVM dialect. We will need to keep a
pointer-to-non-LLVM type in the IR until a further conversion can convert
dialect B types to LLVM types.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D106076

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/test/Dialect/LLVMIR/types.mlir
    mlir/test/lib/Dialect/Test/CMakeLists.txt
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
index 91754f16e8a3..f7ca6dd7624b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
@@ -16,7 +16,12 @@ add_public_tablegen_target(MLIRLLVMOpsIncGen)
 
 add_mlir_doc(LLVMOps LLVMOps Dialects/ -gen-op-doc)
 
-add_mlir_interface(LLVMOpsInterfaces)
+set(LLVM_TARGET_DEFINITIONS LLVMOpsInterfaces.td)
+mlir_tablegen(LLVMOpsInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(LLVMOpsInterfaces.cpp.inc -gen-op-interface-defs)
+mlir_tablegen(LLVMTypeInterfaces.h.inc -gen-type-interface-decls)
+mlir_tablegen(LLVMTypeInterfaces.cpp.inc -gen-type-interface-defs)
+add_public_tablegen_target(MLIRLLVMOpsInterfacesIncGen)
 
 set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
 mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 716260f3819d..750cf53a1e99 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -108,10 +108,18 @@ def LLVM_OpaqueStruct : Type<
   And<[LLVM_AnyStruct.predicate,
        CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().isOpaque()">]>>;
 
+// Type constraint accepting types that implement that pointer element
+// interface.
+def LLVM_PointerElementType : Type<
+  CPred<"$_self.isa<::mlir::LLVM::PointerElementTypeInterface>()">,
+  "LLVM-compatible pointer element type">;
+
+
 // Type constraint accepting any LLVM type that can be loaded or stored, i.e. a
 // type that has size (not void, function or opaque struct type).
 def LLVM_LoadableType : Type<
-  And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
+  Or<[And<[LLVM_PrimitiveType.predicate, Neg<LLVM_OpaqueStruct.predicate>]>,
+      LLVM_PointerElementType.predicate]>,
   "LLVM type with size">;
 
 // Type constraint accepting any LLVM aggregate type, i.e. structure or array.

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 3853017b9a46..cdf57ce28a1e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -331,7 +331,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
                    OptionalAttr<SymbolRefArrayAttr>:$access_groups,
                    OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
                    UnitAttr:$nontemporal);
-  let results = (outs LLVM_Type:$res);
+  let results = (outs LLVM_LoadableType:$res);
   string llvmBuilder = [{
     auto *inst = builder.CreateLoad(
       $addr->getType()->getPointerElementType(), $addr, $volatile_);

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
index d31ae81ab2dd..495518621015 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td
@@ -23,8 +23,38 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   let cppNamespace = "::mlir::LLVM";
 
   let methods = [
-    InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "fastmathFlags">,
+    InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags",
+                    "fastmathFlags">,
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVM dialect type interfaces.
+//===----------------------------------------------------------------------===//
+
+// An interface for LLVM pointer element types.
+def LLVM_PointerElementTypeInterface
+    : TypeInterface<"PointerElementTypeInterface"> {
+  let cppNamespace = "::mlir::LLVM";
+
+  let description = [{
+    An interface for types that are allowed as elements of LLVM pointer type.
+    Such types must have a size.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/"Returns the size of the type in bytes.",
+      /*retTy=*/"unsigned",
+      /*methodName=*/"getSizeInBytes",
+      /*args=*/(ins "const DataLayout &":$dataLayout),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return dataLayout.getTypeSize($_type);
+      }]
+    >
+  ];
+}
+
+
 #endif // LLVM_OPS_INTERFACES

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index bf10bd51b4fc..150d62fcd5d9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -36,6 +36,13 @@ struct LLVMPointerTypeStorage;
 struct LLVMStructTypeStorage;
 struct LLVMTypeAndSizeStorage;
 } // namespace detail
+} // namespace LLVM
+} // namespace mlir
+
+#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc"
+
+namespace mlir {
+namespace LLVM {
 
 //===----------------------------------------------------------------------===//
 // Trivial types.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index fc4663db18f4..d9560bf9139d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -120,8 +120,9 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
 //===----------------------------------------------------------------------===//
 
 bool LLVMPointerType::isValidElementType(Type type) {
-  return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
-                   LLVMLabelType>();
+  return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
+                                            LLVMMetadataType, LLVMLabelType>()
+                                : type.isa<PointerElementTypeInterface>();
 }
 
 LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@@ -607,3 +608,5 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
         return llvm::TypeSize::Fixed(0);
       });
 }
+
+#include "mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc"

diff  --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir
index 45c864060f5c..9a53f56ce70e 100644
--- a/mlir/test/Dialect/LLVMIR/types.mlir
+++ b/mlir/test/Dialect/LLVMIR/types.mlir
@@ -176,6 +176,14 @@ func @verbose() {
   return
 }
 
+// CHECK-LABEL: @ptr_elem_interface
+// CHECK-COUNT-3: !llvm.ptr<!test.smpla>
+func @ptr_elem_interface(%arg0: !llvm.ptr<!test.smpla>) {
+  %0 = llvm.load %arg0 : !llvm.ptr<!test.smpla>
+  llvm.store %0, %arg0 : !llvm.ptr<!test.smpla>
+  return
+}
+
 // -----
 
 // Check that type aliases can be used inside LLVM dialect types. Note that

diff  --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 5f37b09dda4c..91af79e9578c 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -62,6 +62,7 @@ add_mlir_library(MLIRTestDialect
   MLIRIR
   MLIRInferTypeOpInterface
   MLIRLinalgTransforms
+  MLIRLLVMIR
   MLIRPass
   MLIRReduce
   MLIRStandard

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 96cfd4cd66ac..6f01540c8b39 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -13,6 +13,7 @@
 
 #include "TestTypes.h"
 #include "TestDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Types.h"
@@ -222,11 +223,19 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
 // TestDialect
 //===----------------------------------------------------------------------===//
 
+namespace {
+
+struct PtrElementModel
+    : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
+                                                              SimpleAType> {};
+} // namespace
+
 void TestDialect::registerTypes() {
   addTypes<TestRecursiveType,
 #define GET_TYPEDEF_LIST
 #include "TestTypeDefs.cpp.inc"
            >();
+  SimpleAType::attachInterface<PtrElementModel>(*getContext());
 }
 
 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index d219480c76fb..c9f831f07a7d 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2443,6 +2443,14 @@ gentbl_cc_library(
             ["-gen-op-interface-defs"],
             "include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.cpp.inc",
         ),
+        (
+            ["-gen-type-interface-decls"],
+            "include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.h.inc",
+        ),
+        (
+            ["-gen-type-interface-defs"],
+            "include/mlir/Dialect/LLVMIR/LLVMTypeInterfaces.cpp.inc",
+        ),
     ],
     tblgen = ":mlir-tblgen",
     td_file = "include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td",

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index f76a5a4ec0fe..b6c6b0c5eb02 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -227,6 +227,7 @@ cc_library(
         "//mlir:Dialect",
         "//mlir:IR",
         "//mlir:InferTypeOpInterface",
+        "//mlir:LLVMDialect",
         "//mlir:Pass",
         "//mlir:Reducer",
         "//mlir:SideEffects",


        


More information about the Mlir-commits mailing list