[llvm] 751d533 - [llvm][IR][CastInst] Update `castIsValid` for scalable vectors.

Francesco Petrogalli via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 30 14:14:21 PDT 2020


Author: Francesco Petrogalli
Date: 2020-03-30T21:13:40Z
New Revision: 751d5332bd63d93de64f4ada0fc6a8f5ab01de66

URL: https://github.com/llvm/llvm-project/commit/751d5332bd63d93de64f4ada0fc6a8f5ab01de66
DIFF: https://github.com/llvm/llvm-project/commit/751d5332bd63d93de64f4ada0fc6a8f5ab01de66.diff

LOG: [llvm][IR][CastInst] Update `castIsValid` for scalable vectors.

Reviewers: sdesmalen

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76738

Added: 
    

Modified: 
    llvm/lib/IR/Instructions.cpp
    llvm/unittests/IR/InstructionsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index f6748c8d864e..3807752eae0d 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3241,57 +3241,54 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
       SrcTy->isAggregateType() || DstTy->isAggregateType())
     return false;
 
-  // Get the size of the types in bits, we'll need this later
-  unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
-  unsigned DstBitSize = DstTy->getScalarSizeInBits();
+  // Get the size of the types in bits, and whether we are dealing
+  // with vector types, we'll need this later.
+  bool SrcIsVec = isa<VectorType>(SrcTy);
+  bool DstIsVec = isa<VectorType>(DstTy);
+  unsigned SrcScalarBitSize = SrcTy->getScalarSizeInBits();
+  unsigned DstScalarBitSize = DstTy->getScalarSizeInBits();
 
   // If these are vector types, get the lengths of the vectors (using zero for
   // scalar types means that checking that vector lengths match also checks that
   // scalars are not being converted to vectors or vectors to scalars).
-  unsigned SrcLength = SrcTy->isVectorTy() ?
-    cast<VectorType>(SrcTy)->getNumElements() : 0;
-  unsigned DstLength = DstTy->isVectorTy() ?
-    cast<VectorType>(DstTy)->getNumElements() : 0;
+  ElementCount SrcEC = SrcIsVec ? cast<VectorType>(SrcTy)->getElementCount()
+                                : ElementCount(0, false);
+  ElementCount DstEC = DstIsVec ? cast<VectorType>(DstTy)->getElementCount()
+                                : ElementCount(0, false);
 
   // Switch on the opcode provided
   switch (op) {
   default: return false; // This is an input error
   case Instruction::Trunc:
     return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
-      SrcLength == DstLength && SrcBitSize > DstBitSize;
+           SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize;
   case Instruction::ZExt:
     return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
-      SrcLength == DstLength && SrcBitSize < DstBitSize;
+           SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
   case Instruction::SExt:
     return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
-      SrcLength == DstLength && SrcBitSize < DstBitSize;
+           SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
   case Instruction::FPTrunc:
     return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
-      SrcLength == DstLength && SrcBitSize > DstBitSize;
+           SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize;
   case Instruction::FPExt:
     return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
-      SrcLength == DstLength && SrcBitSize < DstBitSize;
+           SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
   case Instruction::UIToFP:
   case Instruction::SIToFP:
     return SrcTy->isIntOrIntVectorTy() && DstTy->isFPOrFPVectorTy() &&
-      SrcLength == DstLength;
+           SrcEC == DstEC;
   case Instruction::FPToUI:
   case Instruction::FPToSI:
     return SrcTy->isFPOrFPVectorTy() && DstTy->isIntOrIntVectorTy() &&
-      SrcLength == DstLength;
+           SrcEC == DstEC;
   case Instruction::PtrToInt:
-    if (isa<VectorType>(SrcTy) != isa<VectorType>(DstTy))
+    if (SrcEC != DstEC)
       return false;
-    if (VectorType *VT = dyn_cast<VectorType>(SrcTy))
-      if (VT->getNumElements() != cast<VectorType>(DstTy)->getNumElements())
-        return false;
     return SrcTy->isPtrOrPtrVectorTy() && DstTy->isIntOrIntVectorTy();
   case Instruction::IntToPtr:
-    if (isa<VectorType>(SrcTy) != isa<VectorType>(DstTy))
+    if (SrcEC != DstEC)
       return false;
-    if (VectorType *VT = dyn_cast<VectorType>(SrcTy))
-      if (VT->getNumElements() != cast<VectorType>(DstTy)->getNumElements())
-        return false;
     return SrcTy->isIntOrIntVectorTy() && DstTy->isPtrOrPtrVectorTy();
   case Instruction::BitCast: {
     PointerType *SrcPtrTy = dyn_cast<PointerType>(SrcTy->getScalarType());
@@ -3312,14 +3309,12 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
       return false;
 
     // A vector of pointers must have the same number of elements.
-    VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy);
-    VectorType *DstVecTy = dyn_cast<VectorType>(DstTy);
-    if (SrcVecTy && DstVecTy)
-      return (SrcVecTy->getNumElements() == DstVecTy->getNumElements());
-    if (SrcVecTy)
-      return SrcVecTy->getNumElements() == 1;
-    if (DstVecTy)
-      return DstVecTy->getNumElements() == 1;
+    if (SrcIsVec && DstIsVec)
+      return SrcEC == DstEC;
+    if (SrcIsVec)
+      return SrcEC == ElementCount(1, false);
+    if (DstIsVec)
+      return DstEC == ElementCount(1, false);
 
     return true;
   }
@@ -3335,14 +3330,7 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
     if (SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace())
       return false;
 
-    if (VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy)) {
-      if (VectorType *DstVecTy = dyn_cast<VectorType>(DstTy))
-        return (SrcVecTy->getNumElements() == DstVecTy->getNumElements());
-
-      return false;
-    }
-
-    return true;
+    return SrcEC == DstEC;
   }
   }
 }

diff  --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index c2f70724337c..f49c75a6015f 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -197,6 +197,12 @@ TEST(InstructionsTest, CastInst) {
   Type *V2Int32Ty = VectorType::get(Int32Ty, 2);
   Type *V2Int64Ty = VectorType::get(Int64Ty, 2);
   Type *V4Int16Ty = VectorType::get(Int16Ty, 4);
+  Type *V1Int16Ty = VectorType::get(Int16Ty, 1);
+
+  Type *VScaleV2Int32Ty = VectorType::get(Int32Ty, 2, true);
+  Type *VScaleV2Int64Ty = VectorType::get(Int64Ty, 2, true);
+  Type *VScaleV4Int16Ty = VectorType::get(Int16Ty, 4, true);
+  Type *VScaleV1Int16Ty = VectorType::get(Int16Ty, 1, true);
 
   Type *Int32PtrTy = PointerType::get(Int32Ty, 0);
   Type *Int64PtrTy = PointerType::get(Int64Ty, 0);
@@ -207,11 +213,15 @@ TEST(InstructionsTest, CastInst) {
   Type *V2Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 2);
   Type *V2Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 2);
   Type *V4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4);
+  Type *VScaleV4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4, true);
   Type *V4Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 4);
 
   Type *V2Int64PtrTy = VectorType::get(Int64PtrTy, 2);
   Type *V2Int32PtrTy = VectorType::get(Int32PtrTy, 2);
+  Type *VScaleV2Int32PtrTy = VectorType::get(Int32PtrTy, 2, true);
   Type *V4Int32PtrTy = VectorType::get(Int32PtrTy, 4);
+  Type *VScaleV4Int32PtrTy = VectorType::get(Int32PtrTy, 4, true);
+  Type *VScaleV4Int64PtrTy = VectorType::get(Int64PtrTy, 4, true);
 
   const Constant* c8 = Constant::getNullValue(V8x8Ty);
   const Constant* c64 = Constant::getNullValue(V8x64Ty);
@@ -286,6 +296,75 @@ TEST(InstructionsTest, CastInst) {
                                      Constant::getNullValue(V2Int32PtrTy),
                                      V4Int32PtrAS1Ty));
 
+  // Address space cast of fixed/scalable vectors of pointers to scalable/fixed
+  // vector of pointers.
+  EXPECT_FALSE(CastInst::castIsValid(
+      Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+      V4Int32PtrTy));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+                                     Constant::getNullValue(V4Int32PtrTy),
+                                     VScaleV4Int32PtrAS1Ty));
+  // Address space cast of scalable vectors of pointers to scalable vector of
+  // pointers.
+  EXPECT_FALSE(CastInst::castIsValid(
+      Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+      VScaleV2Int32PtrTy));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+                                     Constant::getNullValue(VScaleV2Int32PtrTy),
+                                     VScaleV4Int32PtrAS1Ty));
+  EXPECT_TRUE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+                                    Constant::getNullValue(VScaleV4Int64PtrTy),
+                                    VScaleV4Int32PtrAS1Ty));
+  // Same number of lanes, 
diff erent address space.
+  EXPECT_TRUE(CastInst::castIsValid(
+      Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+      VScaleV4Int32PtrTy));
+  // Same number of lanes, same address space.
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+                                     Constant::getNullValue(VScaleV4Int64PtrTy),
+                                     VScaleV4Int32PtrTy));
+
+  // Bit casting fixed/scalable vector to scalable/fixed vectors.
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(V2Int32Ty),
+                                     VScaleV2Int32Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(V2Int64Ty),
+                                     VScaleV2Int64Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(V4Int16Ty),
+                                     VScaleV4Int16Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV2Int32Ty),
+                                     V2Int32Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV2Int64Ty),
+                                     V2Int64Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV4Int16Ty),
+                                     V4Int16Ty));
+
+  // Bit casting scalable vectors to scalable vectors.
+  EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast,
+                                    Constant::getNullValue(VScaleV4Int16Ty),
+                                    VScaleV2Int32Ty));
+  EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast,
+                                    Constant::getNullValue(VScaleV2Int32Ty),
+                                    VScaleV4Int16Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV2Int64Ty),
+                                     VScaleV2Int32Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV2Int32Ty),
+                                     VScaleV2Int64Ty));
+
+  // Bitcasting to/from <vscale x 1 x Ty>
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(VScaleV1Int16Ty),
+                                     V1Int16Ty));
+  EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+                                     Constant::getNullValue(V1Int16Ty),
+                                     VScaleV1Int16Ty));
 
   // Check that assertion is not hit when creating a cast with a vector of
   // pointers


        


More information about the llvm-commits mailing list