[llvm] f0faea5 - [InstSimplify] fold exact divide to poison if it is known to not divide evenly

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 29 07:30:03 PST 2022


Author: Sanjay Patel
Date: 2022-12-29T10:26:50-05:00
New Revision: f0faea571403eb75a1d2d0dceca1dd52a7824d33

URL: https://github.com/llvm/llvm-project/commit/f0faea571403eb75a1d2d0dceca1dd52a7824d33
DIFF: https://github.com/llvm/llvm-project/commit/f0faea571403eb75a1d2d0dceca1dd52a7824d33.diff

LOG: [InstSimplify] fold exact divide to poison if it is known to not divide evenly

This is related to the discussion in D140665. I was looking over the demanded
bits implementation in IR and noticed that we just bail out of a potential
fold if a udiv is exact:
https://github.com/llvm/llvm-project/blob/82be8a1d2b00f6e89096b86f670a8be894c7b9e6/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp#L799

Also, see tests added with 7f0c11509e8f.

Then, I saw that we could lose a fold to poison if we zap the exact with that
transform, so this patch tries to catch that as a preliminary step.

Alive2 proofs:
https://alive2.llvm.org/ce/z/zCjKM7
https://alive2.llvm.org/ce/z/-tz_RK (trailing zeros must be "less-than")
https://alive2.llvm.org/ce/z/c9CMsJ (general proof and specific example)

Differential Revision: https://reviews.llvm.org/D140733

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/udiv-simplify.ll
    llvm/test/Transforms/InstSimplify/div.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 52d43bf5c2a61..0a2f199794f8b 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -188,10 +188,12 @@ Value *simplifyFMAFMul(Value *LHS, Value *RHS, FastMathFlags FMF,
 Value *simplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
 /// Given operands for an SDiv, fold the result or return null.
-Value *simplifySDivInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
+Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact,
+                        const SimplifyQuery &Q);
 
 /// Given operands for a UDiv, fold the result or return null.
-Value *simplifyUDivInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
+Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact,
+                        const SimplifyQuery &Q);
 
 /// Given operands for an FDiv, fold the result or return null.
 Value *

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e8f0b4e6d8795..78b493d478823 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1143,13 +1143,24 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
 
 /// These are simplifications common to SDiv and UDiv.
 static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
-                          const SimplifyQuery &Q, unsigned MaxRecurse) {
+                          bool IsExact, const SimplifyQuery &Q,
+                          unsigned MaxRecurse) {
   if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
     return C;
 
   if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse))
     return V;
 
+  // If this is an exact divide by a constant, then the dividend (Op0) must have
+  // at least as many trailing zeros as the divisor to divide evenly. If it has
+  // less trailing zeros, then the result must be poison.
+  const APInt *DivC;
+  if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countTrailingZeros()) {
+    KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+    if (KnownOp0.countMaxTrailingZeros() < DivC->countTrailingZeros())
+      return PoisonValue::get(Op0->getType());
+  }
+
   bool IsSigned = Opcode == Instruction::SDiv;
 
   // (X rem Y) / Y -> 0
@@ -1230,28 +1241,30 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
 
 /// Given operands for an SDiv, see if we can fold the result.
 /// If not, this returns null.
-static Value *simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
-                               unsigned MaxRecurse) {
+static Value *simplifySDivInst(Value *Op0, Value *Op1, bool IsExact,
+                               const SimplifyQuery &Q, unsigned MaxRecurse) {
   // If two operands are negated and no signed overflow, return -1.
   if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true))
     return Constant::getAllOnesValue(Op0->getType());
 
-  return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse);
+  return simplifyDiv(Instruction::SDiv, Op0, Op1, IsExact, Q, MaxRecurse);
 }
 
-Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
-  return ::simplifySDivInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, bool IsExact,
+                              const SimplifyQuery &Q) {
+  return ::simplifySDivInst(Op0, Op1, IsExact, Q, RecursionLimit);
 }
 
 /// Given operands for a UDiv, see if we can fold the result.
 /// If not, this returns null.
-static Value *simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
-                               unsigned MaxRecurse) {
-  return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse);
+static Value *simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact,
+                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+  return simplifyDiv(Instruction::UDiv, Op0, Op1, IsExact, Q, MaxRecurse);
 }
 
-Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
-  return ::simplifyUDivInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact,
+                              const SimplifyQuery &Q) {
+  return ::simplifyUDivInst(Op0, Op1, IsExact, Q, RecursionLimit);
 }
 
 /// Given operands for an SRem, see if we can fold the result.
@@ -1405,6 +1418,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
     return IsExact ? Op0 : Constant::getNullValue(Op0->getType());
 
   // The low bit cannot be shifted out of an exact shift if it is set.
+  // TODO: Generalize by counting trailing zeros (see fold for exact division).
   if (IsExact) {
     KnownBits Op0Known =
         computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
@@ -5678,9 +5692,9 @@ static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
   case Instruction::Mul:
     return simplifyMulInst(LHS, RHS, Q, MaxRecurse);
   case Instruction::SDiv:
-    return simplifySDivInst(LHS, RHS, Q, MaxRecurse);
+    return simplifySDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
   case Instruction::UDiv:
-    return simplifyUDivInst(LHS, RHS, Q, MaxRecurse);
+    return simplifyUDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
   case Instruction::SRem:
     return simplifySRemInst(LHS, RHS, Q, MaxRecurse);
   case Instruction::URem:
@@ -6553,9 +6567,11 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
   case Instruction::Mul:
     return simplifyMulInst(NewOps[0], NewOps[1], Q);
   case Instruction::SDiv:
-    return simplifySDivInst(NewOps[0], NewOps[1], Q);
+    return simplifySDivInst(NewOps[0], NewOps[1],
+                            Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
   case Instruction::UDiv:
-    return simplifyUDivInst(NewOps[0], NewOps[1], Q);
+    return simplifyUDivInst(NewOps[0], NewOps[1],
+                            Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
   case Instruction::FDiv:
     return simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
   case Instruction::SRem:

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 1a113d7f340be..2484b59682e9f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1207,7 +1207,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
 }
 
 Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
-  if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1),
+  if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), I.isExact(),
                                   SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
@@ -1287,7 +1287,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
 }
 
 Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
-  if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1),
+  if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), I.isExact(),
                                   SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 

diff  --git a/llvm/test/Transforms/InstCombine/udiv-simplify.ll b/llvm/test/Transforms/InstCombine/udiv-simplify.ll
index 559800f18deb2..f33c744ca79d6 100644
--- a/llvm/test/Transforms/InstCombine/udiv-simplify.ll
+++ b/llvm/test/Transforms/InstCombine/udiv-simplify.ll
@@ -95,13 +95,11 @@ define i8 @udiv_demanded_low_bits_set(i8 %a) {
   ret i8 %u
 }
 
-; TODO: This can't divide evenly, so it is poison.
+; This can't divide evenly, so it is poison.
 
 define i8 @udiv_exact_demanded_low_bits_set(i8 %a) {
 ; CHECK-LABEL: @udiv_exact_demanded_low_bits_set(
-; CHECK-NEXT:    [[O:%.*]] = or i8 [[A:%.*]], 3
-; CHECK-NEXT:    [[U:%.*]] = udiv exact i8 [[O]], 12
-; CHECK-NEXT:    ret i8 [[U]]
+; CHECK-NEXT:    ret i8 poison
 ;
   %o = or i8 %a, 3
   %u = udiv exact i8 %o, 12

diff  --git a/llvm/test/Transforms/InstSimplify/div.ll b/llvm/test/Transforms/InstSimplify/div.ll
index e9e07ea5328a7..4882e1f848902 100644
--- a/llvm/test/Transforms/InstSimplify/div.ll
+++ b/llvm/test/Transforms/InstSimplify/div.ll
@@ -333,17 +333,19 @@ define i1 @const_urem_1() {
   ret i1 %rem
 }
 
+; Can't divide evenly, so create poison.
+
 define i8 @sdiv_exact_trailing_zeros(i8 %x) {
 ; CHECK-LABEL: @sdiv_exact_trailing_zeros(
-; CHECK-NEXT:    [[O:%.*]] = or i8 [[X:%.*]], 1
-; CHECK-NEXT:    [[R:%.*]] = sdiv exact i8 [[O]], -42
-; CHECK-NEXT:    ret i8 [[R]]
+; CHECK-NEXT:    ret i8 poison
 ;
   %o = or i8 %x, 1           ; odd number
   %r = sdiv exact i8 %o, -42 ; can't divide exactly
   ret i8 %r
 }
 
+; Negative test - could divide evenly.
+
 define i8 @sdiv_exact_trailing_zeros_eq(i8 %x) {
 ; CHECK-LABEL: @sdiv_exact_trailing_zeros_eq(
 ; CHECK-NEXT:    [[O:%.*]] = or i8 [[X:%.*]], 2
@@ -355,6 +357,8 @@ define i8 @sdiv_exact_trailing_zeros_eq(i8 %x) {
   ret i8 %r
 }
 
+; Negative test - must be exact div.
+
 define i8 @sdiv_trailing_zeros(i8 %x) {
 ; CHECK-LABEL: @sdiv_trailing_zeros(
 ; CHECK-NEXT:    [[O:%.*]] = or i8 [[X:%.*]], 1
@@ -366,17 +370,32 @@ define i8 @sdiv_trailing_zeros(i8 %x) {
   ret i8 %r
 }
 
+; TODO: Match non-splat vector constants.
+
+define <2 x i8> @sdiv_exact_trailing_zeros_nonuniform_vector(<2 x i8> %x) {
+; CHECK-LABEL: @sdiv_exact_trailing_zeros_nonuniform_vector(
+; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 3, i8 1>
+; CHECK-NEXT:    [[R:%.*]] = sdiv exact <2 x i8> [[O]], <i8 12, i8 2>
+; CHECK-NEXT:    ret <2 x i8> [[R]]
+;
+  %o = or <2 x i8> %x, <i8 3, i8 1>
+  %r = sdiv exact <2 x i8> %o, <i8 12, i8 2>
+  ret <2 x i8> %r
+}
+
+; Can't divide evenly, so create poison.
+
 define <2 x i8> @udiv_exact_trailing_zeros(<2 x i8> %x) {
 ; CHECK-LABEL: @udiv_exact_trailing_zeros(
-; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 3, i8 3>
-; CHECK-NEXT:    [[R:%.*]] = udiv exact <2 x i8> [[O]], <i8 12, i8 12>
-; CHECK-NEXT:    ret <2 x i8> [[R]]
+; CHECK-NEXT:    ret <2 x i8> poison
 ;
   %o = or <2 x i8> %x, <i8 3, i8 3>
   %r = udiv exact <2 x i8> %o, <i8 12, i8 12>  ; can't divide exactly
   ret <2 x i8> %r
 }
 
+; Negative test - could divide evenly.
+
 define <2 x i8> @udiv_exact_trailing_zeros_eq(<2 x i8> %x) {
 ; CHECK-LABEL: @udiv_exact_trailing_zeros_eq(
 ; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 28, i8 28>
@@ -388,6 +407,8 @@ define <2 x i8> @udiv_exact_trailing_zeros_eq(<2 x i8> %x) {
   ret <2 x i8> %r
 }
 
+; Negative test - must be exact div.
+
 define i8 @udiv_trailing_zeros(i8 %x) {
 ; CHECK-LABEL: @udiv_trailing_zeros(
 ; CHECK-NEXT:    [[O:%.*]] = or i8 [[X:%.*]], 1
@@ -399,4 +420,17 @@ define i8 @udiv_trailing_zeros(i8 %x) {
   ret i8 %r
 }
 
+; Negative test - only the first element is poison
+
+define <2 x i8> @udiv_exact_trailing_zeros_nonuniform_vector(<2 x i8> %x) {
+; CHECK-LABEL: @udiv_exact_trailing_zeros_nonuniform_vector(
+; CHECK-NEXT:    [[O:%.*]] = or <2 x i8> [[X:%.*]], <i8 3, i8 3>
+; CHECK-NEXT:    [[R:%.*]] = udiv exact <2 x i8> [[O]], <i8 12, i8 1>
+; CHECK-NEXT:    ret <2 x i8> [[R]]
+;
+  %o = or <2 x i8> %x, <i8 3, i8 3>
+  %r = udiv exact <2 x i8> %o, <i8 12, i8 1>
+  ret <2 x i8> %r
+}
+
 !0 = !{i32 0, i32 3}


        


More information about the llvm-commits mailing list