[llvm] c912f88 - [InstCombine] Remove false commutativity from processUMulZExtIdiom() (NFCI)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 25 02:31:40 PDT 2023


Author: Nikita Popov
Date: 2023-10-25T11:31:31+02:00
New Revision: c912f88c2177f44d9a584b338f94b29a7873e028

URL: https://github.com/llvm/llvm-project/commit/c912f88c2177f44d9a584b338f94b29a7873e028
DIFF: https://github.com/llvm/llvm-project/commit/c912f88c2177f44d9a584b338f94b29a7873e028.diff

LOG: [InstCombine] Remove false commutativity from processUMulZExtIdiom() (NFCI)

This fold requires a fold against a constant, which will always be
on the RHS. If the swapped fold actually did trigger, it would
result in a miscompile, because it did not work with the swapped
predicate when swapping operands.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 1b5c4b1ffd7f809..8e3242a86199b75 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5807,15 +5807,13 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
 /// \returns Instruction which must replace the compare instruction, NULL if no
 ///          replacement required.
 static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
-                                         Value *OtherVal,
+                                         const APInt *OtherVal,
                                          InstCombinerImpl &IC) {
   // Don't bother doing this transformation for pointers, don't do it for
   // vectors.
   if (!isa<IntegerType>(MulVal->getType()))
     return nullptr;
 
-  assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
-  assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
   auto *MulInstr = dyn_cast<Instruction>(MulVal);
   if (!MulInstr)
     return nullptr;
@@ -5875,28 +5873,26 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
 
   // Recognize patterns
   switch (I.getPredicate()) {
-  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_UGT: {
     // Recognize pattern:
     //   mulval = mul(zext A, zext B)
     //   cmp ugt mulval, max
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
-      APInt MaxVal = APInt::getMaxValue(MulWidth);
-      MaxVal = MaxVal.zext(CI->getBitWidth());
-      if (MaxVal.eq(CI->getValue()))
-        break; // Recognized
-    }
+    APInt MaxVal = APInt::getMaxValue(MulWidth);
+    MaxVal = MaxVal.zext(OtherVal->getBitWidth());
+    if (MaxVal.eq(*OtherVal))
+      break; // Recognized
     return nullptr;
+  }
 
-  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULT: {
     // Recognize pattern:
     //   mulval = mul(zext A, zext B)
     //   cmp ule mulval, max + 1
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) {
-      APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth);
-      if (MaxVal.eq(CI->getValue()))
-        break; // Recognized
-    }
+    APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth);
+    if (MaxVal.eq(*OtherVal))
+      break; // Recognized
     return nullptr;
+  }
 
   default:
     return nullptr;
@@ -5922,7 +5918,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
   if (MulVal->hasNUsesOrMore(2)) {
     Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value");
     for (User *U : make_early_inc_range(MulVal->users())) {
-      if (U == &I || U == OtherVal)
+      if (U == &I)
         continue;
       if (TruncInst *TI = dyn_cast<TruncInst>(U)) {
         if (TI->getType()->getPrimitiveSizeInBits() == MulWidth)
@@ -5943,27 +5939,10 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
       IC.addToWorklist(cast<Instruction>(U));
     }
   }
-  if (isa<Instruction>(OtherVal))
-    IC.addToWorklist(cast<Instruction>(OtherVal));
 
   // The original icmp gets replaced with the overflow value, maybe inverted
   // depending on predicate.
-  bool Inverse = false;
-  switch (I.getPredicate()) {
-  case ICmpInst::ICMP_UGT:
-    if (I.getOperand(0) == MulVal)
-      break;
-    Inverse = true;
-    break;
-  case ICmpInst::ICMP_ULT:
-    if (I.getOperand(1) == MulVal)
-      break;
-    Inverse = true;
-    break;
-  default:
-    llvm_unreachable("Unexpected predicate");
-  }
-  if (Inverse) {
+  if (I.getPredicate() == ICmpInst::ICMP_ULT) {
     Value *Res = Builder.CreateExtractValue(Call, 1);
     return BinaryOperator::CreateNot(Res);
   }
@@ -7083,12 +7062,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
     }
 
     // (zext a) * (zext b)  --> llvm.umul.with.overflow.
-    if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
-      if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this))
-        return R;
-    }
-    if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
-      if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this))
+    if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B)))) &&
+        match(Op1, m_APInt(C))) {
+      if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this))
         return R;
     }
 


        


More information about the llvm-commits mailing list