[llvm] 664b7a4 - [SCCP] Fix conversion of range to constant for vectors (PR63380)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 19 03:29:52 PDT 2023
Author: Nikita Popov
Date: 2023-06-19T12:29:44+02:00
New Revision: 664b7a4cd51d9273888e79688f64cc8bbcbdbe25
URL: https://github.com/llvm/llvm-project/commit/664b7a4cd51d9273888e79688f64cc8bbcbdbe25
DIFF: https://github.com/llvm/llvm-project/commit/664b7a4cd51d9273888e79688f64cc8bbcbdbe25.diff
LOG: [SCCP] Fix conversion of range to constant for vectors (PR63380)
The ConstantRange specifies the range of the scalar elements in the
vector. When converting into a Constant, we need to create a vector
splat with the correct type. For that purpose, pass in the expected
type for the constant.
Fixes https://github.com/llvm/llvm-project/issues/63380.
Added:
Modified:
llvm/include/llvm/Transforms/Utils/SCCPSolver.h
llvm/lib/Transforms/Utils/SCCPSolver.cpp
llvm/test/Transforms/SCCP/intrinsics.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 3754b51f4722d..7930d95e1deaf 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -160,7 +160,7 @@ class SCCPSolver {
/// Helper to return a Constant if \p LV is either a constant or a constant
/// range with a single element.
- Constant *getConstant(const ValueLatticeElement &LV) const;
+ Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const;
/// Return either a Constant or nullptr for a given Value.
Constant *getConstantOrNull(Value *V) const;
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 902651ab84f68..24d1a46cfd40f 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -394,8 +394,8 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
LLVMContext &Ctx;
private:
- ConstantInt *getConstantInt(const ValueLatticeElement &IV) const {
- return dyn_cast_or_null<ConstantInt>(getConstant(IV));
+ ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const {
+ return dyn_cast_or_null<ConstantInt>(getConstant(IV, Ty));
}
// pushToWorkList - Helper for markConstant/markOverdefined
@@ -778,7 +778,7 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
bool isStructLatticeConstant(Function *F, StructType *STy);
- Constant *getConstant(const ValueLatticeElement &LV) const;
+ Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const;
Constant *getConstantOrNull(Value *V) const;
@@ -881,14 +881,18 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) {
return true;
}
-Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
- if (LV.isConstant())
- return LV.getConstant();
+Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV,
+ Type *Ty) const {
+ if (LV.isConstant()) {
+ Constant *C = LV.getConstant();
+ assert(C->getType() == Ty && "Type mismatch");
+ return C;
+ }
if (LV.isConstantRange()) {
const auto &CR = LV.getConstantRange();
if (CR.getSingleElement())
- return ConstantInt::get(Ctx, *CR.getSingleElement());
+ return ConstantInt::get(Ty, *CR.getSingleElement());
}
return nullptr;
}
@@ -904,7 +908,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) {
ValueLatticeElement LV = LVs[I];
ConstVals.push_back(SCCPSolver::isConstant(LV)
- ? getConstant(LV)
+ ? getConstant(LV, ST->getElementType(I))
: UndefValue::get(ST->getElementType(I)));
}
Const = ConstantStruct::get(ST, ConstVals);
@@ -912,7 +916,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
const ValueLatticeElement &LV = getLatticeValueFor(V);
if (SCCPSolver::isOverdefined(LV))
return nullptr;
- Const = SCCPSolver::isConstant(LV) ? getConstant(LV)
+ Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType())
: UndefValue::get(V->getType());
}
assert(Const && "Constant is nullptr here!");
@@ -1007,7 +1011,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
}
ValueLatticeElement BCValue = getValueState(BI->getCondition());
- ConstantInt *CI = getConstantInt(BCValue);
+ ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType());
if (!CI) {
// Overdefined condition variables, and branches on unfoldable constant
// conditions, mean the branch could go either way.
@@ -1033,7 +1037,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
return;
}
const ValueLatticeElement &SCValue = getValueState(SI->getCondition());
- if (ConstantInt *CI = getConstantInt(SCValue)) {
+ if (ConstantInt *CI =
+ getConstantInt(SCValue, SI->getCondition()->getType())) {
Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true;
return;
}
@@ -1064,7 +1069,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {
// Casts are folded by visitCastInst.
ValueLatticeElement IBRValue = getValueState(IBR->getAddress());
- BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue));
+ BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(
+ getConstant(IBRValue, IBR->getAddress()->getType()));
if (!Addr) { // Overdefined or unknown condition?
// All destinations are executable!
if (!IBRValue.isUnknownOrUndef())
@@ -1219,7 +1225,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
if (OpSt.isUnknownOrUndef())
return;
- if (Constant *OpC = getConstant(OpSt)) {
+ if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) {
// Fold the constant as we build.
Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);
markConstant(&I, C);
@@ -1354,7 +1360,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) {
if (CondValue.isUnknownOrUndef())
return;
- if (ConstantInt *CondCB = getConstantInt(CondValue)) {
+ if (ConstantInt *CondCB =
+ getConstantInt(CondValue, I.getCondition()->getType())) {
Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue();
mergeInValue(&I, getValueState(OpVal));
return;
@@ -1387,8 +1394,8 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) {
return;
if (SCCPSolver::isConstant(V0State))
- if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(),
- getConstant(V0State), DL))
+ if (Constant *C = ConstantFoldUnaryOpOperand(
+ I.getOpcode(), getConstant(V0State, I.getType()), DL))
return (void)markConstant(IV, &I, C);
markOverdefined(&I);
@@ -1412,8 +1419,8 @@ void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) {
return;
if (SCCPSolver::isConstant(V0State) &&
- isGuaranteedNotToBeUndefOrPoison(getConstant(V0State)))
- return (void)markConstant(IV, &I, getConstant(V0State));
+ isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType())))
+ return (void)markConstant(IV, &I, getConstant(V0State, I.getType()));
markOverdefined(&I);
}
@@ -1437,10 +1444,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) {
// If either of the operands is a constant, try to fold it to a constant.
// TODO: Use information from notconstant better.
if ((V1State.isConstant() || V2State.isConstant())) {
- Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State)
- : I.getOperand(0);
- Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State)
- : I.getOperand(1);
+ Value *V1 = SCCPSolver::isConstant(V1State)
+ ? getConstant(V1State, I.getOperand(0)->getType())
+ : I.getOperand(0);
+ Value *V2 = SCCPSolver::isConstant(V2State)
+ ? getConstant(V2State, I.getOperand(1)->getType())
+ : I.getOperand(1);
Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL));
auto *C = dyn_cast_or_null<Constant>(R);
if (C) {
@@ -1518,7 +1527,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
if (SCCPSolver::isOverdefined(State))
return (void)markOverdefined(&I);
- if (Constant *C = getConstant(State)) {
+ if (Constant *C = getConstant(State, I.getOperand(i)->getType())) {
Operands.push_back(C);
continue;
}
@@ -1584,7 +1593,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) {
ValueLatticeElement &IV = ValueState[&I];
if (SCCPSolver::isConstant(PtrVal)) {
- Constant *Ptr = getConstant(PtrVal);
+ Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType());
// load null is undefined.
if (isa<ConstantPointerNull>(Ptr)) {
@@ -1647,7 +1656,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) {
if (SCCPSolver::isOverdefined(State))
return (void)markOverdefined(&CB);
assert(SCCPSolver::isConstant(State) && "Unknown state!");
- Operands.push_back(getConstant(State));
+ Operands.push_back(getConstant(State, A->getType()));
}
if (SCCPSolver::isOverdefined(getValueState(&CB)))
@@ -2067,8 +2076,9 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) {
return Visitor->isStructLatticeConstant(F, STy);
}
-Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const {
- return Visitor->getConstant(LV);
+Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV,
+ Type *Ty) const {
+ return Visitor->getConstant(LV, Ty);
}
Constant *SCCPSolver::getConstantOrNull(Value *V) const {
diff --git a/llvm/test/Transforms/SCCP/intrinsics.ll b/llvm/test/Transforms/SCCP/intrinsics.ll
index 3fc7637ab7327..5edb31738e685 100644
--- a/llvm/test/Transforms/SCCP/intrinsics.ll
+++ b/llvm/test/Transforms/SCCP/intrinsics.ll
@@ -122,3 +122,17 @@ exit:
%p_umax = call i8 @llvm.umax.i8(i8 %p, i8 1)
ret i8 %p_umax
}
+
+define <4 x i32> @pr63380(<4 x i32> %input) {
+; CHECK-LABEL: @pr63380(
+; CHECK-NEXT: [[CTLZ_1:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[INPUT:%.*]], i1 false)
+; CHECK-NEXT: [[CTLZ_2:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[CTLZ_1]], i1 true)
+; CHECK-NEXT: ret <4 x i32> <i32 27, i32 27, i32 27, i32 27>
+;
+ %ctlz.1 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %input, i1 false)
+ %ctlz.2 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.1, i1 true)
+ %ctlz.3 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.2, i1 true)
+ ret <4 x i32> %ctlz.3
+}
+
+declare <4 x i32> @llvm.ctlz.v4i32(<4 x i32>, i1 immarg)
More information about the llvm-commits
mailing list