[Mlir-commits] [mlir] 92dde8a - Clean up usages of asserting vector getters in Type

Christopher Tetreault llvmlistbot at llvm.org
Fri Apr 10 13:46:35 PDT 2020


Author: Christopher Tetreault
Date: 2020-04-10T13:46:18-07:00
New Revision: 92dde8a6579108ae9e1d213e5d7709e1e1c1e46c

URL: https://github.com/llvm/llvm-project/commit/92dde8a6579108ae9e1d213e5d7709e1e1c1e46c
DIFF: https://github.com/llvm/llvm-project/commit/92dde8a6579108ae9e1d213e5d7709e1e1c1e46c.diff

LOG: Clean up usages of asserting vector getters in Type

Summary:
Remove usages of asserting vector getters in Type in preparation for the
VectorType refactor. The existence of these functions complicates the
refactor while adding little value.

Reviewers: rriddle, efriedma, sdesmalen

Reviewed By: sdesmalen

Subscribers: frgossen, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, Joonsoo, grosul1, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77258

Added: 
    

Modified: 
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 57663a39e132..ea4f996a13b2 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -1804,9 +1804,10 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
         op, operands, typeConverter,
         [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
-              mlir::VectorType::get(
-                  {llvmVectorTy.getUnderlyingType()->getVectorNumElements()},
-                  floatType),
+              mlir::VectorType::get({(unsigned)cast<llvm::VectorType>(
+                                         llvmVectorTy.getUnderlyingType())
+                                         ->getNumElements()},
+                                    floatType),
               floatOne);
           auto one =
               rewriter.create<LLVM::ConstantOp>(loc, llvmVectorTy, splatAttr);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 82bbe18dd01e..01e35006d44d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -102,7 +102,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
     return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
   if (argType.getUnderlyingType()->isVectorTy())
     resultType = LLVMType::getVectorTy(
-        resultType, argType.getUnderlyingType()->getVectorNumElements());
+        resultType, llvm::cast<llvm::VectorType>(argType.getUnderlyingType())
+                        ->getNumElements());
 
   result.addTypes({resultType});
   return success();
@@ -1772,10 +1773,12 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
 
 /// Vector type utilities.
 LLVMType LLVMType::getVectorElementType() {
-  return get(getContext(), getUnderlyingType()->getVectorElementType());
+  return get(
+      getContext(),
+      llvm::cast<llvm::VectorType>(getUnderlyingType())->getElementType());
 }
 unsigned LLVMType::getVectorNumElements() {
-  return getUnderlyingType()->getVectorNumElements();
+  return llvm::cast<llvm::VectorType>(getUnderlyingType())->getNumElements();
 }
 bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
 

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 5f1ae738280a..bcab9af93b65 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -169,14 +169,15 @@ LLVMType Importer::processType(llvm::Type *type) {
     return LLVMType::getArrayTy(elementType, type->getArrayNumElements());
   }
   case llvm::Type::VectorTyID: {
-    if (type->getVectorIsScalable()) {
+    auto *typeVTy = llvm::cast<llvm::VectorType>(type);
+    if (typeVTy->isScalable()) {
       emitError(unknownLoc) << "scalable vector types not supported";
       return nullptr;
     }
-    LLVMType elementType = processType(type->getVectorElementType());
+    LLVMType elementType = processType(typeVTy->getElementType());
     if (!elementType)
       return nullptr;
-    return LLVMType::getVectorTy(elementType, type->getVectorNumElements());
+    return LLVMType::getVectorTy(elementType, typeVTy->getNumElements());
   }
   case llvm::Type::VoidTyID:
     return LLVMType::getVoidTy(dialect);
@@ -243,7 +244,8 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
 
   // LLVM vectors can only contain scalars.
   if (type.isVectorTy()) {
-    auto numElements = type.getUnderlyingType()->getVectorElementCount();
+    auto numElements = llvm::cast<llvm::VectorType>(type.getUnderlyingType())
+                           ->getElementCount();
     if (numElements.Scalable) {
       emitError(unknownLoc) << "scalable vectors not supported";
       return nullptr;
@@ -269,7 +271,8 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
     if (type.getArrayElementType().isVectorTy()) {
       LLVMType vectorType = type.getArrayElementType();
       auto numElements =
-          vectorType.getUnderlyingType()->getVectorElementCount();
+          llvm::cast<llvm::VectorType>(vectorType.getUnderlyingType())
+              ->getElementCount();
       if (numElements.Scalable) {
         emitError(unknownLoc) << "scalable vectors not supported";
         return nullptr;


        


More information about the Mlir-commits mailing list