[llvm] 664b7a4 - [SCCP] Fix conversion of range to constant for vectors (PR63380)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 19 03:29:52 PDT 2023


Author: Nikita Popov
Date: 2023-06-19T12:29:44+02:00
New Revision: 664b7a4cd51d9273888e79688f64cc8bbcbdbe25

URL: https://github.com/llvm/llvm-project/commit/664b7a4cd51d9273888e79688f64cc8bbcbdbe25
DIFF: https://github.com/llvm/llvm-project/commit/664b7a4cd51d9273888e79688f64cc8bbcbdbe25.diff

LOG: [SCCP] Fix conversion of range to constant for vectors (PR63380)

The ConstantRange specifies the range of the scalar elements in the
vector. When converting into a Constant, we need to create a vector
splat with the correct type. For that purpose, pass in the expected
type for the constant.

Fixes https://github.com/llvm/llvm-project/issues/63380.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/SCCPSolver.h
    llvm/lib/Transforms/Utils/SCCPSolver.cpp
    llvm/test/Transforms/SCCP/intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
index 3754b51f4722d..7930d95e1deaf 100644
--- a/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
+++ b/llvm/include/llvm/Transforms/Utils/SCCPSolver.h
@@ -160,7 +160,7 @@ class SCCPSolver {
 
   /// Helper to return a Constant if \p LV is either a constant or a constant
   /// range with a single element.
-  Constant *getConstant(const ValueLatticeElement &LV) const;
+  Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const;
 
   /// Return either a Constant or nullptr for a given Value.
   Constant *getConstantOrNull(Value *V) const;

diff  --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 902651ab84f68..24d1a46cfd40f 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -394,8 +394,8 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
   LLVMContext &Ctx;
 
 private:
-  ConstantInt *getConstantInt(const ValueLatticeElement &IV) const {
-    return dyn_cast_or_null<ConstantInt>(getConstant(IV));
+  ConstantInt *getConstantInt(const ValueLatticeElement &IV, Type *Ty) const {
+    return dyn_cast_or_null<ConstantInt>(getConstant(IV, Ty));
   }
 
   // pushToWorkList - Helper for markConstant/markOverdefined
@@ -778,7 +778,7 @@ class SCCPInstVisitor : public InstVisitor<SCCPInstVisitor> {
 
   bool isStructLatticeConstant(Function *F, StructType *STy);
 
-  Constant *getConstant(const ValueLatticeElement &LV) const;
+  Constant *getConstant(const ValueLatticeElement &LV, Type *Ty) const;
 
   Constant *getConstantOrNull(Value *V) const;
 
@@ -881,14 +881,18 @@ bool SCCPInstVisitor::isStructLatticeConstant(Function *F, StructType *STy) {
   return true;
 }
 
-Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV) const {
-  if (LV.isConstant())
-    return LV.getConstant();
+Constant *SCCPInstVisitor::getConstant(const ValueLatticeElement &LV,
+                                       Type *Ty) const {
+  if (LV.isConstant()) {
+    Constant *C = LV.getConstant();
+    assert(C->getType() == Ty && "Type mismatch");
+    return C;
+  }
 
   if (LV.isConstantRange()) {
     const auto &CR = LV.getConstantRange();
     if (CR.getSingleElement())
-      return ConstantInt::get(Ctx, *CR.getSingleElement());
+      return ConstantInt::get(Ty, *CR.getSingleElement());
   }
   return nullptr;
 }
@@ -904,7 +908,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
     for (unsigned I = 0, E = ST->getNumElements(); I != E; ++I) {
       ValueLatticeElement LV = LVs[I];
       ConstVals.push_back(SCCPSolver::isConstant(LV)
-                              ? getConstant(LV)
+                              ? getConstant(LV, ST->getElementType(I))
                               : UndefValue::get(ST->getElementType(I)));
     }
     Const = ConstantStruct::get(ST, ConstVals);
@@ -912,7 +916,7 @@ Constant *SCCPInstVisitor::getConstantOrNull(Value *V) const {
     const ValueLatticeElement &LV = getLatticeValueFor(V);
     if (SCCPSolver::isOverdefined(LV))
       return nullptr;
-    Const = SCCPSolver::isConstant(LV) ? getConstant(LV)
+    Const = SCCPSolver::isConstant(LV) ? getConstant(LV, V->getType())
                                        : UndefValue::get(V->getType());
   }
   assert(Const && "Constant is nullptr here!");
@@ -1007,7 +1011,7 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
     }
 
     ValueLatticeElement BCValue = getValueState(BI->getCondition());
-    ConstantInt *CI = getConstantInt(BCValue);
+    ConstantInt *CI = getConstantInt(BCValue, BI->getCondition()->getType());
     if (!CI) {
       // Overdefined condition variables, and branches on unfoldable constant
       // conditions, mean the branch could go either way.
@@ -1033,7 +1037,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
       return;
     }
     const ValueLatticeElement &SCValue = getValueState(SI->getCondition());
-    if (ConstantInt *CI = getConstantInt(SCValue)) {
+    if (ConstantInt *CI =
+            getConstantInt(SCValue, SI->getCondition()->getType())) {
       Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true;
       return;
     }
@@ -1064,7 +1069,8 @@ void SCCPInstVisitor::getFeasibleSuccessors(Instruction &TI,
   if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {
     // Casts are folded by visitCastInst.
     ValueLatticeElement IBRValue = getValueState(IBR->getAddress());
-    BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue));
+    BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(
+        getConstant(IBRValue, IBR->getAddress()->getType()));
     if (!Addr) { // Overdefined or unknown condition?
       // All destinations are executable!
       if (!IBRValue.isUnknownOrUndef())
@@ -1219,7 +1225,7 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
   if (OpSt.isUnknownOrUndef())
     return;
 
-  if (Constant *OpC = getConstant(OpSt)) {
+  if (Constant *OpC = getConstant(OpSt, I.getOperand(0)->getType())) {
     // Fold the constant as we build.
     Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);
     markConstant(&I, C);
@@ -1354,7 +1360,8 @@ void SCCPInstVisitor::visitSelectInst(SelectInst &I) {
   if (CondValue.isUnknownOrUndef())
     return;
 
-  if (ConstantInt *CondCB = getConstantInt(CondValue)) {
+  if (ConstantInt *CondCB =
+          getConstantInt(CondValue, I.getCondition()->getType())) {
     Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue();
     mergeInValue(&I, getValueState(OpVal));
     return;
@@ -1387,8 +1394,8 @@ void SCCPInstVisitor::visitUnaryOperator(Instruction &I) {
     return;
 
   if (SCCPSolver::isConstant(V0State))
-    if (Constant *C = ConstantFoldUnaryOpOperand(I.getOpcode(),
-                                                 getConstant(V0State), DL))
+    if (Constant *C = ConstantFoldUnaryOpOperand(
+            I.getOpcode(), getConstant(V0State, I.getType()), DL))
       return (void)markConstant(IV, &I, C);
 
   markOverdefined(&I);
@@ -1412,8 +1419,8 @@ void SCCPInstVisitor::visitFreezeInst(FreezeInst &I) {
     return;
 
   if (SCCPSolver::isConstant(V0State) &&
-      isGuaranteedNotToBeUndefOrPoison(getConstant(V0State)))
-    return (void)markConstant(IV, &I, getConstant(V0State));
+      isGuaranteedNotToBeUndefOrPoison(getConstant(V0State, I.getType())))
+    return (void)markConstant(IV, &I, getConstant(V0State, I.getType()));
 
   markOverdefined(&I);
 }
@@ -1437,10 +1444,12 @@ void SCCPInstVisitor::visitBinaryOperator(Instruction &I) {
   // If either of the operands is a constant, try to fold it to a constant.
   // TODO: Use information from notconstant better.
   if ((V1State.isConstant() || V2State.isConstant())) {
-    Value *V1 = SCCPSolver::isConstant(V1State) ? getConstant(V1State)
-                                                : I.getOperand(0);
-    Value *V2 = SCCPSolver::isConstant(V2State) ? getConstant(V2State)
-                                                : I.getOperand(1);
+    Value *V1 = SCCPSolver::isConstant(V1State)
+                    ? getConstant(V1State, I.getOperand(0)->getType())
+                    : I.getOperand(0);
+    Value *V2 = SCCPSolver::isConstant(V2State)
+                    ? getConstant(V2State, I.getOperand(1)->getType())
+                    : I.getOperand(1);
     Value *R = simplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL));
     auto *C = dyn_cast_or_null<Constant>(R);
     if (C) {
@@ -1518,7 +1527,7 @@ void SCCPInstVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
     if (SCCPSolver::isOverdefined(State))
       return (void)markOverdefined(&I);
 
-    if (Constant *C = getConstant(State)) {
+    if (Constant *C = getConstant(State, I.getOperand(i)->getType())) {
       Operands.push_back(C);
       continue;
     }
@@ -1584,7 +1593,7 @@ void SCCPInstVisitor::visitLoadInst(LoadInst &I) {
   ValueLatticeElement &IV = ValueState[&I];
 
   if (SCCPSolver::isConstant(PtrVal)) {
-    Constant *Ptr = getConstant(PtrVal);
+    Constant *Ptr = getConstant(PtrVal, I.getOperand(0)->getType());
 
     // load null is undefined.
     if (isa<ConstantPointerNull>(Ptr)) {
@@ -1647,7 +1656,7 @@ void SCCPInstVisitor::handleCallOverdefined(CallBase &CB) {
       if (SCCPSolver::isOverdefined(State))
         return (void)markOverdefined(&CB);
       assert(SCCPSolver::isConstant(State) && "Unknown state!");
-      Operands.push_back(getConstant(State));
+      Operands.push_back(getConstant(State, A->getType()));
     }
 
     if (SCCPSolver::isOverdefined(getValueState(&CB)))
@@ -2067,8 +2076,9 @@ bool SCCPSolver::isStructLatticeConstant(Function *F, StructType *STy) {
   return Visitor->isStructLatticeConstant(F, STy);
 }
 
-Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV) const {
-  return Visitor->getConstant(LV);
+Constant *SCCPSolver::getConstant(const ValueLatticeElement &LV,
+                                  Type *Ty) const {
+  return Visitor->getConstant(LV, Ty);
 }
 
 Constant *SCCPSolver::getConstantOrNull(Value *V) const {

diff  --git a/llvm/test/Transforms/SCCP/intrinsics.ll b/llvm/test/Transforms/SCCP/intrinsics.ll
index 3fc7637ab7327..5edb31738e685 100644
--- a/llvm/test/Transforms/SCCP/intrinsics.ll
+++ b/llvm/test/Transforms/SCCP/intrinsics.ll
@@ -122,3 +122,17 @@ exit:
   %p_umax = call i8 @llvm.umax.i8(i8 %p, i8 1)
   ret i8 %p_umax
 }
+
+define <4 x i32> @pr63380(<4 x i32> %input) {
+; CHECK-LABEL: @pr63380(
+; CHECK-NEXT:    [[CTLZ_1:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[INPUT:%.*]], i1 false)
+; CHECK-NEXT:    [[CTLZ_2:%.*]] = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> [[CTLZ_1]], i1 true)
+; CHECK-NEXT:    ret <4 x i32> <i32 27, i32 27, i32 27, i32 27>
+;
+  %ctlz.1 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %input, i1 false)
+  %ctlz.2 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.1, i1 true)
+  %ctlz.3 = call <4 x i32> @llvm.ctlz.v4i32(<4 x i32> %ctlz.2, i1 true)
+  ret <4 x i32> %ctlz.3
+}
+
+declare <4 x i32> @llvm.ctlz.v4i32(<4 x i32>, i1 immarg)


        


More information about the llvm-commits mailing list