[llvm] 71597d4 - [ConstantFolding] refactor helper for vector reductions; NFC

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 29 09:10:14 PDT 2021


Author: Sanjay Patel
Date: 2021-04-29T12:09:22-04:00
New Revision: 71597d40e878c42990d315b23885771f87d6af4d

URL: https://github.com/llvm/llvm-project/commit/71597d40e878c42990d315b23885771f87d6af4d
DIFF: https://github.com/llvm/llvm-project/commit/71597d40e878c42990d315b23885771f87d6af4d.diff

LOG: [ConstantFolding] refactor helper for vector reductions; NFC

We should handle other cases (undef/poison), so reduce
the duplication of repeated switches.

Added: 
    

Modified: 
    llvm/lib/Analysis/ConstantFolding.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dc55930a9f26..cfc1fc4055c5 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1702,19 +1702,29 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V,
   return GetConstantFoldFPValue(V, Ty);
 }
 
-Constant *ConstantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
+Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
   FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
   if (!VT)
     return nullptr;
-  ConstantInt *CI = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
-  if (!CI)
+
+  // This isn't strictly necessary, but handle the special/common case of zero:
+  // all integer reductions of a zero input produce zero.
+  if (isa<ConstantAggregateZero>(Op))
+    return ConstantInt::get(VT->getElementType(), 0);
+
+  // TODO: Handle undef and poison.
+  if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op))
     return nullptr;
-  APInt Acc = CI->getValue();
 
-  for (unsigned I = 1; I < VT->getNumElements(); I++) {
-    if (!(CI = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
+  auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
+  if (!EltC)
+    return nullptr;
+
+  APInt Acc = EltC->getValue();
+  for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) {
+    if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
       return nullptr;
-    const APInt &X = CI->getValue();
+    const APInt &X = EltC->getValue();
     switch (IID) {
     case Intrinsic::vector_reduce_add:
       Acc = Acc + X;
@@ -2241,20 +2251,20 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
     }
   }
 
-  if (isa<ConstantAggregateZero>(Operands[0])) {
-    switch (IntrinsicID) {
-    default: break;
-    case Intrinsic::vector_reduce_add:
-    case Intrinsic::vector_reduce_mul:
-    case Intrinsic::vector_reduce_and:
-    case Intrinsic::vector_reduce_or:
-    case Intrinsic::vector_reduce_xor:
-    case Intrinsic::vector_reduce_smin:
-    case Intrinsic::vector_reduce_smax:
-    case Intrinsic::vector_reduce_umin:
-    case Intrinsic::vector_reduce_umax:
-      return ConstantInt::get(Ty, 0);
-    }
+  switch (IntrinsicID) {
+  default: break;
+  case Intrinsic::vector_reduce_add:
+  case Intrinsic::vector_reduce_mul:
+  case Intrinsic::vector_reduce_and:
+  case Intrinsic::vector_reduce_or:
+  case Intrinsic::vector_reduce_xor:
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_smax:
+  case Intrinsic::vector_reduce_umin:
+  case Intrinsic::vector_reduce_umax:
+    if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
+      return C;
+    break;
   }
 
   // Support ConstantVector in case we have an Undef in the top.
@@ -2263,18 +2273,6 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
     auto *Op = cast<Constant>(Operands[0]);
     switch (IntrinsicID) {
     default: break;
-    case Intrinsic::vector_reduce_add:
-    case Intrinsic::vector_reduce_mul:
-    case Intrinsic::vector_reduce_and:
-    case Intrinsic::vector_reduce_or:
-    case Intrinsic::vector_reduce_xor:
-    case Intrinsic::vector_reduce_smin:
-    case Intrinsic::vector_reduce_smax:
-    case Intrinsic::vector_reduce_umin:
-    case Intrinsic::vector_reduce_umax:
-      if (Constant *C = ConstantFoldVectorReduce(IntrinsicID, Op))
-        return C;
-      break;
     case Intrinsic::x86_sse_cvtss2si:
     case Intrinsic::x86_sse_cvtss2si64:
     case Intrinsic::x86_sse2_cvtsd2si:


        


More information about the llvm-commits mailing list