[llvm] 2dbb454 - [ValueTracking][LVI] Consolidate vector constant range calculation

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 06:19:36 PDT 2024


Author: Nikita Popov
Date: 2024-07-03T15:19:26+02:00
New Revision: 2dbb454791044e3ef91c8e7069f953b7406d78c6

URL: https://github.com/llvm/llvm-project/commit/2dbb454791044e3ef91c8e7069f953b7406d78c6
DIFF: https://github.com/llvm/llvm-project/commit/2dbb454791044e3ef91c8e7069f953b7406d78c6.diff

LOG: [ValueTracking][LVI] Consolidate vector constant range calculation

Add a common helper used for computeConstantRange() and LVI. The
implementation is a mix of both, with the efficient handling for
ConstantDataVector taken from computeConstantRange(), and the
general handling (including non-splat poison) from LVI.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ValueTracking.h
    llvm/lib/Analysis/LazyValueInfo.cpp
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b7b78cb9edab3..a67ad501982d2 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -904,6 +904,9 @@ bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
 /// based on the vscale_range function attribute.
 ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);
 
+/// Determine the possible constant range of a vector constant.
+ConstantRange getVectorConstantRange(const Constant *C);
+
 /// Determine the possible constant range of an integer or vector of integer
 /// value. This is intended as a cheap, non-recursive check.
 ConstantRange computeConstantRange(const Value *V, bool ForSigned,

diff  --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index e9051e74b4577..468b08a15d7df 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -836,24 +836,6 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
   }
 }
 
-static ConstantRange getConstantRangeFromFixedVector(Constant *C,
-                                                     FixedVectorType *Ty) {
-  unsigned BW = Ty->getScalarSizeInBits();
-  ConstantRange CR = ConstantRange::getEmpty(BW);
-  for (unsigned I = 0; I < Ty->getNumElements(); ++I) {
-    Constant *Elem = C->getAggregateElement(I);
-    if (!Elem)
-      return ConstantRange::getFull(BW);
-    if (isa<PoisonValue>(Elem))
-      continue;
-    auto *CI = dyn_cast<ConstantInt>(Elem);
-    if (!CI)
-      return ConstantRange::getFull(BW);
-    CR = CR.unionWith(CI->getValue());
-  }
-  return CR;
-}
-
 static ConstantRange toConstantRange(const ValueLatticeElement &Val,
                                      Type *Ty, bool UndefAllowed = false) {
   assert(Ty->isIntOrIntVectorTy() && "Must be integer type");
@@ -862,13 +844,8 @@ static ConstantRange toConstantRange(const ValueLatticeElement &Val,
   unsigned BW = Ty->getScalarSizeInBits();
   if (Val.isUnknown())
     return ConstantRange::getEmpty(BW);
-  if (Val.isConstant() && Ty->isVectorTy()) {
-    if (auto *CI = dyn_cast_or_null<ConstantInt>(
-            Val.getConstant()->getSplatValue(/*AllowPoison=*/true)))
-      return ConstantRange(CI->getValue());
-    if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
-      return getConstantRangeFromFixedVector(Val.getConstant(), VTy);
-  }
+  if (Val.isConstant() && Ty->isVectorTy())
+    return getVectorConstantRange(Val.getConstant());
   return ConstantRange::getFull(BW);
 }
 

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 7660009b088d0..5476dc5d85182 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9498,6 +9498,39 @@ static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
   }
 }
 
+ConstantRange llvm::getVectorConstantRange(const Constant *C) {
+  assert(C->getType()->isVectorTy() && "Expected vector constant");
+  if (auto *CI = dyn_cast_or_null<ConstantInt>(
+          C->getSplatValue(/*AllowPoison=*/true)))
+    return ConstantRange(CI->getValue());
+
+  unsigned BitWidth = C->getType()->getScalarSizeInBits();
+  if (auto *CDV = dyn_cast<ConstantDataVector>(C)) {
+    ConstantRange CR = ConstantRange::getEmpty(BitWidth);
+    for (unsigned I = 0, E = CDV->getNumElements(); I < E; ++I)
+      CR = CR.unionWith(CDV->getElementAsAPInt(I));
+    return CR;
+  }
+
+  if (auto *CV = dyn_cast<ConstantVector>(C)) {
+    ConstantRange CR = ConstantRange::getEmpty(BitWidth);
+    for (unsigned I = 0, E = CV->getNumOperands(); I < E; ++I) {
+      Constant *Elem = C->getAggregateElement(I);
+      if (!Elem)
+        return ConstantRange::getFull(BitWidth);
+      if (isa<PoisonValue>(Elem))
+        continue;
+      auto *CI = dyn_cast<ConstantInt>(Elem);
+      if (!CI)
+        return ConstantRange::getFull(BitWidth);
+      CR = CR.unionWith(CI->getValue());
+    }
+    return CR;
+  }
+
+  return ConstantRange::getFull(BitWidth);
+}
+
 ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
                                          bool UseInstrInfo, AssumptionCache *AC,
                                          const Instruction *CtxI,
@@ -9508,19 +9541,15 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
   if (Depth == MaxAnalysisRecursionDepth)
     return ConstantRange::getFull(V->getType()->getScalarSizeInBits());
 
-  const APInt *C;
-  if (match(V, m_APInt(C)))
-    return ConstantRange(*C);
-  unsigned BitWidth = V->getType()->getScalarSizeInBits();
-
-  if (auto *VC = dyn_cast<ConstantDataVector>(V)) {
-    ConstantRange CR = ConstantRange::getEmpty(BitWidth);
-    for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
-         ++ElemIdx)
-      CR = CR.unionWith(VC->getElementAsAPInt(ElemIdx));
-    return CR;
+  if (auto *C = dyn_cast<Constant>(V)) {
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      return ConstantRange(CI->getValue());
+    if (C->getType()->isVectorTy())
+      return getVectorConstantRange(C);
+    return ConstantRange::getFull(C->getType()->getScalarSizeInBits());
   }
 
+  unsigned BitWidth = V->getType()->getScalarSizeInBits();
   InstrInfoQuery IIQ(UseInstrInfo);
   ConstantRange CR = ConstantRange::getFull(BitWidth);
   if (auto *BO = dyn_cast<BinaryOperator>(V)) {

diff  --git a/llvm/test/Transforms/InstCombine/saturating-add-sub.ll b/llvm/test/Transforms/InstCombine/saturating-add-sub.ll
index 5a29ee7f66e35..bf1568f1cd8c0 100644
--- a/llvm/test/Transforms/InstCombine/saturating-add-sub.ll
+++ b/llvm/test/Transforms/InstCombine/saturating-add-sub.ll
@@ -1064,8 +1064,7 @@ define <2 x i8> @test_vector_usub_add_nuw_no_ov_nonsplat1(<2 x i8> %a) {
 
 define <3 x i8> @test_vector_usub_add_nuw_no_ov_nonsplat1_poison(<3 x i8> %a) {
 ; CHECK-LABEL: @test_vector_usub_add_nuw_no_ov_nonsplat1_poison(
-; CHECK-NEXT:    [[B:%.*]] = add nuw <3 x i8> [[A:%.*]], <i8 10, i8 10, i8 10>
-; CHECK-NEXT:    [[R:%.*]] = call <3 x i8> @llvm.usub.sat.v3i8(<3 x i8> [[B]], <3 x i8> <i8 10, i8 9, i8 poison>)
+; CHECK-NEXT:    [[R:%.*]] = add <3 x i8> [[A:%.*]], <i8 0, i8 1, i8 poison>
 ; CHECK-NEXT:    ret <3 x i8> [[R]]
 ;
   %b = add nuw <3 x i8> %a, <i8 10, i8 10, i8 10>


        


More information about the llvm-commits mailing list