[llvm] 8819202 - [SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold
Christopher Tetreault via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 17 14:20:13 PDT 2020
Author: Christopher Tetreault
Date: 2020-06-17T14:19:56-07:00
New Revision: 8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a
URL: https://github.com/llvm/llvm-project/commit/8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a
DIFF: https://github.com/llvm/llvm-project/commit/8819202dfd2c39a7ed4dd69f0d7e0e0bcf409e2a.diff
LOG: [SVE] Eliminate bad VectorType::getNumElements() calls from ConstantFold
Summary:
Assume all usages of this function are explicitly fixed-width operations
and cast to FixedVectorType
Reviewers: efriedma, sdesmalen, c-rhodes, majnemer, dblaikie
Reviewed By: sdesmalen
Subscribers: tschuett, hiraditya, rkruppe, psnobl, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D80262
Added:
llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll
Modified:
llvm/lib/IR/ConstantFold.cpp
Removed:
################################################################################
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3fb49e94870f..ef584afc68bc 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -55,8 +55,8 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
// If this cast changes element count then we can't handle it here:
// doing so requires endianness information. This should be handled by
// Analysis/ConstantFolding.cpp
- unsigned NumElts = DstTy->getNumElements();
- if (NumElts != cast<VectorType>(CV->getType())->getNumElements())
+ unsigned NumElts = cast<FixedVectorType>(DstTy)->getNumElements();
+ if (NumElts != cast<FixedVectorType>(CV->getType())->getNumElements())
return nullptr;
Type *DstEltTy = DstTy->getElementType();
@@ -573,8 +573,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
// count may be mismatched; don't attempt to handle that here.
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
DestTy->isVectorTy() &&
- cast<VectorType>(DestTy)->getNumElements() ==
- cast<VectorType>(V->getType())->getNumElements()) {
+ cast<FixedVectorType>(DestTy)->getNumElements() ==
+ cast<FixedVectorType>(V->getType())->getNumElements()) {
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
// Fast path for splatted constants.
@@ -585,7 +585,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
}
SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
- for (unsigned i = 0, e = cast<VectorType>(V->getType())->getNumElements();
+ for (unsigned i = 0,
+ e = cast<FixedVectorType>(V->getType())->getNumElements();
i != e; ++i) {
Constant *C =
ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i));
@@ -809,9 +810,11 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
if (!CIdx)
return nullptr;
- // ee({w,x,y,z}, wrong_value) -> undef
- if (CIdx->uge(ValVTy->getNumElements()))
- return UndefValue::get(ValVTy->getElementType());
+ if (auto *ValFVTy = dyn_cast<FixedVectorType>(Val->getType())) {
+ // ee({w,x,y,z}, wrong_value) -> undef
+ if (CIdx->uge(ValFVTy->getNumElements()))
+ return UndefValue::get(ValFVTy->getElementType());
+ }
// ee (gep (ptr, idx0, ...), idx) -> gep (ee (ptr, idx), ee (idx0, idx), ...)
if (auto *CE = dyn_cast<ConstantExpr>(Val)) {
@@ -823,7 +826,7 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
if (Op->getType()->isVectorTy()) {
Constant *ScalarOp = ConstantExpr::getExtractElement(Op, Idx);
if (!ScalarOp)
- return nullptr;
+ return nullptr;
Ops.push_back(ScalarOp);
} else
Ops.push_back(Op);
@@ -833,6 +836,16 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
}
}
+ // CAZ of type ScalableVectorType and n < CAZ->getMinNumElements() =>
+ // extractelt CAZ, n -> 0
+ if (auto *ValSVTy = dyn_cast<ScalableVectorType>(Val->getType())) {
+ if (!CIdx->uge(ValSVTy->getMinNumElements())) {
+ if (auto *CAZ = dyn_cast<ConstantAggregateZero>(Val))
+ return CAZ->getElementValue(CIdx->getZExtValue());
+ }
+ return nullptr;
+ }
+
return Val->getAggregateElement(CIdx);
}
@@ -847,11 +860,12 @@ Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
// Do not iterate on scalable vector. The num of elements is unknown at
// compile-time.
- VectorType *ValTy = cast<VectorType>(Val->getType());
- if (isa<ScalableVectorType>(ValTy))
+ if (isa<ScalableVectorType>(Val->getType()))
return nullptr;
- unsigned NumElts = cast<VectorType>(Val->getType())->getNumElements();
+ auto *ValTy = cast<FixedVectorType>(Val->getType());
+
+ unsigned NumElts = ValTy->getNumElements();
if (CIdx->uge(NumElts))
return UndefValue::get(Val->getType());
@@ -898,7 +912,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, Constant *V2,
if (isa<ScalableVectorType>(V1VTy))
return nullptr;
- unsigned SrcNumElts = V1VTy->getNumElements();
+ unsigned SrcNumElts = V1VTy->getElementCount().Min;
// Loop over the shuffle mask, evaluating each element.
SmallVector<Constant*, 32> Result;
@@ -998,11 +1012,8 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
case Instruction::FNeg:
return ConstantFP::get(C->getContext(), neg(CV));
}
- } else if (VectorType *VTy = dyn_cast<VectorType>(C->getType())) {
- // Do not iterate on scalable vector. The number of elements is unknown at
- // compile-time.
- if (IsScalableVector)
- return nullptr;
+ } else if (auto *VTy = dyn_cast<FixedVectorType>(C->getType())) {
+
Type *Ty = IntegerType::get(VTy->getContext(), 32);
// Fast path for splatted constants.
if (Constant *Splat = C->getSplatValue()) {
@@ -1011,7 +1022,7 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
}
// Fold each element and create a vector constant from those constants.
- SmallVector<Constant*, 16> Result;
+ SmallVector<Constant *, 16> Result;
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *ExtractIdx = ConstantInt::get(Ty, i);
Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@@ -1367,11 +1378,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
return ConstantFP::get(C1->getContext(), C3V);
}
}
- } else if (VectorType *VTy = dyn_cast<VectorType>(C1->getType())) {
+ } else if (IsScalableVector) {
// Do not iterate on scalable vector. The number of elements is unknown at
// compile-time.
- if (IsScalableVector)
- return nullptr;
+ // FIXME: this branch can potentially be removed
+ return nullptr;
+ } else if (auto *VTy = dyn_cast<FixedVectorType>(C1->getType())) {
// Fast path for splatted constants.
if (Constant *C2Splat = C2->getSplatValue()) {
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
@@ -2014,7 +2026,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
SmallVector<Constant*, 4> ResElts;
Type *Ty = IntegerType::get(C1->getContext(), 32);
// Compare the elements, producing an i1 result or constant expr.
- for (unsigned i = 0, e = C1VTy->getNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = C1VTy->getElementCount().Min; i != e; ++i) {
Constant *C1E =
ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i));
Constant *C2E =
@@ -2286,14 +2298,18 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
assert(Ty && "Invalid indices for GEP!");
Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace());
- if (VectorType *VT = dyn_cast<VectorType>(C->getType()))
- GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
-
+ if (VectorType *VT = dyn_cast<VectorType>(C->getType())) {
+ // FIXME: handle scalable vectors (use getElementCount())
+ GEPTy = FixedVectorType::get(
+ OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
+ }
// The GEP returns a vector of pointers when one of more of
// its arguments is a vector.
for (unsigned i = 0, e = Idxs.size(); i != e; ++i) {
if (auto *VT = dyn_cast<VectorType>(Idxs[i]->getType())) {
- GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements());
+ // FIXME: handle scalable vectors
+ GEPTy = FixedVectorType::get(
+ OrigGEPTy, cast<FixedVectorType>(VT)->getNumElements());
break;
}
}
@@ -2500,19 +2516,19 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
if (!IsCurrIdxVector && IsPrevIdxVector)
CurrIdx = ConstantDataVector::getSplat(
- cast<VectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
+ cast<FixedVectorType>(PrevIdx->getType())->getNumElements(), CurrIdx);
if (!IsPrevIdxVector && IsCurrIdxVector)
PrevIdx = ConstantDataVector::getSplat(
- cast<VectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
+ cast<FixedVectorType>(CurrIdx->getType())->getNumElements(), PrevIdx);
Constant *Factor =
ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
if (UseVector)
Factor = ConstantDataVector::getSplat(
IsPrevIdxVector
- ? cast<VectorType>(PrevIdx->getType())->getNumElements()
- : cast<VectorType>(CurrIdx->getType())->getNumElements(),
+ ? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
+ : cast<FixedVectorType>(CurrIdx->getType())->getNumElements(),
Factor);
NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
@@ -2531,8 +2547,8 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
ExtendedTy = FixedVectorType::get(
ExtendedTy,
IsPrevIdxVector
- ? cast<VectorType>(PrevIdx->getType())->getNumElements()
- : cast<VectorType>(CurrIdx->getType())->getNumElements());
+ ? cast<FixedVectorType>(PrevIdx->getType())->getNumElements()
+ : cast<FixedVectorType>(CurrIdx->getType())->getNumElements());
if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);
diff --git a/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll
new file mode 100644
index 000000000000..c4b42be45019
--- /dev/null
+++ b/llvm/test/Analysis/ConstantFolding/extractelement-vscale.ll
@@ -0,0 +1,13 @@
+; RUN: opt -instcombine -S < %s | FileCheck %s
+
+; CHECK-LABEL: definitely_in_bounds
+; CHECK: ret i8 0
+define i8 @definitely_in_bounds() {
+ ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 15)
+}
+
+; CHECK-LABEL: maybe_in_bounds
+; CHECK: ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
+define i8 @maybe_in_bounds() {
+ ret i8 extractelement (<vscale x 16 x i8> zeroinitializer, i64 16)
+}
More information about the llvm-commits
mailing list