[Mlir-commits] [mlir] a9bde16 - [mlir][llvm] Move LLVMPointerType to TypeDef

Jeff Niu llvmlistbot at llvm.org
Fri Oct 21 15:13:21 PDT 2022


Author: Jeff Niu
Date: 2022-10-21T15:13:08-07:00
New Revision: a9bde16ba5f7aa06c34e3091d71334071df5b114

URL: https://github.com/llvm/llvm-project/commit/a9bde16ba5f7aa06c34e3091d71334071df5b114
DIFF: https://github.com/llvm/llvm-project/commit/a9bde16ba5f7aa06c34e3091d71334071df5b114.diff

LOG: [mlir][llvm] Move LLVMPointerType to TypeDef

Depends on D136485

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D136498

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/test/Dialect/LLVMIR/layout.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index 38990ac128bd..a6b47cf50e8b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -122,7 +122,7 @@ class LLVM_LifetimeBaseOp<string opName> : LLVM_ZeroResultIntrOp<opName> {
                             moduleTranslation.lookupValue(op.getPtr())});
   }];
 
-  let assemblyFormat = "$size `,` $ptr attr-dict `:` type($ptr)";
+  let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
 }
 
 def LLVM_LifetimeStartOp : LLVM_LifetimeBaseOp<"lifetime.start"> {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index b4429fc6d13f..7e6f4e9e0ec6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -1232,7 +1232,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
     LLVMFuncOp getFunction(SymbolTableCollection &symbolTable);
   }];
 
-  let assemblyFormat = "$global_name attr-dict `:` type($res)";
+  let assemblyFormat = "$global_name attr-dict `:` qualified(type($res))";
 }
 
 def LLVM_MetadataOp : LLVM_Op<"metadata", [
@@ -1656,7 +1656,7 @@ def LLVM_NullOp
 
   let results = (outs LLVM_AnyPointer:$res);
   let builders = [LLVM_OneResultOpBuilder];
-  let assemblyFormat = "attr-dict `:` type($res)";
+  let assemblyFormat = "attr-dict `:` qualified(type($res))";
 }
 
 def LLVM_UndefOp : LLVM_Op<"mlir.undef", [Pure]>,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 237082b44479..08bcaaeaa14e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -1,4 +1,4 @@
-//===- LLVMDialect.h - MLIR LLVM dialect types ------------------*- C++ -*-===//
+//===- LLVMTypes.h - MLIR LLVM dialect types --------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -72,75 +72,6 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 
 #undef DEFINE_TRIVIAL_LLVM_TYPE
 
-//===----------------------------------------------------------------------===//
-// LLVMPointerType.
-//===----------------------------------------------------------------------===//
-
-/// LLVM dialect pointer type. This type typically represents a reference to an
-/// object in memory. Pointers may be opaque or parameterized by the element
-/// type. Both opaque and non-opaque pointers are additionally parameterized by
-/// the address space.
-class LLVMPointerType
-    : public Type::TypeBase<
-          LLVMPointerType, Type, detail::LLVMPointerTypeStorage,
-          DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> {
-public:
-  /// Inherit base constructors.
-  using Base::Base;
-
-  /// Checks if the given type can have a pointer type pointing to it.
-  static bool isValidElementType(Type type);
-
-  /// Gets or creates an instance of LLVM dialect pointer type pointing to an
-  /// object of `pointee` type in the given address space. The pointer type is
-  /// created in the same context as `pointee`. If the pointee is not provided,
-  /// creates an opaque pointer in the given context and address space.
-  static LLVMPointerType get(MLIRContext *context, unsigned addressSpace = 0);
-  static LLVMPointerType get(Type pointee, unsigned addressSpace = 0);
-  static LLVMPointerType
-  getChecked(function_ref<InFlightDiagnostic()> emitError, Type pointee,
-             unsigned addressSpace = 0);
-  static LLVMPointerType
-  getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
-             unsigned addressSpace = 0);
-
-  /// Returns the pointed-to type. It may be null if the pointer is opaque.
-  Type getElementType() const;
-
-  /// Returns `true` if this type is the opaque pointer type, i.e., it has no
-  /// pointed-to type.
-  bool isOpaque() const;
-
-  /// Returns the address space of the pointer.
-  unsigned getAddressSpace() const;
-
-  /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type pointee, unsigned);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              MLIRContext *context, unsigned) {
-    return success();
-  }
-
-  /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
-  /// DataLayout instance and query it instead.
-  unsigned getTypeSizeInBits(const DataLayout &dataLayout,
-                             DataLayoutEntryListRef params) const;
-  unsigned getABIAlignment(const DataLayout &dataLayout,
-                           DataLayoutEntryListRef params) const;
-  unsigned getPreferredAlignment(const DataLayout &dataLayout,
-                                 DataLayoutEntryListRef params) const;
-  bool areCompatible(DataLayoutEntryListRef oldLayout,
-                     DataLayoutEntryListRef newLayout) const;
-  LogicalResult verifyEntries(DataLayoutEntryListRef entries,
-                              Location loc) const;
-
-  void walkImmediateSubElements(function_ref<void(Attribute)> walkAttrsFn,
-                                function_ref<void(Type)> walkTypesFn) const;
-  Type replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
-                                   ArrayRef<Type> replTypes) const;
-};
-
 //===----------------------------------------------------------------------===//
 // LLVMStructType.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 6ddef17aa92a..960953866469 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -118,4 +118,53 @@ def LLVMFunctionType : LLVMType<"LLVMFunction", "func", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// LLVMPointerType
+//===----------------------------------------------------------------------===//
+
+def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
+    DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
+      "areCompatible", "verifyEntries"]>,
+    DeclareTypeInterfaceMethods<SubElementTypeInterface>]> {
+  let summary = "LLVM pointer type";
+  let description = [{
+    The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
+    a reference to an object in memory. Pointers may be opaque or parameterized
+    by the element type. Both opaque and non-opaque pointers are additionally
+    parameterized by the address space.
+
+    Example:
+
+    ```mlir
+    !llvm.ptr<i8>
+    !llvm.ptr
+    ```
+  }];
+
+  let parameters = (ins DefaultValuedParameter<"Type", "Type()">:$elementType,
+                        DefaultValuedParameter<"unsigned", "0">:$addressSpace);
+  let assemblyFormat = [{
+    (`<` custom<Pointer>($elementType, $addressSpace)^ `>`)?
+  }];
+
+  let genVerifyDecl = 1;
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins "Type":$elementType,
+                                         CArg<"unsigned", "0">:$addressSpace)>,
+    TypeBuilder<(ins CArg<"unsigned", "0">:$addressSpace), [{
+      return $_get($_ctxt, Type(), addressSpace);
+    }]>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns `true` if this type is the opaque pointer type, i.e., it has no
+    /// pointed-to type.
+    bool isOpaque() const { return !getElementType(); }
+
+    /// Checks if the given type can have a pointer type pointing to it.
+    static bool isValidElementType(Type type);
+  }];
+}
+
 #endif // LLVMTYPES_TD

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6851b23cc997..eea3516949ac 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -168,7 +168,7 @@ def NVVM_VoteBallotOp :
   let hasCustomAssemblyFormat = 1;
 }
 
-def NVVM_SyncWarpOp : 
+def NVVM_SyncWarpOp :
   NVVM_Op<"bar.warp.sync">,
   Arguments<(ins LLVM_Type:$mask)> {
   string llvmBuilder = [{
@@ -534,9 +534,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
   );
 }
 
-// Returns a list of operation suffixes corresponding to possible b1 
-// multiply-and-accumulate operations for all fragments which have a 
-// b1 type. For all other fragments, the list returned holds a list 
+// Returns a list of operation suffixes corresponding to possible b1
+// multiply-and-accumulate operations for all fragments which have a
+// b1 type. For all other fragments, the list returned holds a list
 // containing the empty string.
 class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
   list<string> ret = !cond(
@@ -555,7 +555,7 @@ class MMA_SYNC_NAME<string ALayout, string BLayout, string b1op, int Satfinite,
                 # "_" # ALayout
                 # "_" # BLayout
                 # !if(Satfinite, "_satfinite", "")
-                # signature;  
+                # signature;
 }
 
 /// Helper to create the mapping between the configuration and the mma.sync
@@ -572,13 +572,13 @@ class MMA_SYNC_INTR {
       "if (layoutA == \"" # layoutA #  "\" && layoutB == \"" # layoutB #  "\" && "
       "    m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k #
       "    && \"" # op[0].ptx_elt_type # "\" == eltypeA && \""
-       # op[1].ptx_elt_type # "\" == eltypeB && " 
+       # op[1].ptx_elt_type # "\" == eltypeB && "
        # " \"" # op[2].ptx_elt_type # "\" == eltypeC && "
        # " \"" # op[3].ptx_elt_type # "\" == eltypeD "
        # " && (sat.has_value()  ? " # sat # " == static_cast<int>(*sat) : true)"
        # !if(!ne(b1op, ""), " && (b1Op.has_value() ? MMAB1Op::" # b1op # " == b1Op.value() : true)", "") # ")\n"
        # "  return " #
-       MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";", 
+       MMA_SYNC_NAME<layoutA, layoutB, b1op, sat, op[0], op[1], op[2], op[3]>.id # ";",
           "") // if supported
           ) // b1op
         ) // sat
@@ -586,7 +586,7 @@ class MMA_SYNC_INTR {
     ) // layoutA
   ); // all_mma_sync_ops
   list<list<list<string>>> f1 = !foldl([[[""]]],
-                                  !foldl([[[[""]]]], cond0, acc, el, 
+                                  !foldl([[[[""]]]], cond0, acc, el,
                                       !listconcat(acc, el)),
                                     acc1, el1, !listconcat(acc1, el1));
   list<list<string>> f2 = !foldl([[""]], f1, acc1, el1, !listconcat(acc1, el1));
@@ -776,7 +776,10 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
     ```
   }];
 
-  let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)";
+  let assemblyFormat = [{
+    $ptr `,` $stride `,` $args attr-dict `:` qualified(type($ptr)) `,`
+    type($args)
+  }];
   let hasVerifier = 1;
 }
 
@@ -884,32 +887,32 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 
   let description = [{
     The `nvvm.mma.sync` operation collectively performs the operation
-    `D = matmul(A, B) + C` using all threads in a warp.    
+    `D = matmul(A, B) + C` using all threads in a warp.
 
     All the threads in the warp must execute the same `mma.sync` operation.
 
     For each possible multiplicand PTX data type, there are one or more possible
     instruction shapes given as "mMnNkK". The below table describes the posssibilities
-    as well as the types required for the operands. Note that the data type for 
-    C (the accumulator) and D (the result) can vary independently when there are 
+    as well as the types required for the operands. Note that the data type for
+    C (the accumulator) and D (the result) can vary independently when there are
     multiple possibilities in the "C/D Type" column.
 
     When an optional attribute cannot be immediately inferred from the types of
     the operands and the result during parsing or validation, an error will be
     raised.
 
-    `b1Op` is only relevant when the binary (b1) type is given to 
+    `b1Op` is only relevant when the binary (b1) type is given to
     `multiplicandDataType`. It specifies how the multiply-and-acumulate is
     performed and is either `xor_popc` or `and_poc`. The default is `xor_popc`.
 
-    `intOverflowBehavior` is only relevant when the `multiplicandType` attribute 
+    `intOverflowBehavior` is only relevant when the `multiplicandType` attribute
     is one of `u8, s8, u4, s4`, this attribute describes how overflow is handled
     in the accumulator. When the attribute is `satfinite`, the accumulator values
     are clamped in the int32 range on overflow. This is the default behavior.
-    Alternatively, accumulator behavior `wrapped` can also be specified, in 
+    Alternatively, accumulator behavior `wrapped` can also be specified, in
     which case overflow wraps from one end of the range to the other.
 
-    `layoutA` and `layoutB` are required and should generally be set to 
+    `layoutA` and `layoutB` are required and should generally be set to
     `#nvvm.mma_layout<row>` and `#nvvm.mma_layout<col>` respectively, but other
     combinations are possible for certain layouts according to the table below.
 
@@ -938,12 +941,12 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     Example:
     ```mlir
 
-    %128 = nvvm.mma.sync A[%120, %121, %122, %123]  
-                         B[%124, %125]  
-                         C[%126, %127]  
-                         {layoutA = #nvvm.mma_layout<row>, 
-                          layoutB = #nvvm.mma_layout<col>, 
-                          shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} 
+    %128 = nvvm.mma.sync A[%120, %121, %122, %123]
+                         B[%124, %125]
+                         C[%126, %127]
+                         {layoutA = #nvvm.mma_layout<row>,
+                          layoutB = #nvvm.mma_layout<col>,
+                          shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}}
         : (vector<2xf16>, vector<2xf16>, vector<2xf16>)
            -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
     ```
@@ -951,7 +954,7 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
 
   let results = (outs LLVM_AnyStruct:$res);
   let arguments = (ins NVVM_MMAShapeAttr:$shape,
-             OptionalAttr<MMAB1OpAttr>:$b1Op, 
+             OptionalAttr<MMAB1OpAttr>:$b1Op,
              OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
              MMALayoutAttr:$layoutA,
              MMALayoutAttr:$layoutB,
@@ -959,12 +962,12 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
              OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
              Variadic<LLVM_Type>:$operandA,
              Variadic<LLVM_Type>:$operandB,
-             Variadic<LLVM_Type>:$operandC); 
+             Variadic<LLVM_Type>:$operandC);
 
   let extraClassDeclaration = !strconcat([{
       static llvm::Intrinsic::ID getIntrinsicID(
             int64_t m, int64_t n, uint64_t k,
-            llvm::Optional<MMAB1Op> b1Op, 
+            llvm::Optional<MMAB1Op> b1Op,
             llvm::Optional<MMAIntOverflow> sat,
             mlir::NVVM::MMALayout layoutAEnum, mlir::NVVM::MMALayout layoutBEnum,
             mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
@@ -988,7 +991,7 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
     }]);
 
   let builders = [
-      OpBuilder<(ins  "Type":$resultType, "ValueRange":$operandA, 
+      OpBuilder<(ins  "Type":$resultType, "ValueRange":$operandA,
         "ValueRange":$operandB, "ValueRange":$operandC,
         "ArrayRef<int64_t>":$shape, "Optional<MMAB1Op>":$b1Op,
         "Optional<MMAIntOverflow>":$intOverflow,
@@ -999,12 +1002,12 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
   string llvmBuilder = [{
     auto operands = moduleTranslation.lookupValues(opInst.getOperands());
     auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
-        $shape.getM(), $shape.getN(), $shape.getK(), 
+        $shape.getM(), $shape.getN(), $shape.getK(),
         $b1Op, $intOverflowBehavior,
         $layoutA, $layoutB,
-        $multiplicandAPtxType.value(), 
+        $multiplicandAPtxType.value(),
         $multiplicandBPtxType.value(),
-        op.accumPtxType(), 
+        op.accumPtxType(),
         op.resultPtxType());
 
     $res = createIntrinsicCall(

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 71ba80d0ce7e..de87d8e34e79 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2573,7 +2573,6 @@ void LLVMDialect::initialize() {
            LLVMTokenType,
            LLVMLabelType,
            LLVMMetadataType,
-           LLVMPointerType,
            LLVMFixedVectorType,
            LLVMScalableVectorType,
            LLVMStructType>();

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 566ef63110b3..4a640221d999 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -124,20 +124,8 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
 
   printer << getTypeKeyword(type);
 
-  if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
-    if (ptrType.isOpaque()) {
-      if (ptrType.getAddressSpace() != 0)
-        printer << '<' << ptrType.getAddressSpace() << '>';
-      return;
-    }
-
-    printer << '<';
-    dispatchPrint(printer, ptrType.getElementType());
-    if (ptrType.getAddressSpace() != 0)
-      printer << ", " << ptrType.getAddressSpace();
-    printer << '>';
-    return;
-  }
+  if (auto ptrType = type.dyn_cast<LLVMPointerType>())
+    return ptrType.print(printer);
 
   if (auto arrayType = type.dyn_cast<LLVMArrayType>())
     return arrayType.print(printer);
@@ -164,37 +152,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) {
 
 static ParseResult dispatchParse(AsmParser &parser, Type &type);
 
-/// Parses an LLVM dialect pointer type.
-///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
-///               | `ptr` (`<` integer `>`)?
-static LLVMPointerType parsePointerType(AsmParser &parser) {
-  SMLoc loc = parser.getCurrentLocation();
-  Type elementType;
-  if (parser.parseOptionalLess()) {
-    return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
-                                              /*addressSpace=*/0);
-  }
-
-  unsigned addressSpace = 0;
-  OptionalParseResult opr = parser.parseOptionalInteger(addressSpace);
-  if (opr.has_value()) {
-    if (failed(*opr) || parser.parseGreater())
-      return LLVMPointerType();
-    return parser.getChecked<LLVMPointerType>(loc, parser.getContext(),
-                                              addressSpace);
-  }
-
-  if (dispatchParse(parser, elementType))
-    return LLVMPointerType();
-
-  if (succeeded(parser.parseOptionalComma()) &&
-      failed(parser.parseInteger(addressSpace)))
-    return LLVMPointerType();
-  if (failed(parser.parseGreater()))
-    return LLVMPointerType();
-  return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
-}
-
 /// Parses an LLVM dialect vector type.
 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
 /// Supports both fixed and scalable vectors.
@@ -391,7 +348,7 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) {
       .Case("label", [&] { return LLVMLabelType::get(ctx); })
       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
       .Case("func", [&] { return LLVMFunctionType::parse(parser); })
-      .Case("ptr", [&] { return parsePointerType(parser); })
+      .Case("ptr", [&] { return LLVMPointerType::parse(parser); })
       .Case("vec", [&] { return parseVectorType(parser); })
       .Case("array", [&] { return LLVMArrayType::parse(parser); })
       .Case("struct", [&] { return parseStructType(parser); })

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index f55d2ae45c45..85e439185d8c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -1,3 +1,4 @@
+//===- LLVMTypes.cpp - MLIR LLVM dialect types ------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -75,6 +76,41 @@ static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
   p << ')';
 }
 
+//===----------------------------------------------------------------------===//
+// custom<Pointer>
+//===----------------------------------------------------------------------===//
+
+static ParseResult parsePointer(AsmParser &p, FailureOr<Type> &elementType,
+                                FailureOr<unsigned> &addressSpace) {
+  addressSpace = 0;
+  // `<` addressSpace `>`
+  OptionalParseResult result = p.parseOptionalInteger(*addressSpace);
+  if (result.has_value()) {
+    if (failed(result.value()))
+      return failure();
+    elementType = Type();
+    return success();
+  }
+
+  if (parsePrettyLLVMType(p, elementType))
+    return failure();
+  if (succeeded(p.parseOptionalComma()))
+    return p.parseInteger(*addressSpace);
+
+  return success();
+}
+
+static void printPointer(AsmPrinter &p, Type elementType,
+                         unsigned addressSpace) {
+  if (elementType)
+    printPrettyLLVMType(p, elementType);
+  if (addressSpace != 0) {
+    if (elementType)
+      p << ", ";
+    p << addressSpace;
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // ODS-Generated Definitions
 //===----------------------------------------------------------------------===//
@@ -228,7 +264,7 @@ Type LLVMFunctionType::replaceImmediateSubElements(
 }
 
 //===----------------------------------------------------------------------===//
-// Pointer type.
+// LLVMPointerType
 //===----------------------------------------------------------------------===//
 
 bool LLVMPointerType::isValidElementType(Type type) {
@@ -246,32 +282,6 @@ LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
   return Base::get(pointee.getContext(), pointee, addressSpace);
 }
 
-LLVMPointerType LLVMPointerType::get(MLIRContext *context,
-                                     unsigned addressSpace) {
-  return Base::get(context, Type(), addressSpace);
-}
-
-LLVMPointerType
-LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
-                            Type pointee, unsigned addressSpace) {
-  return Base::getChecked(emitError, pointee.getContext(), pointee,
-                          addressSpace);
-}
-
-LLVMPointerType
-LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
-                            MLIRContext *context, unsigned addressSpace) {
-  return Base::getChecked(emitError, context, Type(), addressSpace);
-}
-
-Type LLVMPointerType::getElementType() const { return getImpl()->pointeeType; }
-
-bool LLVMPointerType::isOpaque() const { return !getImpl()->pointeeType; }
-
-unsigned LLVMPointerType::getAddressSpace() const {
-  return getImpl()->addressSpace;
-}
-
 LogicalResult
 LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
                         Type pointee, unsigned) {
@@ -280,6 +290,9 @@ LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DataLayoutTypeInterface
+
 constexpr const static unsigned kDefaultPointerSizeBits = 64;
 constexpr const static unsigned kDefaultPointerAlignment = 8;
 
@@ -426,6 +439,9 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// SubElementTypeInterface
+
 void LLVMPointerType::walkImmediateSubElements(
     function_ref<void(Attribute)> walkAttrsFn,
     function_ref<void(Type)> walkTypesFn) const {

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index d13452f6e0af..d30e94c1eca1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -321,33 +321,6 @@ struct LLVMStructTypeStorage : public TypeStorage {
   unsigned identifiedBodySizeAndFlags = 0;
 };
 
-//===----------------------------------------------------------------------===//
-// LLVMPointerTypeStorage.
-//===----------------------------------------------------------------------===//
-
-/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of
-/// element type and address space. The element type may be null indicating that
-/// the pointer is opaque.
-struct LLVMPointerTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<Type, unsigned>;
-
-  LLVMPointerTypeStorage(const KeyTy &key)
-      : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {}
-
-  static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator,
-                                           const KeyTy &key) {
-    return new (allocator.allocate<LLVMPointerTypeStorage>())
-        LLVMPointerTypeStorage(key);
-  }
-
-  bool operator==(const KeyTy &key) const {
-    return std::make_tuple(pointeeType, addressSpace) == key;
-  }
-
-  Type pointeeType;
-  unsigned addressSpace;
-};
-
 //===----------------------------------------------------------------------===//
 // LLVMTypeAndSizeStorage.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bf56f4fece53..1a50afa8c986 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -495,7 +495,7 @@ func.func @invalid_vector_type_5(%a : vector<4xf32>, %idx : i32) -> vector<4xf32
 // -----
 
 func.func @null_non_llvm_type() {
-  // expected-error at +1 {{custom op 'llvm.mlir.null' invalid kind of type specified}}
+  // expected-error at +1 {{'llvm.mlir.null' op result #0 must be LLVM pointer type, but got 'i32'}}
   llvm.mlir.null : i32
 }
 

diff  --git a/mlir/test/Dialect/LLVMIR/layout.mlir b/mlir/test/Dialect/LLVMIR/layout.mlir
index af471bb587cc..c2f162dd6175 100644
--- a/mlir/test/Dialect/LLVMIR/layout.mlir
+++ b/mlir/test/Dialect/LLVMIR/layout.mlir
@@ -82,12 +82,12 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<
     // CHECK: size = 8
     "test.data_layout_query"() : () -> !llvm.ptr<i8, 5>
     // CHECK: alignment = 4
-	// CHECK: bitsize = 32
+	  // CHECK: bitsize = 32
     // CHECK: preferred = 8
     // CHECK: size = 4
     "test.data_layout_query"() : () -> !llvm.ptr<3>
     // CHECK: alignment = 8
-	// CHECK: bitsize = 32
+	  // CHECK: bitsize = 32
     // CHECK: preferred = 8
     // CHECK: size = 4
 	"test.data_layout_query"() : () -> !llvm.ptr<4>


        


More information about the Mlir-commits mailing list