[Mlir-commits] [mlir] aec38c6 - [mlir] LLVMType: make getUnderlyingType private

Alex Zinenko llvmlistbot at llvm.org
Wed Jul 29 04:43:46 PDT 2020


Author: Alex Zinenko
Date: 2020-07-29T13:43:38+02:00
New Revision: aec38c619dfa1c41b6b0f6c52a31d221ac108b6b

URL: https://github.com/llvm/llvm-project/commit/aec38c619dfa1c41b6b0f6c52a31d221ac108b6b
DIFF: https://github.com/llvm/llvm-project/commit/aec38c619dfa1c41b6b0f6c52a31d221ac108b6b.diff

LOG: [mlir] LLVMType: make getUnderlyingType private

The current modeling of LLVM IR types in MLIR is based on the LLVMType class
that wraps a raw `llvm::Type *` and delegates uniquing, printing and parsing to
LLVM itself. This is model makes thread-safe type manipulation hard and is
being progressively replaced with a cleaner MLIR model that replicates the type
system. In the new model, LLVMType will no longer have an underlying LLVM IR
type. Restrict access to this type in the current model in preparation for the
change.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D84389

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 078cb1cfa4e5..52acfbfa8e50 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -47,6 +47,14 @@ struct LLVMTypeStorage;
 struct LLVMDialectImpl;
 } // namespace detail
 
+class LLVMType;
+
+/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function
+/// exists exclusively for the purpose of gradual transition to the first-party
+/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM
+/// translation.
+llvm::Type *convertLLVMType(LLVMType type);
+
 class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
                                              detail::LLVMTypeStorage> {
 public:
@@ -59,26 +67,32 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
   static bool kindof(unsigned kind) { return kind == LLVM_TYPE; }
 
   LLVMDialect &getDialect();
-  llvm::Type *getUnderlyingType() const;
 
   /// Utilities to identify types.
   bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); }
   bool isHalfTy() { return getUnderlyingType()->isHalfTy(); }
   bool isFloatTy() { return getUnderlyingType()->isFloatTy(); }
   bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); }
-  bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
-  bool isIntegerTy(unsigned bitwidth) {
-    return getUnderlyingType()->isIntegerTy(bitwidth);
-  }
+  bool isFloatingPointTy() { return getUnderlyingType()->isFloatingPointTy(); }
 
   /// Array type utilities.
   LLVMType getArrayElementType();
   unsigned getArrayNumElements();
   bool isArrayTy();
 
+  /// Integer type utilities.
+  unsigned getIntegerBitWidth() {
+    return getUnderlyingType()->getIntegerBitWidth();
+  }
+  bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); }
+  bool isIntegerTy(unsigned bitwidth) {
+    return getUnderlyingType()->isIntegerTy(bitwidth);
+  }
+
   /// Vector type utilities.
   LLVMType getVectorElementType();
   unsigned getVectorNumElements();
+  llvm::ElementCount getVectorElementCount();
   bool isVectorTy();
 
   /// Function type utilities.
@@ -86,11 +100,13 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
   unsigned getFunctionNumParams();
   LLVMType getFunctionResultType();
   bool isFunctionTy();
+  bool isFunctionVarArg();
 
   /// Pointer type utilities.
   LLVMType getPointerTo(unsigned addrSpace = 0);
   LLVMType getPointerElementTy();
   bool isPointerTy();
+  static bool isValidPointerElementType(LLVMType type);
 
   /// Struct type utilities.
   LLVMType getStructElementType(unsigned i);
@@ -194,6 +210,14 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
 
 private:
   friend LLVMDialect;
+  friend llvm::Type *convertLLVMType(LLVMType type);
+
+  /// Get the underlying LLVM IR type.
+  llvm::Type *getUnderlyingType() const;
+
+  /// Get the underlying LLVM IR types for the given array of types.
+  static void getUnderlyingTypes(ArrayRef<LLVMType> types,
+                                 SmallVectorImpl<llvm::Type *> &result);
 
   /// Get an LLVMType with a pre-existing llvm type.
   static LLVMType get(MLIRContext *context, llvm::Type *llvmType);

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index d88b372dbf43..4d99bf265c65 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -134,11 +134,9 @@ class ListIntSubst<string pattern, list<int> values> {
 // or result in the operation.
 def LLVM_IntrPatterns {
   string operand =
-    [{opInst.getOperand($0).getType()
-      .cast<LLVM::LLVMType>().getUnderlyingType()}];
+    [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
   string result =
-    [{opInst.getResult($0).getType()
-      .cast<LLVM::LLVMType>().getUnderlyingType()}];
+    [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
 }
 
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 29d7fd930030..4da90575524b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -61,9 +61,8 @@ def LLVM_VoidResultTypeOpBuilder : OpBuilder<
   [{
     auto llvmType = resultType.dyn_cast<LLVM::LLVMType>(); (void)llvmType;
     assert(llvmType && "result must be an LLVM type");
-    assert(llvmType.getUnderlyingType() &&
-            llvmType.getUnderlyingType()->isVoidTy() &&
-            "for zero-result operands, only 'void' is accepted as result type");
+    assert(llvmType.isVoidTy() &&
+           "for zero-result operands, only 'void' is accepted as result type");
     build(builder, result, operands, attributes);
   }]>;
 
@@ -477,7 +476,7 @@ def LLVM_ShuffleVectorOp
   let verifier = [{
     auto wrappedVectorType1 = v1().getType().cast<LLVM::LLVMType>();
     auto wrappedVectorType2 = v2().getType().cast<LLVM::LLVMType>();
-    if (!wrappedVectorType2.getUnderlyingType()->isVectorTy())
+    if (!wrappedVectorType2.isVectorTy())
       return emitOpError("expected LLVM IR Dialect vector type for operand #2");
     if (wrappedVectorType1.getVectorElementType() !=
         wrappedVectorType2.getVectorElementType())
@@ -765,10 +764,10 @@ def LLVM_LLVMFuncOp
           .getValue().cast<LLVMType>();
     }
     bool isVarArg() {
-      return getType().getUnderlyingType()->isFunctionVarArg();
+      return getType().isFunctionVarArg();
     }
 
-    // Hook for OpTrait::FunctionLike, returns the number of function arguments.
+    // Hook for OpTrait::FunctionLike, returns the number of function arguments`.
     // Depends on the type attribute being correct as checked by verifyType.
     unsigned getNumFuncArguments();
 

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 786b9ef217bd..0cd11690daa8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -139,7 +139,7 @@ def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">;
 // Vector buffer load/store intrinsics
 
 def ROCDL_MubufLoadOp :
-  ROCDL_Op<"buffer.load">, 
+  ROCDL_Op<"buffer.load">,
   Results<(outs LLVM_Type:$res)>,
   Arguments<(ins LLVM_Type:$rsrc,
                  LLVM_Type:$vindex,
@@ -160,7 +160,7 @@ def ROCDL_MubufLoadOp :
 }
 
 def ROCDL_MubufStoreOp :
-  ROCDL_Op<"buffer.store">, 
+  ROCDL_Op<"buffer.store">,
   Arguments<(ins LLVM_Type:$vdata,
                  LLVM_Type:$rsrc,
                  LLVM_Type:$vindex,
@@ -168,14 +168,13 @@ def ROCDL_MubufStoreOp :
                  LLVM_Type:$glc,
                  LLVM_Type:$slc)>{
   string llvmBuilder = [{
-    auto vdataType = op.vdata().getType().cast<LLVM::LLVMType>()
-                       .getUnderlyingType();
+    auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
     createIntrinsicCall(builder,
-          llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, 
+          llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
           $offset, $glc, $slc}, {vdataType});
   }];
   let parser = [{ return parseROCDLMubufStoreOp(parser, result); }];
-  let printer = [{ 
+  let printer = [{
     Operation *op = this->getOperation();
     p << op->getName() << " " << op->getOperands()
       << " : " << vdata().getType();

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index e44ae976e0dd..61f8f9fce64c 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -89,6 +89,10 @@ class ModuleTranslation {
                                             llvm::IRBuilder<> &builder);
   virtual LogicalResult convertOmpParallel(Operation &op,
                                            llvm::IRBuilder<> &builder);
+
+  /// Converts the type from MLIR LLVM dialect to LLVM.
+  llvm::Type *convertType(LLVMType type);
+
   static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
 
   /// A helper to look up remapped operands in the value remapping table.

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
index 25a3ac07d5f4..fd0e96b79d2b 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp
@@ -64,10 +64,8 @@ static unsigned getBitWidth(Type type) {
 
 /// Returns the bit width of LLVMType integer or vector.
 static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
-  return type.isVectorTy() ? type.getVectorElementType()
-                                 .getUnderlyingType()
-                                 ->getIntegerBitWidth()
-                           : type.getUnderlyingType()->getIntegerBitWidth();
+  return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth()
+                           : type.getIntegerBitWidth();
 }
 
 /// Creates `IntegerAttribute` with all bits set for given type

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 1e6fa6a8754b..0d154796f049 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -2248,10 +2248,8 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
         op, operands, typeConverter,
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {cast<llvm::FixedVectorType>(llvmVectorTy.getUnderlyingType())
-                       ->getNumElements()},
-                  floatType),
+              mlir::VectorType::get({llvmVectorTy.getVectorNumElements()},
+                                    floatType),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);
@@ -2511,8 +2509,8 @@ struct IndexCastOpLowering : public ConvertOpToLLVMPattern<IndexCastOp> {
         this->typeConverter.convertType(indexCastOp.getResult().getType())
             .cast<LLVM::LLVMType>();
     auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
-    unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
-    unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
+    unsigned targetBits = targetType.getIntegerBitWidth();
+    unsigned sourceBits = sourceType.getIntegerBitWidth();
 
     if (targetBits == sourceBits)
       rewriter.replaceOp(op, transformed.in());

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5dbc8394b03a..4fa7b573f84e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -127,7 +127,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
 
   auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout();
   align = dataLayout.getPrefTypeAlignment(
-      elementTy.cast<LLVM::LLVMType>().getUnderlyingType());
+      LLVM::convertLLVMType(elementTy.cast<LLVM::LLVMType>()));
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index cf7a5d926528..17848c6bf3ee 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -105,11 +105,9 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
   auto argType = type.dyn_cast<LLVM::LLVMType>();
   if (!argType)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
-  if (argType.getUnderlyingType()->isVectorTy())
-    resultType = LLVMType::getVectorTy(
-        resultType,
-        llvm::cast<llvm::FixedVectorType>(argType.getUnderlyingType())
-            ->getNumElements());
+  if (argType.isVectorTy())
+    resultType =
+        LLVMType::getVectorTy(resultType, argType.getVectorNumElements());
 
   result.addTypes({resultType});
   return success();
@@ -214,7 +212,7 @@ static Type getLoadStoreElementType(OpAsmParser &parser, Type type,
   if (!llvmTy)
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
            nullptr;
-  if (!llvmTy.getUnderlyingType()->isPointerTy())
+  if (!llvmTy.isPointerTy())
     return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"),
            nullptr;
   return llvmTy.getPointerElementTy();
@@ -683,8 +681,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
       parser.resolveOperand(position, positionType, result.operands))
     return failure();
   auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType ||
-      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   result.addTypes(wrappedVectorType.getVectorElementType());
@@ -725,16 +722,15 @@ static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
                               "expected an array of integer literals"),
              nullptr;
     int position = positionElementAttr.getInt();
-    auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
-    if (llvmContainerType->isArrayTy()) {
+    if (wrappedContainerType.isArrayTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
-                              llvmContainerType->getArrayNumElements())
+                              wrappedContainerType.getArrayNumElements())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
       wrappedContainerType = wrappedContainerType.getArrayElementType();
-    } else if (llvmContainerType->isStructTy()) {
+    } else if (wrappedContainerType.isStructTy()) {
       if (position < 0 || static_cast<unsigned>(position) >=
-                              llvmContainerType->getStructNumElements())
+                              wrappedContainerType.getStructNumElements())
         return parser.emitError(attributeLoc, "position out of bounds"),
                nullptr;
       wrappedContainerType =
@@ -803,8 +799,7 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
     return failure();
 
   auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedVectorType ||
-      !wrappedVectorType.getUnderlyingType()->isVectorTy())
+  if (!wrappedVectorType || !wrappedVectorType.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   auto valueType = wrappedVectorType.getVectorElementType();
@@ -1125,7 +1120,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static LogicalResult verify(GlobalOp op) {
-  if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
+  if (!LLVMType::isValidPointerElementType(op.getType()))
     return op.emitOpError(
         "expects type to be a valid element type for an LLVM pointer");
   if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp()))
@@ -1133,8 +1128,7 @@ static LogicalResult verify(GlobalOp op) {
 
   if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
     auto type = op.getType();
-    if (!type.getUnderlyingType()->isArrayTy() ||
-        !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
+    if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) ||
         type.getArrayNumElements() != strAttr.getValue().size())
       return op.emitOpError(
           "requires an i8 array type of the length equal to that of the string "
@@ -1197,8 +1191,7 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
       parser.resolveOperand(v2, typeV2, result.operands))
     return failure();
   auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedContainerType1 ||
-      !wrappedContainerType1.getUnderlyingType()->isVectorTy())
+  if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy())
     return parser.emitError(
         loc, "expected LLVM IR dialect vector type for operand #1");
   auto vType = LLVMType::getVectorTy(
@@ -1239,7 +1232,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
   if (argAttrs.empty())
     return;
 
-  unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
+  unsigned numInputs = type.getFunctionNumParams();
   assert(numInputs == argAttrs.size() &&
          "expected as many argument attribute lists as arguments");
   SmallString<8> argAttrName;
@@ -1374,7 +1367,7 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
 // getNumArguments hook not failing.
 LogicalResult LLVMFuncOp::verifyType() {
   auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
-  if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
+  if (!llvmType || !llvmType.isFunctionTy())
     return emitOpError("requires '" + getTypeAttrName() +
                        "' attribute of wrapped LLVM function type");
 
@@ -1384,7 +1377,7 @@ LogicalResult LLVMFuncOp::verifyType() {
 // Hook for OpTrait::FunctionLike, returns the number of function arguments.
 // Depends on the type attribute being correct as checked by verifyType
 unsigned LLVMFuncOp::getNumFuncArguments() {
-  return getType().getUnderlyingType()->getFunctionNumParams();
+  return getType().getFunctionNumParams();
 }
 
 // Hook for OpTrait::FunctionLike, returns the number of function results.
@@ -1424,8 +1417,7 @@ static LogicalResult verify(LLVMFuncOp op) {
   if (op.isVarArg())
     return op.emitOpError("only external functions can be variadic");
 
-  auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
-  unsigned numArguments = funcType->getNumParams();
+  unsigned numArguments = op.getType().getFunctionNumParams();
   Block &entryBlock = op.front();
   for (unsigned i = 0; i < numArguments; ++i) {
     Type argType = entryBlock.getArgument(i).getType();
@@ -1433,7 +1425,7 @@ static LogicalResult verify(LLVMFuncOp op) {
     if (!argLLVMType)
       return op.emitOpError("entry block argument #")
              << i << " is not of LLVM type";
-    if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
+    if (op.getType().getFunctionParamType(i) != argLLVMType)
       return op.emitOpError("the type of entry block argument #")
              << i << " does not match the function signature";
   }
@@ -1566,7 +1558,7 @@ static LogicalResult verify(AtomicRMWOp op) {
     return op.emitOpError(
         "expected LLVM IR result type to match type for operand #1");
   if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
-    if (!valType.getUnderlyingType()->isFloatingPointTy())
+    if (!valType.isFloatingPointTy())
       return op.emitOpError("expected LLVM IR floating point type");
   } else if (op.bin_op() == AtomicBinOp::xchg) {
     if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
@@ -1842,6 +1834,13 @@ llvm::Type *LLVMType::getUnderlyingType() const {
   return getImpl()->underlyingType;
 }
 
+void LLVMType::getUnderlyingTypes(ArrayRef<LLVMType> types,
+                                  SmallVectorImpl<llvm::Type *> &result) {
+  result.reserve(result.size() + types.size());
+  for (LLVMType ty : types)
+    result.push_back(ty.getUnderlyingType());
+}
+
 /// Array type utilities.
 LLVMType LLVMType::getArrayElementType() {
   return get(getContext(), getUnderlyingType()->getArrayElementType());
@@ -1861,6 +1860,9 @@ unsigned LLVMType::getVectorNumElements() {
   return llvm::cast<llvm::FixedVectorType>(getUnderlyingType())
       ->getNumElements();
 }
+llvm::ElementCount LLVMType::getVectorElementCount() {
+  return llvm::cast<llvm::VectorType>(getUnderlyingType())->getElementCount();
+}
 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
 
 /// Function type utilities.
@@ -1876,6 +1878,9 @@ LLVMType LLVMType::getFunctionResultType() {
       llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
 }
 bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); }
+bool LLVMType::isFunctionVarArg() {
+  return getUnderlyingType()->isFunctionVarArg();
+}
 
 /// Pointer type utilities.
 LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
@@ -1888,6 +1893,9 @@ LLVMType LLVMType::getPointerElementTy() {
   return get(getContext(), getUnderlyingType()->getPointerElementType());
 }
 bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); }
+bool LLVMType::isValidPointerElementType(LLVMType type) {
+  return llvm::PointerType::isValidElementType(type.getUnderlyingType());
+}
 
 /// Struct type utilities.
 LLVMType LLVMType::getStructElementType(unsigned i) {
@@ -1974,18 +1982,12 @@ LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
                                  isPacked);
   });
 }
-inline static SmallVector<llvm::Type *, 8>
-toUnderlyingTypes(ArrayRef<LLVMType> elements) {
-  SmallVector<llvm::Type *, 8> llvmElements;
-  for (auto elt : elements)
-    llvmElements.push_back(elt.getUnderlyingType());
-  return llvmElements;
-}
 LLVMType LLVMType::createStructTy(LLVMDialect *dialect,
                                   ArrayRef<LLVMType> elements,
                                   Optional<StringRef> name, bool isPacked) {
   StringRef sr = name.hasValue() ? *name : "";
-  SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+  SmallVector<llvm::Type *, 8> llvmElements;
+  getUnderlyingTypes(elements, llvmElements);
   return getLocked(dialect, [=] {
     auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr);
     if (!llvmElements.empty())
@@ -1997,7 +1999,8 @@ LLVMType LLVMType::setStructTyBody(LLVMType structType,
                                    ArrayRef<LLVMType> elements, bool isPacked) {
   llvm::StructType *st =
       llvm::cast<llvm::StructType>(structType.getUnderlyingType());
-  SmallVector<llvm::Type *, 8> llvmElements(toUnderlyingTypes(elements));
+  SmallVector<llvm::Type *, 8> llvmElements;
+  getUnderlyingTypes(elements, llvmElements);
   return getLocked(&structType.getDialect(), [=] {
     st->setBody(llvmElements, isPacked);
     return st;
@@ -2017,6 +2020,10 @@ LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
 
 bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); }
 
+llvm::Type *mlir::LLVM::convertLLVMType(LLVMType type) {
+  return type.getUnderlyingType();
+}
+
 //===----------------------------------------------------------------------===//
 // Utility functions.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 9754c614efdf..77897d65e1a5 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -234,18 +234,17 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
     return nullptr;
 
   if (type.isIntegerTy())
-    return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth());
+    return b.getIntegerType(type.getIntegerBitWidth());
 
-  if (type.getUnderlyingType()->isFloatTy())
+  if (type.isFloatTy())
     return b.getF32Type();
 
-  if (type.getUnderlyingType()->isDoubleTy())
+  if (type.isDoubleTy())
     return b.getF64Type();
 
   // LLVM vectors can only contain scalars.
   if (type.isVectorTy()) {
-    auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
-                           ->getElementCount();
+    auto numElements = type.getVectorElementCount();
     if (numElements.Scalable) {
       emitError(unknownLoc) << "scalable vectors not supported";
       return nullptr;
@@ -270,9 +269,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
     // attribute type.
     if (type.getArrayElementType().isVectorTy()) {
       LLVMType vectorType = type.getArrayElementType();
-      auto numElements =
-          llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
-              ->getElementCount();
+      auto numElements = vectorType.getVectorElementCount();
       if (numElements.Scalable) {
         emitError(unknownLoc) << "scalable vectors not supported";
         return nullptr;

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 3a70dd3932e9..a0aefc988a5d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -574,7 +574,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   }
 
   if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
-    llvm::Type *ty = lpOp.getType().dyn_cast<LLVMType>().getUnderlyingType();
+    llvm::Type *ty = convertType(lpOp.getType().cast<LLVMType>());
     llvm::LandingPadInst *lpi =
         builder.CreateLandingPad(ty, lpOp.getNumOperands());
 
@@ -661,7 +661,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
       if (!wrappedType)
         return emitError(bb.front().getLoc(),
                          "block argument does not have an LLVM type");
-      llvm::Type *type = wrappedType.getUnderlyingType();
+      llvm::Type *type = convertType(wrappedType);
       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
       valueMapping[arg] = phi;
     }
@@ -687,7 +687,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
   llvm::sys::SmartScopedLock<true> scopedLock(
       llvmDialect->getLLVMContextMutex());
   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
-    llvm::Type *type = op.getType().getUnderlyingType();
+    llvm::Type *type = convertType(op.getType());
     llvm::Constant *cst = llvm::UndefValue::get(type);
     if (op.getValueOrNull()) {
       // String attributes are treated separately because they cannot appear as
@@ -826,7 +826,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be boolean, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.getUnderlyingType()->isPointerTy())
+      if (!argTy.isPointerTy())
         return func.emitError(
             "llvm.noalias attribute attached to LLVM non-pointer argument");
       if (attr.getValue())
@@ -837,7 +837,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
       // NB: Attribute already verified to be int, so check if we can indeed
       // attach the attribute to this argument, based on its type.
       auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.getUnderlyingType()->isPointerTy())
+      if (!argTy.isPointerTy())
         return func.emitError(
             "llvm.align attribute attached to LLVM non-pointer argument");
       llvmArg.addAttrs(
@@ -896,7 +896,7 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
     llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
         function.getName(),
-        cast<llvm::FunctionType>(function.getType().getUnderlyingType()));
+        cast<llvm::FunctionType>(convertType(function.getType())));
     llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
     llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage()));
     functionMapping[function.getName()] = llvmFunc;
@@ -928,6 +928,10 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
+llvm::Type *ModuleTranslation::convertType(LLVMType type) {
+  return LLVM::convertLLVMType(type);
+}
+
 /// A helper to look up remapped operands in the value remapping table.`
 SmallVector<llvm::Value *, 8>
 ModuleTranslation::lookupValues(ValueRange values) {

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 9edbad8fdd54..f62e7aebe24a 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -135,8 +135,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
     } else if (isResultName(op, name)) {
       bs << formatv("valueMapping[op.{0}()]", name);
     } else if (name == "_resultType") {
-      bs << "op.getResult().getType().cast<LLVM::LLVMType>()."
-            "getUnderlyingType()";
+      bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
     } else if (name == "_hasResult") {
       bs << "opInst.getNumResults() == 1";
     } else if (name == "_location") {


        


More information about the Mlir-commits mailing list