[llvm] 751d533 - [llvm][IR][CastInst] Update `castIsValid` for scalable vectors.
Francesco Petrogalli via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 30 14:14:21 PDT 2020
Author: Francesco Petrogalli
Date: 2020-03-30T21:13:40Z
New Revision: 751d5332bd63d93de64f4ada0fc6a8f5ab01de66
URL: https://github.com/llvm/llvm-project/commit/751d5332bd63d93de64f4ada0fc6a8f5ab01de66
DIFF: https://github.com/llvm/llvm-project/commit/751d5332bd63d93de64f4ada0fc6a8f5ab01de66.diff
LOG: [llvm][IR][CastInst] Update `castIsValid` for scalable vectors.
Reviewers: sdesmalen
Subscribers: hiraditya, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D76738
Added:
Modified:
llvm/lib/IR/Instructions.cpp
llvm/unittests/IR/InstructionsTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index f6748c8d864e..3807752eae0d 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3241,57 +3241,54 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
SrcTy->isAggregateType() || DstTy->isAggregateType())
return false;
- // Get the size of the types in bits, we'll need this later
- unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
- unsigned DstBitSize = DstTy->getScalarSizeInBits();
+ // Get the size of the types in bits, and whether we are dealing
+ // with vector types, we'll need this later.
+ bool SrcIsVec = isa<VectorType>(SrcTy);
+ bool DstIsVec = isa<VectorType>(DstTy);
+ unsigned SrcScalarBitSize = SrcTy->getScalarSizeInBits();
+ unsigned DstScalarBitSize = DstTy->getScalarSizeInBits();
// If these are vector types, get the lengths of the vectors (using zero for
// scalar types means that checking that vector lengths match also checks that
// scalars are not being converted to vectors or vectors to scalars).
- unsigned SrcLength = SrcTy->isVectorTy() ?
- cast<VectorType>(SrcTy)->getNumElements() : 0;
- unsigned DstLength = DstTy->isVectorTy() ?
- cast<VectorType>(DstTy)->getNumElements() : 0;
+ ElementCount SrcEC = SrcIsVec ? cast<VectorType>(SrcTy)->getElementCount()
+ : ElementCount(0, false);
+ ElementCount DstEC = DstIsVec ? cast<VectorType>(DstTy)->getElementCount()
+ : ElementCount(0, false);
// Switch on the opcode provided
switch (op) {
default: return false; // This is an input error
case Instruction::Trunc:
return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
- SrcLength == DstLength && SrcBitSize > DstBitSize;
+ SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize;
case Instruction::ZExt:
return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
- SrcLength == DstLength && SrcBitSize < DstBitSize;
+ SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
case Instruction::SExt:
return SrcTy->isIntOrIntVectorTy() && DstTy->isIntOrIntVectorTy() &&
- SrcLength == DstLength && SrcBitSize < DstBitSize;
+ SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
case Instruction::FPTrunc:
return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
- SrcLength == DstLength && SrcBitSize > DstBitSize;
+ SrcEC == DstEC && SrcScalarBitSize > DstScalarBitSize;
case Instruction::FPExt:
return SrcTy->isFPOrFPVectorTy() && DstTy->isFPOrFPVectorTy() &&
- SrcLength == DstLength && SrcBitSize < DstBitSize;
+ SrcEC == DstEC && SrcScalarBitSize < DstScalarBitSize;
case Instruction::UIToFP:
case Instruction::SIToFP:
return SrcTy->isIntOrIntVectorTy() && DstTy->isFPOrFPVectorTy() &&
- SrcLength == DstLength;
+ SrcEC == DstEC;
case Instruction::FPToUI:
case Instruction::FPToSI:
return SrcTy->isFPOrFPVectorTy() && DstTy->isIntOrIntVectorTy() &&
- SrcLength == DstLength;
+ SrcEC == DstEC;
case Instruction::PtrToInt:
- if (isa<VectorType>(SrcTy) != isa<VectorType>(DstTy))
+ if (SrcEC != DstEC)
return false;
- if (VectorType *VT = dyn_cast<VectorType>(SrcTy))
- if (VT->getNumElements() != cast<VectorType>(DstTy)->getNumElements())
- return false;
return SrcTy->isPtrOrPtrVectorTy() && DstTy->isIntOrIntVectorTy();
case Instruction::IntToPtr:
- if (isa<VectorType>(SrcTy) != isa<VectorType>(DstTy))
+ if (SrcEC != DstEC)
return false;
- if (VectorType *VT = dyn_cast<VectorType>(SrcTy))
- if (VT->getNumElements() != cast<VectorType>(DstTy)->getNumElements())
- return false;
return SrcTy->isIntOrIntVectorTy() && DstTy->isPtrOrPtrVectorTy();
case Instruction::BitCast: {
PointerType *SrcPtrTy = dyn_cast<PointerType>(SrcTy->getScalarType());
@@ -3312,14 +3309,12 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
return false;
// A vector of pointers must have the same number of elements.
- VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy);
- VectorType *DstVecTy = dyn_cast<VectorType>(DstTy);
- if (SrcVecTy && DstVecTy)
- return (SrcVecTy->getNumElements() == DstVecTy->getNumElements());
- if (SrcVecTy)
- return SrcVecTy->getNumElements() == 1;
- if (DstVecTy)
- return DstVecTy->getNumElements() == 1;
+ if (SrcIsVec && DstIsVec)
+ return SrcEC == DstEC;
+ if (SrcIsVec)
+ return SrcEC == ElementCount(1, false);
+ if (DstIsVec)
+ return DstEC == ElementCount(1, false);
return true;
}
@@ -3335,14 +3330,7 @@ CastInst::castIsValid(Instruction::CastOps op, Value *S, Type *DstTy) {
if (SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace())
return false;
- if (VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy)) {
- if (VectorType *DstVecTy = dyn_cast<VectorType>(DstTy))
- return (SrcVecTy->getNumElements() == DstVecTy->getNumElements());
-
- return false;
- }
-
- return true;
+ return SrcEC == DstEC;
}
}
}
diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index c2f70724337c..f49c75a6015f 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -197,6 +197,12 @@ TEST(InstructionsTest, CastInst) {
Type *V2Int32Ty = VectorType::get(Int32Ty, 2);
Type *V2Int64Ty = VectorType::get(Int64Ty, 2);
Type *V4Int16Ty = VectorType::get(Int16Ty, 4);
+ Type *V1Int16Ty = VectorType::get(Int16Ty, 1);
+
+ Type *VScaleV2Int32Ty = VectorType::get(Int32Ty, 2, true);
+ Type *VScaleV2Int64Ty = VectorType::get(Int64Ty, 2, true);
+ Type *VScaleV4Int16Ty = VectorType::get(Int16Ty, 4, true);
+ Type *VScaleV1Int16Ty = VectorType::get(Int16Ty, 1, true);
Type *Int32PtrTy = PointerType::get(Int32Ty, 0);
Type *Int64PtrTy = PointerType::get(Int64Ty, 0);
@@ -207,11 +213,15 @@ TEST(InstructionsTest, CastInst) {
Type *V2Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 2);
Type *V2Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 2);
Type *V4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4);
+ Type *VScaleV4Int32PtrAS1Ty = VectorType::get(Int32PtrAS1Ty, 4, true);
Type *V4Int64PtrAS1Ty = VectorType::get(Int64PtrAS1Ty, 4);
Type *V2Int64PtrTy = VectorType::get(Int64PtrTy, 2);
Type *V2Int32PtrTy = VectorType::get(Int32PtrTy, 2);
+ Type *VScaleV2Int32PtrTy = VectorType::get(Int32PtrTy, 2, true);
Type *V4Int32PtrTy = VectorType::get(Int32PtrTy, 4);
+ Type *VScaleV4Int32PtrTy = VectorType::get(Int32PtrTy, 4, true);
+ Type *VScaleV4Int64PtrTy = VectorType::get(Int64PtrTy, 4, true);
const Constant* c8 = Constant::getNullValue(V8x8Ty);
const Constant* c64 = Constant::getNullValue(V8x64Ty);
@@ -286,6 +296,75 @@ TEST(InstructionsTest, CastInst) {
Constant::getNullValue(V2Int32PtrTy),
V4Int32PtrAS1Ty));
+ // Address space cast of fixed/scalable vectors of pointers to scalable/fixed
+ // vector of pointers.
+ EXPECT_FALSE(CastInst::castIsValid(
+ Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+ V4Int32PtrTy));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+ Constant::getNullValue(V4Int32PtrTy),
+ VScaleV4Int32PtrAS1Ty));
+ // Address space cast of scalable vectors of pointers to scalable vector of
+ // pointers.
+ EXPECT_FALSE(CastInst::castIsValid(
+ Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+ VScaleV2Int32PtrTy));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+ Constant::getNullValue(VScaleV2Int32PtrTy),
+ VScaleV4Int32PtrAS1Ty));
+ EXPECT_TRUE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+ Constant::getNullValue(VScaleV4Int64PtrTy),
+ VScaleV4Int32PtrAS1Ty));
+ // Same number of lanes,
diff erent address space.
+ EXPECT_TRUE(CastInst::castIsValid(
+ Instruction::AddrSpaceCast, Constant::getNullValue(VScaleV4Int32PtrAS1Ty),
+ VScaleV4Int32PtrTy));
+ // Same number of lanes, same address space.
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::AddrSpaceCast,
+ Constant::getNullValue(VScaleV4Int64PtrTy),
+ VScaleV4Int32PtrTy));
+
+ // Bit casting fixed/scalable vector to scalable/fixed vectors.
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(V2Int32Ty),
+ VScaleV2Int32Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(V2Int64Ty),
+ VScaleV2Int64Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(V4Int16Ty),
+ VScaleV4Int16Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV2Int32Ty),
+ V2Int32Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV2Int64Ty),
+ V2Int64Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV4Int16Ty),
+ V4Int16Ty));
+
+ // Bit casting scalable vectors to scalable vectors.
+ EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV4Int16Ty),
+ VScaleV2Int32Ty));
+ EXPECT_TRUE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV2Int32Ty),
+ VScaleV4Int16Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV2Int64Ty),
+ VScaleV2Int32Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV2Int32Ty),
+ VScaleV2Int64Ty));
+
+ // Bitcasting to/from <vscale x 1 x Ty>
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(VScaleV1Int16Ty),
+ V1Int16Ty));
+ EXPECT_FALSE(CastInst::castIsValid(Instruction::BitCast,
+ Constant::getNullValue(V1Int16Ty),
+ VScaleV1Int16Ty));
// Check that assertion is not hit when creating a cast with a vector of
// pointers
More information about the llvm-commits
mailing list