[llvm] 8607a02 - [InstSimplify] Transform X * Y % Y --> 0

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue May 25 07:29:00 PDT 2021


Author: David Goldblatt
Date: 2021-05-25T10:16:04-04:00
New Revision: 8607a023574f29cbb0b3fdd26f36872ca6b4af5e

URL: https://github.com/llvm/llvm-project/commit/8607a023574f29cbb0b3fdd26f36872ca6b4af5e
DIFF: https://github.com/llvm/llvm-project/commit/8607a023574f29cbb0b3fdd26f36872ca6b4af5e.diff

LOG: [InstSimplify] Transform X * Y % Y --> 0

simplifyDiv already handles the case X * Y / Y --> X (barring overflow).
This adds the equivalent handling to simplifyRem.

Correctness:
https://alive2.llvm.org/ce/z/J2cUbS
https://alive2.llvm.org/ce/z/us9NUM
https://alive2.llvm.org/ce/z/AvaDGJ
https://alive2.llvm.org/ce/z/kq9ige

Extending the situations in which we apply this transform would not be
correct:
https://alive2.llvm.org/ce/z/Lf9V63
https://alive2.llvm.org/ce/z/6RPQK3
https://alive2.llvm.org/ce/z/p9UdxC
https://alive2.llvm.org/ce/z/A2zlhE
https://alive2.llvm.org/ce/z/vHTtLw
https://alive2.llvm.org/ce/z/lvpH42

Differential Revision: https://reviews.llvm.org/D102864

Added: 
    

Modified: 
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/test/Transforms/InstSimplify/rem.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 0f5a5bb63735..778ce93f61f8 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -923,8 +923,11 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
 
 /// Check for common or similar folds of integer division or integer remainder.
 /// This applies to all 4 opcodes (sdiv/udiv/srem/urem).
-static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv,
-                             const SimplifyQuery &Q) {
+static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
+                             Value *Op1, const SimplifyQuery &Q) {
+  bool IsDiv = (Opcode == Instruction::SDiv || Opcode == Instruction::UDiv);
+  bool IsSigned = (Opcode == Instruction::SDiv || Opcode == Instruction::SRem);
+
   Type *Ty = Op0->getType();
 
   // X / undef -> poison
@@ -976,6 +979,21 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv,
       (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))
     return IsDiv ? Op0 : Constant::getNullValue(Ty);
 
+  // If X * Y does not overflow, then:
+  //   X * Y / Y -> X
+  //   X * Y % Y -> 0
+  if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
+    auto *Mul = cast<OverflowingBinaryOperator>(Op0);
+    // The multiplication can't overflow if it is defined not to, or if
+    // X == A / Y for some A.
+    if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
+        (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)) ||
+        (IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
+        (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) {
+      return IsDiv ? X : Constant::getNullValue(Op0->getType());
+    }
+  }
+
   return nullptr;
 }
 
@@ -1047,25 +1065,11 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
     return C;
 
-  if (Value *V = simplifyDivRem(Op0, Op1, true, Q))
+  if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
     return V;
 
   bool IsSigned = Opcode == Instruction::SDiv;
 
-  // (X * Y) / Y -> X if the multiplication does not overflow.
-  Value *X;
-  if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
-    auto *Mul = cast<OverflowingBinaryOperator>(Op0);
-    // If the Mul does not overflow, then we are good to go.
-    if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
-        (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)))
-      return X;
-    // If X has the form X = A / Y, then X * Y cannot overflow.
-    if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
-        (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1)))))
-      return X;
-  }
-
   // (X rem Y) / Y -> 0
   if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) ||
       (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1)))))
@@ -1073,7 +1077,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
 
   // (X /u C1) /u C2 -> 0 if C1 * C2 overflow
   ConstantInt *C1, *C2;
-  if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) &&
+  if (!IsSigned && match(Op0, m_UDiv(m_Value(), m_ConstantInt(C1))) &&
       match(Op1, m_ConstantInt(C2))) {
     bool Overflow;
     (void)C1->getValue().umul_ov(C2->getValue(), Overflow);
@@ -1105,7 +1109,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
     return C;
 
-  if (Value *V = simplifyDivRem(Op0, Op1, false, Q))
+  if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
     return V;
 
   // (X % Y) % Y -> X % Y

diff  --git a/llvm/test/Transforms/InstSimplify/rem.ll b/llvm/test/Transforms/InstSimplify/rem.ll
index 5ee893a3e77c..c42d6d855d44 100644
--- a/llvm/test/Transforms/InstSimplify/rem.ll
+++ b/llvm/test/Transforms/InstSimplify/rem.ll
@@ -335,9 +335,7 @@ define i8 @srem_minusone_divisor() {
 
 define i32 @srem_of_mul_nsw(i32 %x, i32 %y) {
 ; CHECK-LABEL: @srem_of_mul_nsw(
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[MOD:%.*]] = srem i32 [[MUL]], [[Y]]
-; CHECK-NEXT:    ret i32 [[MOD]]
+; CHECK-NEXT:    ret i32 0
 ;
   %mul = mul nsw i32 %x, %y
   %mod = srem i32 %mul, %y
@@ -349,9 +347,7 @@ define i32 @srem_of_mul_nsw(i32 %x, i32 %y) {
 ;   - vector types
 define <2 x i32> @srem_of_mul_nsw_vec_commuted(<2 x i32> %x, <2 x i32> %y) {
 ; CHECK-LABEL: @srem_of_mul_nsw_vec_commuted(
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw <2 x i32> [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[MOD:%.*]] = srem <2 x i32> [[MUL]], [[Y]]
-; CHECK-NEXT:    ret <2 x i32> [[MOD]]
+; CHECK-NEXT:    ret <2 x i32> zeroinitializer
 ;
   %mul = mul nsw <2 x i32> %y, %x
   %mod = srem <2 x i32> %mul, %y
@@ -393,9 +389,7 @@ define i32 @urem_of_mul_nsw(i32 %x, i32 %y) {
 
 define i32 @urem_of_mul_nuw(i32 %x, i32 %y) {
 ; CHECK-LABEL: @urem_of_mul_nuw(
-; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[MOD:%.*]] = urem i32 [[MUL]], [[Y]]
-; CHECK-NEXT:    ret i32 [[MOD]]
+; CHECK-NEXT:    ret i32 0
 ;
   %mul = mul nuw i32 %x, %y
   %mod = urem i32 %mul, %y
@@ -404,9 +398,7 @@ define i32 @urem_of_mul_nuw(i32 %x, i32 %y) {
 
 define <2 x i32> @srem_of_mul_nuw_vec_commuted(<2 x i32> %x, <2 x i32> %y) {
 ; CHECK-LABEL: @srem_of_mul_nuw_vec_commuted(
-; CHECK-NEXT:    [[MUL:%.*]] = mul nuw <2 x i32> [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[MOD:%.*]] = urem <2 x i32> [[MUL]], [[Y]]
-; CHECK-NEXT:    ret <2 x i32> [[MOD]]
+; CHECK-NEXT:    ret <2 x i32> zeroinitializer
 ;
   %mul = mul nuw <2 x i32> %y, %x
   %mod = urem <2 x i32> %mul, %y
@@ -426,10 +418,7 @@ define i32 @urem_of_mul(i32 %x, i32 %y) {
 
 define i4 @srem_mul_sdiv(i4 %x, i4 %y) {
 ; CHECK-LABEL: @srem_mul_sdiv(
-; CHECK-NEXT:    [[D:%.*]] = sdiv i4 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[MUL:%.*]] = mul i4 [[D]], [[Y]]
-; CHECK-NEXT:    [[MOD:%.*]] = srem i4 [[MUL]], [[Y]]
-; CHECK-NEXT:    ret i4 [[MOD]]
+; CHECK-NEXT:    ret i4 0
 ;
   %d = sdiv i4 %x, %y
   %mul = mul i4 %d, %y
@@ -452,10 +441,7 @@ define i8 @srem_mul_udiv(i8 %x, i8 %y) {
 
 define <3 x i7> @urem_mul_udiv_vec_commuted(<3 x i7> %x, <3 x i7> %y) {
 ; CHECK-LABEL: @urem_mul_udiv_vec_commuted(
-; CHECK-NEXT:    [[D:%.*]] = udiv <3 x i7> [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[MUL:%.*]] = mul <3 x i7> [[Y]], [[D]]
-; CHECK-NEXT:    [[MOD:%.*]] = urem <3 x i7> [[MUL]], [[Y]]
-; CHECK-NEXT:    ret <3 x i7> [[MOD]]
+; CHECK-NEXT:    ret <3 x i7> zeroinitializer
 ;
   %d = udiv <3 x i7> %x, %y
   %mul = mul <3 x i7> %y, %d


        


More information about the llvm-commits mailing list