[clang] 118abf2 - [SVE] Update API ConstantVector::getSplat() to use ElementCount.
Huihui Zhang via cfe-commits
cfe-commits at lists.llvm.org
Thu Mar 12 13:30:43 PDT 2020
Author: Huihui Zhang
Date: 2020-03-12T13:22:41-07:00
New Revision: 118abf20173899e9e1667db1a9c850dc5570b6ae
URL: https://github.com/llvm/llvm-project/commit/118abf20173899e9e1667db1a9c850dc5570b6ae
DIFF: https://github.com/llvm/llvm-project/commit/118abf20173899e9e1667db1a9c850dc5570b6ae.diff
LOG: [SVE] Update API ConstantVector::getSplat() to use ElementCount.
Summary:
Support ConstantInt::get() and Constant::getAllOnesValue() for scalable
vector type, this requires ConstantVector::getSplat() to take in 'ElementCount',
instead of 'unsigned' number of element count.
This change is needed for D73753.
Reviewers: sdesmalen, efriedma, apazos, spatel, huntergr, willlovett
Reviewed By: efriedma
Subscribers: tschuett, hiraditya, rkruppe, psnobl, cfe-commits, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D74386
Added:
llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll
Modified:
clang/lib/CodeGen/CGBuiltin.cpp
llvm/include/llvm/Analysis/Utils/Local.h
llvm/include/llvm/IR/Constants.h
llvm/lib/Analysis/InstructionSimplify.cpp
llvm/lib/CodeGen/CodeGenPrepare.cpp
llvm/lib/IR/ConstantFold.cpp
llvm/lib/IR/Constants.cpp
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
llvm/test/Transforms/InstSimplify/gep.ll
llvm/unittests/FuzzMutate/OperationsTest.cpp
llvm/unittests/IR/VerifierTest.cpp
Removed:
################################################################################
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2cd0f8814cc..436084ef23cb 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -4496,8 +4496,8 @@ static llvm::VectorType *GetFloatNeonType(CodeGenFunction *CGF,
}
Value *CodeGenFunction::EmitNeonSplat(Value *V, Constant *C) {
- unsigned nElts = V->getType()->getVectorNumElements();
- Value* SV = llvm::ConstantVector::getSplat(nElts, C);
+ ElementCount EC = V->getType()->getVectorElementCount();
+ Value *SV = llvm::ConstantVector::getSplat(EC, C);
return Builder.CreateShuffleVector(V, V, SV, "lane");
}
@@ -8701,7 +8701,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
llvm::VectorType::get(VTy->getElementType(), VTy->getNumElements() / 2) :
VTy;
llvm::Constant *cst = cast<Constant>(Ops[3]);
- Value *SV = llvm::ConstantVector::getSplat(VTy->getNumElements(), cst);
+ Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(), cst);
Ops[1] = Builder.CreateBitCast(Ops[1], SourceTy);
Ops[1] = Builder.CreateShuffleVector(Ops[1], Ops[1], SV, "lane");
@@ -8730,7 +8730,7 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
llvm::Type *STy = llvm::VectorType::get(VTy->getElementType(),
VTy->getNumElements() * 2);
Ops[2] = Builder.CreateBitCast(Ops[2], STy);
- Value* SV = llvm::ConstantVector::getSplat(VTy->getNumElements(),
+ Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(),
cast<ConstantInt>(Ops[3]));
Ops[2] = Builder.CreateShuffleVector(Ops[2], Ops[2], SV, "lane");
diff --git a/llvm/include/llvm/Analysis/Utils/Local.h b/llvm/include/llvm/Analysis/Utils/Local.h
index ca505960cbeb..84e884e46d0b 100644
--- a/llvm/include/llvm/Analysis/Utils/Local.h
+++ b/llvm/include/llvm/Analysis/Utils/Local.h
@@ -63,7 +63,7 @@ Value *EmitGEPOffset(IRBuilderTy *Builder, const DataLayout &DL, User *GEP,
// Splat the constant if needed.
if (IntIdxTy->isVectorTy() && !OpC->getType()->isVectorTy())
- OpC = ConstantVector::getSplat(IntIdxTy->getVectorNumElements(), OpC);
+ OpC = ConstantVector::getSplat(IntIdxTy->getVectorElementCount(), OpC);
Constant *Scale = ConstantInt::get(IntIdxTy, Size);
Constant *OC = ConstantExpr::getIntegerCast(OpC, IntIdxTy, true /*SExt*/);
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 90bf22bd4344..e6d8c0eb4d92 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -517,7 +517,7 @@ class ConstantVector final : public ConstantAggregate {
public:
/// Return a ConstantVector with the specified constant in each element.
- static Constant *getSplat(unsigned NumElts, Constant *Elt);
+ static Constant *getSplat(ElementCount EC, Constant *Elt);
/// Specialize the getType() method to always return a VectorType,
/// which reduces the amount of casting needed in parts of the compiler.
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index de7310623e84..bc9ae5ebffb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -707,9 +707,8 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V,
Offset = Offset.sextOrTrunc(IntIdxTy->getIntegerBitWidth());
Constant *OffsetIntPtr = ConstantInt::get(IntIdxTy, Offset);
- if (V->getType()->isVectorTy())
- return ConstantVector::getSplat(V->getType()->getVectorNumElements(),
- OffsetIntPtr);
+ if (VectorType *VecTy = dyn_cast<VectorType>(V->getType()))
+ return ConstantVector::getSplat(VecTy->getElementCount(), OffsetIntPtr);
return OffsetIntPtr;
}
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 262036499b30..d2e31e492db5 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6565,19 +6565,23 @@ class VectorPromoteHelper {
UseSplat = true;
}
- unsigned End = getTransitionType()->getVectorNumElements();
+ ElementCount EC = getTransitionType()->getVectorElementCount();
if (UseSplat)
- return ConstantVector::getSplat(End, Val);
-
- SmallVector<Constant *, 4> ConstVec;
- UndefValue *UndefVal = UndefValue::get(Val->getType());
- for (unsigned Idx = 0; Idx != End; ++Idx) {
- if (Idx == ExtractIdx)
- ConstVec.push_back(Val);
- else
- ConstVec.push_back(UndefVal);
- }
- return ConstantVector::get(ConstVec);
+ return ConstantVector::getSplat(EC, Val);
+
+ if (!EC.Scalable) {
+ SmallVector<Constant *, 4> ConstVec;
+ UndefValue *UndefVal = UndefValue::get(Val->getType());
+ for (unsigned Idx = 0; Idx != EC.Min; ++Idx) {
+ if (Idx == ExtractIdx)
+ ConstVec.push_back(Val);
+ else
+ ConstVec.push_back(UndefVal);
+ }
+ return ConstantVector::get(ConstVec);
+ } else
+ llvm_unreachable(
+ "Generate scalable vector for non-splat is unimplemented");
}
/// Check if promoting to a vector type an operand at \p OperandIdx
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index acd10b46c86f..0da027f56fdd 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -2229,8 +2229,7 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
Constant *Idx0 = cast<Constant>(Idxs[0]);
if (Idxs.size() == 1 && (Idx0->isNullValue() || isa<UndefValue>(Idx0)))
return GEPTy->isVectorTy() && !C->getType()->isVectorTy()
- ? ConstantVector::getSplat(
- cast<VectorType>(GEPTy)->getNumElements(), C)
+ ? ConstantVector::getSplat(GEPTy->getVectorElementCount(), C)
: C;
if (C->isNullValue()) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 399bd41c82b2..eb0e5894ae54 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -370,7 +370,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) {
// Broadcast a scalar to a vector, if necessary.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- C = ConstantVector::getSplat(VTy->getNumElements(), C);
+ C = ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -387,7 +387,7 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
}
VectorType *VTy = cast<VectorType>(Ty);
- return ConstantVector::getSplat(VTy->getNumElements(),
+ return ConstantVector::getSplat(VTy->getElementCount(),
getAllOnesValue(VTy->getElementType()));
}
@@ -681,7 +681,7 @@ Constant *ConstantInt::getTrue(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), TrueC);
+ return ConstantVector::getSplat(VTy->getElementCount(), TrueC);
return TrueC;
}
@@ -689,7 +689,7 @@ Constant *ConstantInt::getFalse(Type *Ty) {
assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1.");
ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext());
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), FalseC);
+ return ConstantVector::getSplat(VTy->getElementCount(), FalseC);
return FalseC;
}
@@ -712,7 +712,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -736,7 +736,7 @@ Constant *ConstantInt::get(Type *Ty, const APInt& V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -781,7 +781,7 @@ Constant *ConstantFP::get(Type *Ty, double V) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -793,7 +793,7 @@ Constant *ConstantFP::get(Type *Ty, const APFloat &V) {
// For vectors, broadcast the value.
if (auto *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -806,7 +806,7 @@ Constant *ConstantFP::get(Type *Ty, StringRef Str) {
// For vectors, broadcast the value.
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -817,7 +817,7 @@ Constant *ConstantFP::getNaN(Type *Ty, bool Negative, uint64_t Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -828,7 +828,7 @@ Constant *ConstantFP::getQNaN(Type *Ty, bool Negative, APInt *Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -839,7 +839,7 @@ Constant *ConstantFP::getSNaN(Type *Ty, bool Negative, APInt *Payload) {
Constant *C = get(Ty->getContext(), NaN);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -850,7 +850,7 @@ Constant *ConstantFP::getNegativeZero(Type *Ty) {
Constant *C = get(Ty->getContext(), NegZero);
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -898,7 +898,7 @@ Constant *ConstantFP::getInfinity(Type *Ty, bool Negative) {
Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative));
if (VectorType *VTy = dyn_cast<VectorType>(Ty))
- return ConstantVector::getSplat(VTy->getNumElements(), C);
+ return ConstantVector::getSplat(VTy->getElementCount(), C);
return C;
}
@@ -1204,15 +1204,35 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
return nullptr;
}
-Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
- // If this splat is compatible with ConstantDataVector, use it instead of
- // ConstantVector.
- if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
- ConstantDataSequential::isElementTypeCompatible(V->getType()))
- return ConstantDataVector::getSplat(NumElts, V);
+Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) {
+ if (!EC.Scalable) {
+ // If this splat is compatible with ConstantDataVector, use it instead of
+ // ConstantVector.
+ if ((isa<ConstantFP>(V) || isa<ConstantInt>(V)) &&
+ ConstantDataSequential::isElementTypeCompatible(V->getType()))
+ return ConstantDataVector::getSplat(EC.Min, V);
- SmallVector<Constant*, 32> Elts(NumElts, V);
- return get(Elts);
+ SmallVector<Constant *, 32> Elts(EC.Min, V);
+ return get(Elts);
+ }
+
+ Type *VTy = VectorType::get(V->getType(), EC);
+
+ if (V->isNullValue())
+ return ConstantAggregateZero::get(VTy);
+ else if (isa<UndefValue>(V))
+ return UndefValue::get(VTy);
+
+ Type *I32Ty = Type::getInt32Ty(VTy->getContext());
+
+ // Move scalar into vector.
+ Constant *UndefV = UndefValue::get(VTy);
+ V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0));
+ // Build shuffle mask to perform the splat.
+ Type *MaskTy = VectorType::get(I32Ty, EC);
+ Constant *Zeros = ConstantAggregateZero::get(MaskTy);
+ // Splat.
+ return ConstantExpr::getShuffleVector(V, UndefV, Zeros);
}
ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) {
@@ -2098,15 +2118,15 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
unsigned AS = C->getType()->getPointerAddressSpace();
Type *ReqTy = DestTy->getPointerTo(AS);
- unsigned NumVecElts = 0;
- if (C->getType()->isVectorTy())
- NumVecElts = C->getType()->getVectorNumElements();
+ ElementCount EltCount = {0, false};
+ if (VectorType *VecTy = dyn_cast<VectorType>(C->getType()))
+ EltCount = VecTy->getElementCount();
else for (auto Idx : Idxs)
- if (Idx->getType()->isVectorTy())
- NumVecElts = Idx->getType()->getVectorNumElements();
+ if (VectorType *VecTy = dyn_cast<VectorType>(Idx->getType()))
+ EltCount = VecTy->getElementCount();
- if (NumVecElts)
- ReqTy = VectorType::get(ReqTy, NumVecElts);
+ if (EltCount.Min != 0)
+ ReqTy = VectorType::get(ReqTy, EltCount);
if (OnlyIfReducedTy == ReqTy)
return nullptr;
@@ -2117,12 +2137,12 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
ArgVec.push_back(C);
for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
assert((!Idxs[i]->getType()->isVectorTy() ||
- Idxs[i]->getType()->getVectorNumElements() == NumVecElts) &&
+ Idxs[i]->getType()->getVectorElementCount() == EltCount) &&
"getelementptr index type missmatch");
Constant *Idx = cast<Constant>(Idxs[i]);
- if (NumVecElts && !Idxs[i]->getType()->isVectorTy())
- Idx = ConstantVector::getSplat(NumVecElts, Idx);
+ if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy())
+ Idx = ConstantVector::getSplat(EltCount, Idx);
ArgVec.push_back(Idx);
}
@@ -2759,7 +2779,7 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
return getFP(V->getContext(), Elts);
}
}
- return ConstantVector::getSplat(NumElts, V);
+ return ConstantVector::getSplat({NumElts, false}, V);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index b68621e9478c..63bf3462faac 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5379,8 +5379,9 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
if (ScalarC && ScalarM) {
// We allow undefs in matching, but this transform removes those for safety.
// Demanded elements analysis should be able to recover some/all of that.
- C = ConstantVector::getSplat(V1Ty->getVectorNumElements(), ScalarC);
- M = ConstantVector::getSplat(M->getType()->getVectorNumElements(), ScalarM);
+ C = ConstantVector::getSplat(V1Ty->getVectorElementCount(), ScalarC);
+ M = ConstantVector::getSplat(M->getType()->getVectorElementCount(),
+ ScalarM);
Value *NewCmp = IsFP ? Builder.CreateFCmp(Pred, V1, C)
: Builder.CreateICmp(Pred, V1, C);
return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 49b178359729..0a842b4e1047 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -774,7 +774,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
}
@@ -809,7 +809,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
if (VectorType *VT = dyn_cast<VectorType>(X->getType()))
- Mask = ConstantVector::getSplat(VT->getNumElements(), Mask);
+ Mask = ConstantVector::getSplat(VT->getElementCount(), Mask);
return BinaryOperator::CreateAnd(X, Mask);
}
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index c6bf118a8c37..a1957ccda3a1 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1717,9 +1717,10 @@ void InnerLoopVectorizer::createVectorIntOrFpInductionPHI(
// FIXME: If the step is non-constant, we create the vector splat with
// IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't
// handle a constant vector splat.
- Value *SplatVF = isa<Constant>(Mul)
- ? ConstantVector::getSplat(VF, cast<Constant>(Mul))
- : Builder.CreateVectorSplat(VF, Mul);
+ Value *SplatVF =
+ isa<Constant>(Mul)
+ ? ConstantVector::getSplat({VF, false}, cast<Constant>(Mul))
+ : Builder.CreateVectorSplat(VF, Mul);
Builder.restoreIP(CurrIP);
// We may need to add the step a number of times, depending on the unroll
@@ -3731,7 +3732,7 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) {
// incoming scalar reduction.
VectorStart = ReductionStartValue;
} else {
- Identity = ConstantVector::getSplat(VF, Iden);
+ Identity = ConstantVector::getSplat({VF, false}, Iden);
// This vector is the Identity vector where the first element is the
// incoming scalar reduction.
diff --git a/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll b/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll
new file mode 100644
index 000000000000..77f1747ea9cf
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll
@@ -0,0 +1,23 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=aarch64 -codegenprepare -S < %s | FileCheck %s
+
+; This test intends to check vector promotion for scalable vector. Current target lowering
+; rejects scalable vector before reaching getConstantVector() in CodeGenPrepare. This test
+; will assert once target lowering is ready, then we can bring in implementation for non-splat
+; codepath for scalable vector.
+
+define void @simpleOneInstructionPromotion(<vscale x 2 x i32>* %addr1, i32* %dest) {
+; CHECK-LABEL: @simpleOneInstructionPromotion(
+; CHECK-NEXT: [[IN1:%.*]] = load <vscale x 2 x i32>, <vscale x 2 x i32>* [[ADDR1:%.*]], align 8
+; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement <vscale x 2 x i32> [[IN1]], i32 1
+; CHECK-NEXT: [[OUT:%.*]] = or i32 [[EXTRACT]], 1
+; CHECK-NEXT: store i32 [[OUT]], i32* [[DEST:%.*]], align 4
+; CHECK-NEXT: ret void
+;
+ %in1 = load <vscale x 2 x i32>, <vscale x 2 x i32>* %addr1, align 8
+ %extract = extractelement <vscale x 2 x i32> %in1, i32 1
+ %out = or i32 %extract, 1
+ store i32 %out, i32* %dest, align 4
+ ret void
+}
+
diff --git a/llvm/test/Transforms/InstSimplify/gep.ll b/llvm/test/Transforms/InstSimplify/gep.ll
index 1fb882777834..c4fa8df717cc 100644
--- a/llvm/test/Transforms/InstSimplify/gep.ll
+++ b/llvm/test/Transforms/InstSimplify/gep.ll
@@ -103,3 +103,69 @@ define <8 x i64*> @undef_vec2() {
ret <8 x i64*> %el
}
+; Check ConstantExpr::getGetElementPtr() using ElementCount for size queries - begin.
+
+; Constant ptr
+
+define i32* @ptr_idx_scalar() {
+; CHECK-LABEL: @ptr_idx_scalar(
+; CHECK-NEXT: ret i32* inttoptr (i64 4 to i32*)
+;
+ %gep = getelementptr <4 x i32>, <4 x i32>* null, i64 0, i64 1
+ ret i32* %gep
+}
+
+define <2 x i32*> @ptr_idx_vector() {
+; CHECK-LABEL: @ptr_idx_vector(
+; CHECK-NEXT: ret <2 x i32*> getelementptr (i32, i32* null, <2 x i64> <i64 1, i64 1>)
+;
+ %gep = getelementptr i32, i32* null, <2 x i64> <i64 1, i64 1>
+ ret <2 x i32*> %gep
+}
+
+define <4 x i32*> @ptr_idx_mix_scalar_vector(){
+; CHECK-LABEL: @ptr_idx_mix_scalar_vector(
+; CHECK-NEXT: ret <4 x i32*> getelementptr ([42 x [3 x i32]], [42 x [3 x i32]]* null, <4 x i64> zeroinitializer, <4 x i64> <i64 0, i64 1, i64 2, i64 3>, <4 x i64> zeroinitializer)
+;
+ %gep = getelementptr [42 x [3 x i32]], [42 x [3 x i32]]* null, i64 0, <4 x i64> <i64 0, i64 1, i64 2, i64 3>, i64 0
+ ret <4 x i32*> %gep
+}
+
+; Constant vector
+
+define <4 x i32*> @vector_idx_scalar() {
+; CHECK-LABEL: @vector_idx_scalar(
+; CHECK-NEXT: ret <4 x i32*> getelementptr (i32, <4 x i32*> zeroinitializer, <4 x i64> <i64 1, i64 1, i64 1, i64 1>)
+;
+ %gep = getelementptr i32, <4 x i32*> zeroinitializer, i64 1
+ ret <4 x i32*> %gep
+}
+
+define <4 x i32*> @vector_idx_vector() {
+; CHECK-LABEL: @vector_idx_vector(
+; CHECK-NEXT: ret <4 x i32*> getelementptr (i32, <4 x i32*> zeroinitializer, <4 x i64> <i64 1, i64 1, i64 1, i64 1>)
+;
+ %gep = getelementptr i32, <4 x i32*> zeroinitializer, <4 x i64> <i64 1, i64 1, i64 1, i64 1>
+ ret <4 x i32*> %gep
+}
+
+%struct = type { double, float }
+define <4 x float*> @vector_idx_mix_scalar_vector() {
+; CHECK-LABEL: @vector_idx_mix_scalar_vector(
+; CHECK-NEXT: ret <4 x float*> getelementptr (%struct, <4 x %struct*> zeroinitializer, <4 x i64> zeroinitializer, <4 x i32> <i32 1, i32 1, i32 1, i32 1>)
+;
+ %gep = getelementptr %struct, <4 x %struct*> zeroinitializer, i32 0, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
+ ret <4 x float*> %gep
+}
+
+; Constant scalable
+
+define <vscale x 4 x i32*> @scalable_idx_scalar() {
+; CHECK-LABEL: @scalable_idx_scalar(
+; CHECK-NEXT: ret <vscale x 4 x i32*> getelementptr (i32, <vscale x 4 x i32*> zeroinitializer, <vscale x 4 x i64> shufflevector (<vscale x 4 x i64> insertelement (<vscale x 4 x i64> undef, i64 1, i32 0), <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer))
+;
+ %gep = getelementptr i32, <vscale x 4 x i32*> zeroinitializer, i64 1
+ ret <vscale x 4 x i32*> %gep
+}
+
+; Check ConstantExpr::getGetElementPtr() using ElementCount for size queries - end.
diff --git a/llvm/unittests/FuzzMutate/OperationsTest.cpp b/llvm/unittests/FuzzMutate/OperationsTest.cpp
index a077c5cd59e0..78a7a13615d5 100644
--- a/llvm/unittests/FuzzMutate/OperationsTest.cpp
+++ b/llvm/unittests/FuzzMutate/OperationsTest.cpp
@@ -92,8 +92,8 @@ TEST(OperationsTest, SourcePreds) {
ConstantStruct::get(StructType::create(Ctx, "OpaqueStruct"));
Constant *a =
ConstantArray::get(ArrayType::get(i32->getType(), 2), {i32, i32});
- Constant *v8i8 = ConstantVector::getSplat(8, i8);
- Constant *v4f16 = ConstantVector::getSplat(4, f16);
+ Constant *v8i8 = ConstantVector::getSplat({8, false}, i8);
+ Constant *v4f16 = ConstantVector::getSplat({4, false}, f16);
Constant *p0i32 =
ConstantPointerNull::get(PointerType::get(i32->getType(), 0));
diff --git a/llvm/unittests/IR/VerifierTest.cpp b/llvm/unittests/IR/VerifierTest.cpp
index 46721e4ec8b4..f6a6a6ec7128 100644
--- a/llvm/unittests/IR/VerifierTest.cpp
+++ b/llvm/unittests/IR/VerifierTest.cpp
@@ -57,7 +57,7 @@ TEST(VerifierTest, Freeze) {
ConstantInt *CI = ConstantInt::get(ITy, 0);
// Valid type : freeze(<2 x i32>)
- Constant *CV = ConstantVector::getSplat(2, CI);
+ Constant *CV = ConstantVector::getSplat({2, false}, CI);
FreezeInst *FI_vec = new FreezeInst(CV);
FI_vec->insertBefore(RI);
More information about the cfe-commits
mailing list