[llvm] 8a4266a - [InstSimplify] Fold `u/sdiv exact (mul nsw/nuw X, C), C --> X` when C is not a power of 2 (#76445)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 28 01:36:29 PST 2023


Author: Yingwei Zheng
Date: 2023-12-28T17:36:25+08:00
New Revision: 8a4266a626914765c0c69839e8a51be383013c1a

URL: https://github.com/llvm/llvm-project/commit/8a4266a626914765c0c69839e8a51be383013c1a
DIFF: https://github.com/llvm/llvm-project/commit/8a4266a626914765c0c69839e8a51be383013c1a.diff

LOG: [InstSimplify] Fold `u/sdiv exact (mul nsw/nuw X, C), C --> X` when C is not a power of 2 (#76445)

Alive2: https://alive2.llvm.org/ce/z/3D9R7d

Added: 
    

Modified: 
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/test/Transforms/InstSimplify/div.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 5beac5547d65e0..ef2c3765400bdd 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1189,14 +1189,26 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse))
     return V;
 
-  // If this is an exact divide by a constant, then the dividend (Op0) must have
-  // at least as many trailing zeros as the divisor to divide evenly. If it has
-  // less trailing zeros, then the result must be poison.
   const APInt *DivC;
-  if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
-    KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
-    if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
-      return PoisonValue::get(Op0->getType());
+  if (IsExact && match(Op1, m_APInt(DivC))) {
+    // If this is an exact divide by a constant, then the dividend (Op0) must
+    // have at least as many trailing zeros as the divisor to divide evenly. If
+    // it has less trailing zeros, then the result must be poison.
+    if (DivC->countr_zero()) {
+      KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
+      if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
+        return PoisonValue::get(Op0->getType());
+    }
+
+    // udiv exact (mul nsw X, C), C --> X
+    // sdiv exact (mul nuw X, C), C --> X
+    // where C is not a power of 2.
+    Value *X;
+    if (!DivC->isPowerOf2() &&
+        (Opcode == Instruction::UDiv
+             ? match(Op0, m_NSWMul(m_Value(X), m_Specific(Op1)))
+             : match(Op0, m_NUWMul(m_Value(X), m_Specific(Op1)))))
+      return X;
   }
 
   return nullptr;

diff  --git a/llvm/test/Transforms/InstSimplify/div.ll b/llvm/test/Transforms/InstSimplify/div.ll
index a379e1ec9efe22..e13b6f139bcf53 100644
--- a/llvm/test/Transforms/InstSimplify/div.ll
+++ b/llvm/test/Transforms/InstSimplify/div.ll
@@ -567,3 +567,100 @@ define <2 x i8> @sdiv_vec_multi_one_bit_divisor(<2 x i8> %x, <2 x i8> %y) {
   %res = sdiv <2 x i8> %y, %and
   ret <2 x i8> %res
 }
+
+define i8 @udiv_exact_mul_nsw(i8 %x) {
+; CHECK-LABEL: @udiv_exact_mul_nsw(
+; CHECK-NEXT:    ret i8 [[X:%.*]]
+;
+  %a = mul nsw i8 %x, 24
+  %b = udiv exact i8 %a, 24
+  ret i8 %b
+}
+
+define i8 @sdiv_exact_mul_nuw(i8 %x) {
+; CHECK-LABEL: @sdiv_exact_mul_nuw(
+; CHECK-NEXT:    ret i8 [[X:%.*]]
+;
+  %a = mul nuw i8 %x, 24
+  %b = sdiv exact i8 %a, 24
+  ret i8 %b
+}
+
+; Negative tests
+
+define i8 @udiv_exact_mul_nsw_mismatch(i8 %x) {
+; CHECK-LABEL: @udiv_exact_mul_nsw_mismatch(
+; CHECK-NEXT:    [[A:%.*]] = mul nsw i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[B:%.*]] = udiv exact i8 [[A]], 12
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul nsw i8 %x, 24
+  %b = udiv exact i8 %a, 12
+  ret i8 %b
+}
+
+define i8 @udiv_exact_mul_nsw_power_of_2(i8 %x) {
+; CHECK-LABEL: @udiv_exact_mul_nsw_power_of_2(
+; CHECK-NEXT:    [[A:%.*]] = mul nsw i8 [[X:%.*]], 8
+; CHECK-NEXT:    [[B:%.*]] = udiv exact i8 [[A]], 8
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul nsw i8 %x, 8
+  %b = udiv exact i8 %a, 8
+  ret i8 %b
+}
+
+define i8 @sdiv_exact_mul_nuw_power_of_2(i8 %x) {
+; CHECK-LABEL: @sdiv_exact_mul_nuw_power_of_2(
+; CHECK-NEXT:    [[A:%.*]] = mul nuw i8 [[X:%.*]], 8
+; CHECK-NEXT:    [[B:%.*]] = sdiv exact i8 [[A]], 8
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul nuw i8 %x, 8
+  %b = sdiv exact i8 %a, 8
+  ret i8 %b
+}
+
+define i8 @udiv_exact_mul(i8 %x) {
+; CHECK-LABEL: @udiv_exact_mul(
+; CHECK-NEXT:    [[A:%.*]] = mul i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[B:%.*]] = udiv exact i8 [[A]], 24
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul i8 %x, 24
+  %b = udiv exact i8 %a, 24
+  ret i8 %b
+}
+
+define i8 @sdiv_exact_mul(i8 %x) {
+; CHECK-LABEL: @sdiv_exact_mul(
+; CHECK-NEXT:    [[A:%.*]] = mul i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[B:%.*]] = sdiv exact i8 [[A]], 24
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul i8 %x, 24
+  %b = sdiv exact i8 %a, 24
+  ret i8 %b
+}
+
+define i8 @udiv_mul_nsw(i8 %x) {
+; CHECK-LABEL: @udiv_mul_nsw(
+; CHECK-NEXT:    [[A:%.*]] = mul nsw i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[B:%.*]] = udiv i8 [[A]], 24
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul nsw i8 %x, 24
+  %b = udiv i8 %a, 24
+  ret i8 %b
+}
+
+define i8 @sdiv_mul_nuw(i8 %x) {
+; CHECK-LABEL: @sdiv_mul_nuw(
+; CHECK-NEXT:    [[A:%.*]] = mul nuw i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[B:%.*]] = sdiv i8 [[A]], 24
+; CHECK-NEXT:    ret i8 [[B]]
+;
+  %a = mul nuw i8 %x, 24
+  %b = sdiv i8 %a, 24
+  ret i8 %b
+}


        


More information about the llvm-commits mailing list