[Mlir-commits] [mlir] 6abd7e2 - [mlir] provide same APIs as existing LLVMType in the new LLVM type modeling

Alex Zinenko llvmlistbot at llvm.org
Tue Aug 4 04:49:22 PDT 2020


Author: Alex Zinenko
Date: 2020-08-04T13:49:14+02:00
New Revision: 6abd7e2e622bc7eabdb673a7815f6673523a1e94

URL: https://github.com/llvm/llvm-project/commit/6abd7e2e622bc7eabdb673a7815f6673523a1e94
DIFF: https://github.com/llvm/llvm-project/commit/6abd7e2e622bc7eabdb673a7815f6673523a1e94.diff

LOG: [mlir] provide same APIs as existing LLVMType in the new LLVM type modeling

These are intended to smoothen the transition and may be removed in the future
in favor of more MLIR-compatible APIs. They intentionally have the same
semantics as the existing functions, which must remain stable until the
transition is complete.

Depends On D85019

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D85020

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 6764f9815c3f..e409d6880283 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -26,6 +26,8 @@ class DialectAsmParser;
 class DialectAsmPrinter;
 
 namespace LLVM {
+class LLVMDialect;
+
 namespace detail {
 struct LLVMFunctionTypeStorage;
 struct LLVMIntegerTypeStorage;
@@ -34,6 +36,12 @@ struct LLVMStructTypeStorage;
 struct LLVMTypeAndSizeStorage;
 } // namespace detail
 
+class LLVMBFloatType;
+class LLVMHalfType;
+class LLVMFloatType;
+class LLVMDoubleType;
+class LLVMIntegerType;
+
 //===----------------------------------------------------------------------===//
 // LLVMTypeNew.
 //===----------------------------------------------------------------------===//
@@ -96,6 +104,150 @@ class LLVMTypeNew : public Type {
   static bool kindof(unsigned kind) {
     return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE;
   }
+
+  LLVMDialect &getDialect();
+
+  /// Floating-point type utilities.
+  bool isBFloatTy() { return isa<LLVMBFloatType>(); }
+  bool isHalfTy() { return isa<LLVMHalfType>(); }
+  bool isFloatTy() { return isa<LLVMFloatType>(); }
+  bool isDoubleTy() { return isa<LLVMDoubleType>(); }
+  bool isFloatingPointTy() {
+    return isa<LLVMHalfType>() || isa<LLVMBFloatType>() ||
+           isa<LLVMFloatType>() || isa<LLVMDoubleType>();
+  }
+
+  /// Array type utilities.
+  LLVMTypeNew getArrayElementType();
+  unsigned getArrayNumElements();
+  bool isArrayTy();
+
+  /// Integer type utilities.
+  bool isIntegerTy() { return isa<LLVMIntegerType>(); }
+  bool isIntegerTy(unsigned bitwidth);
+  unsigned getIntegerBitWidth();
+
+  /// Vector type utilities.
+  LLVMTypeNew getVectorElementType();
+  unsigned getVectorNumElements();
+  llvm::ElementCount getVectorElementCount();
+  bool isVectorTy();
+
+  /// Function type utilities.
+  LLVMTypeNew getFunctionParamType(unsigned argIdx);
+  unsigned getFunctionNumParams();
+  LLVMTypeNew getFunctionResultType();
+  bool isFunctionTy();
+  bool isFunctionVarArg();
+
+  /// Pointer type utilities.
+  LLVMTypeNew getPointerTo(unsigned addrSpace = 0);
+  LLVMTypeNew getPointerElementTy();
+  bool isPointerTy();
+  static bool isValidPointerElementType(LLVMTypeNew type);
+
+  /// Struct type utilities.
+  LLVMTypeNew getStructElementType(unsigned i);
+  unsigned getStructNumElements();
+  bool isStructTy();
+
+  /// Utilities used to generate floating point types.
+  static LLVMTypeNew getDoubleTy(LLVMDialect *dialect);
+  static LLVMTypeNew getFloatTy(LLVMDialect *dialect);
+  static LLVMTypeNew getBFloatTy(LLVMDialect *dialect);
+  static LLVMTypeNew getHalfTy(LLVMDialect *dialect);
+  static LLVMTypeNew getFP128Ty(LLVMDialect *dialect);
+  static LLVMTypeNew getX86_FP80Ty(LLVMDialect *dialect);
+
+  /// Utilities used to generate integer types.
+  static LLVMTypeNew getIntNTy(LLVMDialect *dialect, unsigned numBits);
+  static LLVMTypeNew getInt1Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/1);
+  }
+  static LLVMTypeNew getInt8Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/8);
+  }
+  static LLVMTypeNew getInt8PtrTy(LLVMDialect *dialect) {
+    return getInt8Ty(dialect).getPointerTo();
+  }
+  static LLVMTypeNew getInt16Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/16);
+  }
+  static LLVMTypeNew getInt32Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/32);
+  }
+  static LLVMTypeNew getInt64Ty(LLVMDialect *dialect) {
+    return getIntNTy(dialect, /*numBits=*/64);
+  }
+
+  /// Utilities used to generate other miscellaneous types.
+  static LLVMTypeNew getArrayTy(LLVMTypeNew elementType, uint64_t numElements);
+  static LLVMTypeNew getFunctionTy(LLVMTypeNew result,
+                                   ArrayRef<LLVMTypeNew> params, bool isVarArg);
+  static LLVMTypeNew getFunctionTy(LLVMTypeNew result, bool isVarArg) {
+    return getFunctionTy(result, llvm::None, isVarArg);
+  }
+  static LLVMTypeNew getStructTy(LLVMDialect *dialect,
+                                 ArrayRef<LLVMTypeNew> elements,
+                                 bool isPacked = false);
+  static LLVMTypeNew getStructTy(LLVMDialect *dialect, bool isPacked = false) {
+    return getStructTy(dialect, llvm::None, isPacked);
+  }
+  template <typename... Args>
+  static typename std::enable_if<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                 LLVMTypeNew>::type
+  getStructTy(LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    return getStructTy(&elt1.getDialect(), fields);
+  }
+  static LLVMTypeNew getVectorTy(LLVMTypeNew elementType, unsigned numElements);
+
+  /// Void type utilities.
+  static LLVMTypeNew getVoidTy(LLVMDialect *dialect);
+  bool isVoidTy();
+
+  // Creation and setting of LLVM's identified struct types
+  static LLVMTypeNew createStructTy(LLVMDialect *dialect,
+                                    ArrayRef<LLVMTypeNew> elements,
+                                    Optional<StringRef> name,
+                                    bool isPacked = false);
+
+  static LLVMTypeNew createStructTy(LLVMDialect *dialect,
+                                    Optional<StringRef> name) {
+    return createStructTy(dialect, llvm::None, name);
+  }
+
+  static LLVMTypeNew createStructTy(ArrayRef<LLVMTypeNew> elements,
+                                    Optional<StringRef> name,
+                                    bool isPacked = false) {
+    assert(!elements.empty() &&
+           "This method may not be invoked with an empty list");
+    LLVMTypeNew ele0 = elements.front();
+    return createStructTy(&ele0.getDialect(), elements, name, isPacked);
+  }
+
+  template <typename... Args>
+  static
+      typename std::enable_if_t<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                LLVMTypeNew>
+      createStructTy(StringRef name, LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    Optional<StringRef> opt_name(name);
+    return createStructTy(&elt1.getDialect(), fields, opt_name);
+  }
+
+  static LLVMTypeNew setStructTyBody(LLVMTypeNew structType,
+                                     ArrayRef<LLVMTypeNew> elements,
+                                     bool isPacked = false);
+
+  template <typename... Args>
+  static
+      typename std::enable_if_t<llvm::are_base_of<LLVMTypeNew, Args...>::value,
+                                LLVMTypeNew>
+      setStructTyBody(LLVMTypeNew structType, LLVMTypeNew elt1, Args... elts) {
+    SmallVector<LLVMTypeNew, 8> fields({elt1, elts...});
+    return setStructTyBody(structType, fields);
+  }
 };
 
 //===----------------------------------------------------------------------===//
@@ -323,6 +475,9 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMTypeNew,
   /// Checks if a struct is opaque.
   bool isOpaque();
 
+  /// Checks if a struct is initialized.
+  bool isInitialized();
+
   /// Returns the name of an identified struct.
   StringRef getName();
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 3540091e90a3..abecbccb1d4a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -13,6 +13,7 @@
 
 #include "TypeDetail.h"
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeSupport.h"
@@ -22,6 +23,213 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
+//===----------------------------------------------------------------------===//
+// LLVMTypeNew.
+//===----------------------------------------------------------------------===//
+
+// TODO: when these types are registered with the LLVMDialect, this method
+// should be removed and the regular Type::getDialect should just work.
+LLVMDialect &LLVMTypeNew::getDialect() {
+  return *getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+}
+
+//----------------------------------------------------------------------------//
+// Integer type utilities.
+
+bool LLVMTypeNew::isIntegerTy(unsigned bitwidth) {
+  if (auto intType = dyn_cast<LLVMIntegerType>())
+    return intType.getBitWidth() == bitwidth;
+  return false;
+}
+
+unsigned LLVMTypeNew::getIntegerBitWidth() {
+  return cast<LLVMIntegerType>().getBitWidth();
+}
+
+LLVMTypeNew LLVMTypeNew::getArrayElementType() {
+  return cast<LLVMArrayType>().getElementType();
+}
+
+//----------------------------------------------------------------------------//
+// Array type utilities.
+
+unsigned LLVMTypeNew::getArrayNumElements() {
+  return cast<LLVMArrayType>().getNumElements();
+}
+
+bool LLVMTypeNew::isArrayTy() { return isa<LLVMArrayType>(); }
+
+//----------------------------------------------------------------------------//
+// Vector type utilities.
+
+LLVMTypeNew LLVMTypeNew::getVectorElementType() {
+  return cast<LLVMVectorType>().getElementType();
+}
+
+unsigned LLVMTypeNew::getVectorNumElements() {
+  return cast<LLVMFixedVectorType>().getNumElements();
+}
+llvm::ElementCount LLVMTypeNew::getVectorElementCount() {
+  return cast<LLVMVectorType>().getElementCount();
+}
+
+bool LLVMTypeNew::isVectorTy() { return isa<LLVMVectorType>(); }
+
+//----------------------------------------------------------------------------//
+// Function type utilities.
+
+LLVMTypeNew LLVMTypeNew::getFunctionParamType(unsigned argIdx) {
+  return cast<LLVMFunctionType>().getParamType(argIdx);
+}
+
+unsigned LLVMTypeNew::getFunctionNumParams() {
+  return cast<LLVMFunctionType>().getNumParams();
+}
+
+LLVMTypeNew LLVMTypeNew::getFunctionResultType() {
+  return cast<LLVMFunctionType>().getReturnType();
+}
+
+bool LLVMTypeNew::isFunctionTy() { return isa<LLVMFunctionType>(); }
+
+bool LLVMTypeNew::isFunctionVarArg() {
+  return cast<LLVMFunctionType>().isVarArg();
+}
+
+//----------------------------------------------------------------------------//
+// Pointer type utilities.
+
+LLVMTypeNew LLVMTypeNew::getPointerTo(unsigned addrSpace) {
+  return LLVMPointerType::get(*this, addrSpace);
+}
+
+LLVMTypeNew LLVMTypeNew::getPointerElementTy() {
+  return cast<LLVMPointerType>().getElementType();
+}
+
+bool LLVMTypeNew::isPointerTy() { return isa<LLVMPointerType>(); }
+
+bool LLVMTypeNew::isValidPointerElementType(LLVMTypeNew type) {
+  return !type.isa<LLVMVoidType>() && !type.isa<LLVMTokenType>() &&
+         !type.isa<LLVMMetadataType>() && !type.isa<LLVMLabelType>();
+}
+
+//----------------------------------------------------------------------------//
+// Struct type utilities.
+
+LLVMTypeNew LLVMTypeNew::getStructElementType(unsigned i) {
+  return cast<LLVMStructType>().getBody()[i];
+}
+
+unsigned LLVMTypeNew::getStructNumElements() {
+  return cast<LLVMStructType>().getBody().size();
+}
+
+bool LLVMTypeNew::isStructTy() { return isa<LLVMStructType>(); }
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate floating point types.
+
+LLVMTypeNew LLVMTypeNew::getDoubleTy(LLVMDialect *dialect) {
+  return LLVMDoubleType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getFloatTy(LLVMDialect *dialect) {
+  return LLVMFloatType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getBFloatTy(LLVMDialect *dialect) {
+  return LLVMBFloatType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getHalfTy(LLVMDialect *dialect) {
+  return LLVMHalfType::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getFP128Ty(LLVMDialect *dialect) {
+  return LLVMFP128Type::get(dialect->getContext());
+}
+
+LLVMTypeNew LLVMTypeNew::getX86_FP80Ty(LLVMDialect *dialect) {
+  return LLVMX86FP80Type::get(dialect->getContext());
+}
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate integer types.
+
+LLVMTypeNew LLVMTypeNew::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+  return LLVMIntegerType::get(dialect->getContext(), numBits);
+}
+
+//----------------------------------------------------------------------------//
+// Utilities used to generate other miscellaneous types.
+
+LLVMTypeNew LLVMTypeNew::getArrayTy(LLVMTypeNew elementType,
+                                    uint64_t numElements) {
+  return LLVMArrayType::get(elementType, numElements);
+}
+
+LLVMTypeNew LLVMTypeNew::getFunctionTy(LLVMTypeNew result,
+                                       ArrayRef<LLVMTypeNew> params,
+                                       bool isVarArg) {
+  return LLVMFunctionType::get(result, params, isVarArg);
+}
+
+LLVMTypeNew LLVMTypeNew::getStructTy(LLVMDialect *dialect,
+                                     ArrayRef<LLVMTypeNew> elements,
+                                     bool isPacked) {
+  return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked);
+}
+
+LLVMTypeNew LLVMTypeNew::getVectorTy(LLVMTypeNew elementType,
+                                     unsigned numElements) {
+  return LLVMFixedVectorType::get(elementType, numElements);
+}
+
+//----------------------------------------------------------------------------//
+// Void type utilities.
+
+LLVMTypeNew LLVMTypeNew::getVoidTy(LLVMDialect *dialect) {
+  return LLVMVoidType::get(dialect->getContext());
+}
+
+bool LLVMTypeNew::isVoidTy() { return isa<LLVMVoidType>(); }
+
+//----------------------------------------------------------------------------//
+// Creation and setting of LLVM's identified struct types
+
+LLVMTypeNew LLVMTypeNew::createStructTy(LLVMDialect *dialect,
+                                        ArrayRef<LLVMTypeNew> elements,
+                                        Optional<StringRef> name,
+                                        bool isPacked) {
+  assert(name.hasValue() &&
+         "identified structs with no identifier not supported");
+  StringRef stringNameBase = name.getValueOr("");
+  std::string stringName = stringNameBase.str();
+  unsigned counter = 0;
+  do {
+    auto type =
+        LLVMStructType::getIdentified(dialect->getContext(), stringName);
+    if (type.isInitialized() || failed(type.setBody(elements, isPacked))) {
+      counter += 1;
+      stringName =
+          (Twine(stringNameBase) + "." + std::to_string(counter)).str();
+      continue;
+    }
+    return type;
+  } while (true);
+}
+
+LLVMTypeNew LLVMTypeNew::setStructTyBody(LLVMTypeNew structType,
+                                         ArrayRef<LLVMTypeNew> elements,
+                                         bool isPacked) {
+  LogicalResult couldSet =
+      structType.cast<LLVMStructType>().setBody(elements, isPacked);
+  assert(succeeded(couldSet) && "failed to set the body");
+  (void)couldSet;
+  return structType;
+}
+
 //===----------------------------------------------------------------------===//
 // Array type.
 
@@ -117,6 +325,7 @@ bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); }
 bool LLVMStructType::isOpaque() {
   return getImpl()->isOpaque() || !getImpl()->isInitialized();
 }
+bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
 ArrayRef<LLVMTypeNew> LLVMStructType::getBody() {
   return isIdentified() ? getImpl()->getIdentifiedStructBody()


        


More information about the Mlir-commits mailing list