[llvm] 13ec913 - [InstCombine] Recognize `((x * y) s/ x) !=/== y` as an signed multiplication overflow check (PR48769)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 20 11:30:11 PDT 2021


Author: Roman Lebedev
Date: 2021-04-20T21:29:43+03:00
New Revision: 13ec913bdf500e2354cc55bf29e2f5d99e0c709e

URL: https://github.com/llvm/llvm-project/commit/13ec913bdf500e2354cc55bf29e2f5d99e0c709e
DIFF: https://github.com/llvm/llvm-project/commit/13ec913bdf500e2354cc55bf29e2f5d99e0c709e.diff

LOG: [InstCombine] Recognize `((x * y) s/ x) !=/== y` as an signed multiplication overflow check (PR48769)

We already had support for it's unsigned variant, so simply extend it
to also handle the signed variant.

Fixes https://bugs.llvm.org/show_bug.cgi?id=48769

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll
    llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 4e3ddae9023e..41b485706753 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3672,19 +3672,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
 
 /// Fold
 ///   (-1 u/ x) u< y
-///   ((x * y) u/ x) != y
+///   ((x * y) ?/ x) != y
 /// to
-///   @llvm.umul.with.overflow(x, y) plus extraction of overflow bit
+///   @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit
 /// Note that the comparison is commutative, while inverted (u>=, ==) predicate
 /// will mean that we are looking for the opposite answer.
-Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
+Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
   ICmpInst::Predicate Pred;
   Value *X, *Y;
   Instruction *Mul;
+  Instruction *Div;
   bool NeedNegation;
   // Look for: (-1 u/ x) u</u>= y
   if (!I.isEquality() &&
-      match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
+      match(&I, m_c_ICmp(Pred,
+                         m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
+                                      m_Instruction(Div)),
                          m_Value(Y)))) {
     Mul = nullptr;
 
@@ -3699,13 +3702,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
     default:
       return nullptr; // Wrong predicate.
     }
-  } else // Look for: ((x * y) u/ x) !=/== y
+  } else // Look for: ((x * y) / x) !=/== y
       if (I.isEquality() &&
-          match(&I, m_c_ICmp(Pred, m_Value(Y),
-                             m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
+          match(&I,
+                m_c_ICmp(Pred, m_Value(Y),
+                         m_CombineAnd(
+                             m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y),
                                                                   m_Value(X)),
                                                           m_Instruction(Mul)),
-                                             m_Deferred(X)))))) {
+                                             m_Deferred(X))),
+                             m_Instruction(Div))))) {
     NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ;
   } else
     return nullptr;
@@ -3717,19 +3723,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
   if (MulHadOtherUses)
     Builder.SetInsertPoint(Mul);
 
-  Function *F = Intrinsic::getDeclaration(
-      I.getModule(), Intrinsic::umul_with_overflow, X->getType());
-  CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul");
+  Function *F = Intrinsic::getDeclaration(I.getModule(),
+                                          Div->getOpcode() == Instruction::UDiv
+                                              ? Intrinsic::umul_with_overflow
+                                              : Intrinsic::smul_with_overflow,
+                                          X->getType());
+  CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul");
 
   // If the multiplication was used elsewhere, to ensure that we don't leave
   // "duplicate" instructions, replace uses of that original multiplication
   // with the multiplication result from the with.overflow intrinsic.
   if (MulHadOtherUses)
-    replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val"));
+    replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val"));
 
-  Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov");
+  Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov");
   if (NeedNegation) // This technically increases instruction count.
-    Res = Builder.CreateNot(Res, "umul.not.ov");
+    Res = Builder.CreateNot(Res, "mul.not.ov");
 
   // If we replaced the mul, erase it. Do this after all uses of Builder,
   // as the mul is used as insertion point.
@@ -4126,7 +4135,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
     }
   }
 
-  if (Value *V = foldUnsignedMultiplicationOverflowCheck(I))
+  if (Value *V = foldMultiplicationOverflowCheck(I))
     return replaceInstUsesWith(I, V);
 
   if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder))

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index edf8f0f2782a..15152bb6f11a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -656,7 +656,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldSignBitTest(ICmpInst &I);
   Instruction *foldICmpWithZero(ICmpInst &Cmp);
 
-  Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp);
+  Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp);
 
   Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select,
                                       ConstantInt *C);

diff  --git a/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll b/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll
index 39a1bc7d6a29..d2a5d5a3ddaa 100644
--- a/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll
@@ -8,10 +8,10 @@
 
 define i1 @t0_basic(i8 %x, i8 %y) {
 ; CHECK-LABEL: @t0_basic(
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
+; CHECK-NEXT:    ret i1 [[MUL_NOT_OV]]
 ;
   %t0 = mul i8 %x, %y
   %t1 = sdiv i8 %t0, %x
@@ -21,10 +21,10 @@ define i1 @t0_basic(i8 %x, i8 %y) {
 
 define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @t1_vec(
-; CHECK-NEXT:    [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq <2 x i8> [[T1]], [[Y]]
-; CHECK-NEXT:    ret <2 x i1> [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor <2 x i1> [[MUL_OV]], <i1 true, i1 true>
+; CHECK-NEXT:    ret <2 x i1> [[MUL_NOT_OV]]
 ;
   %t0 = mul <2 x i8> %x, %y
   %t1 = sdiv <2 x i8> %t0, %x
@@ -37,10 +37,10 @@ declare i8 @gen8()
 define i1 @t2_commutative(i8 %x) {
 ; CHECK-LABEL: @t2_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
+; CHECK-NEXT:    ret i1 [[MUL_NOT_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -52,10 +52,10 @@ define i1 @t2_commutative(i8 %x) {
 define i1 @t3_commutative(i8 %x) {
 ; CHECK-LABEL: @t3_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
+; CHECK-NEXT:    ret i1 [[MUL_NOT_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -67,10 +67,10 @@ define i1 @t3_commutative(i8 %x) {
 define i1 @t4_commutative(i8 %x) {
 ; CHECK-LABEL: @t4_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[Y]], [[T1]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
+; CHECK-NEXT:    ret i1 [[MUL_NOT_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -85,11 +85,12 @@ declare void @use8(i8)
 
 define i1 @t5_extrause0(i8 %x, i8 %y) {
 ; CHECK-LABEL: @t5_extrause0(
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    call void @use8(i8 [[T0]])
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true
+; CHECK-NEXT:    call void @use8(i8 [[MUL_VAL]])
+; CHECK-NEXT:    ret i1 [[MUL_NOT_OV]]
 ;
   %t0 = mul i8 %x, %y
   call void @use8(i8 %t0)

diff  --git a/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll b/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll
index 81c04a06e30b..f84ae67a3a0c 100644
--- a/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll
+++ b/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll
@@ -8,10 +8,9 @@
 
 define i1 @t0_basic(i8 %x, i8 %y) {
 ; CHECK-LABEL: @t0_basic(
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    ret i1 [[MUL_OV]]
 ;
   %t0 = mul i8 %x, %y
   %t1 = sdiv i8 %t0, %x
@@ -21,10 +20,9 @@ define i1 @t0_basic(i8 %x, i8 %y) {
 
 define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @t1_vec(
-; CHECK-NEXT:    [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne <2 x i8> [[T1]], [[Y]]
-; CHECK-NEXT:    ret <2 x i1> [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1
+; CHECK-NEXT:    ret <2 x i1> [[MUL_OV]]
 ;
   %t0 = mul <2 x i8> %x, %y
   %t1 = sdiv <2 x i8> %t0, %x
@@ -37,10 +35,9 @@ declare i8 @gen8()
 define i1 @t2_commutative(i8 %x) {
 ; CHECK-LABEL: @t2_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    ret i1 [[MUL_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -52,10 +49,9 @@ define i1 @t2_commutative(i8 %x) {
 define i1 @t3_commutative(i8 %x) {
 ; CHECK-LABEL: @t3_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    ret i1 [[MUL_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -67,10 +63,9 @@ define i1 @t3_commutative(i8 %x) {
 define i1 @t4_commutative(i8 %x) {
 ; CHECK-LABEL: @t4_commutative(
 ; CHECK-NEXT:    [[Y:%.*]] = call i8 @gen8()
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[Y]], [[T1]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]])
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    ret i1 [[MUL_OV]]
 ;
   %y = call i8 @gen8()
   %t0 = mul i8 %y, %x ; swapped
@@ -85,11 +80,11 @@ declare void @use8(i8)
 
 define i1 @t5_extrause0(i8 %x, i8 %y) {
 ; CHECK-LABEL: @t5_extrause0(
-; CHECK-NEXT:    [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    call void @use8(i8 [[T0]])
-; CHECK-NEXT:    [[T1:%.*]] = sdiv i8 [[T0]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[T1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
+; CHECK-NEXT:    [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0
+; CHECK-NEXT:    [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1
+; CHECK-NEXT:    call void @use8(i8 [[MUL_VAL]])
+; CHECK-NEXT:    ret i1 [[MUL_OV]]
 ;
   %t0 = mul i8 %x, %y
   call void @use8(i8 %t0)


        


More information about the llvm-commits mailing list