[llvm] [InstCombine] Simplify fractions when there is no overflow (PR #92949)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 21 16:22:29 PDT 2024


https://github.com/AtariDreams updated https://github.com/llvm/llvm-project/pull/92949

>From 2f7dc2c6e0e766be23c7df4e641253cb04979b23 Mon Sep 17 00:00:00 2001
From: Rose <gfunni234 at gmail.com>
Date: Tue, 21 May 2024 14:09:05 -0400
Subject: [PATCH] [InstCombine] Simplify fractions when there is no overlap

https://alive2.llvm.org/ce/z/-36zkQ
---
 .../InstCombine/InstCombineMulDivRem.cpp      | 76 ++++++++++++++++---
 llvm/test/Transforms/InstCombine/div.ll       | 15 +++-
 2 files changed, 80 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index ca1b1921404d8..ec178204ff105 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1057,7 +1057,8 @@ static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) {
 
     // (X * Y) s/ (X << Z) --> Y s/ (1 << Z)
     if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) {
-      Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
+      Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z, "", true,
+                                     HasNUW && HasNSW);
       return Builder.CreateSDiv(Y, Shl, "", I.isExact());
     }
   }
@@ -1169,13 +1170,53 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
 
       // (X * C1) / C2 -> X * (C1 / C2) if C1 is a multiple of C2.
       if (isMultiple(*C1, *C2, Quotient, IsSigned)) {
-        auto *Mul = BinaryOperator::Create(Instruction::Mul, X,
-                                           ConstantInt::get(Ty, Quotient));
+        assert(!C1->isOne() && !C2->isOne() && !C1->isZero() && !C2->isZero() &&
+               "InstSimplify should have removed constants of 1!");
+        auto *Mul =
+            BinaryOperator::CreateNSWMul(X, ConstantInt::get(Ty, Quotient));
         auto *OBO = cast<OverflowingBinaryOperator>(Op0);
-        Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap());
-        Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap());
+        Mul->setHasNoUnsignedWrap(OBO->hasNoUnsignedWrap());
         return Mul;
       }
+
+      // We can reduce expressions of things like * 150 / 100 to * 3 / 2
+      if (Op0->hasOneUse() && !C2->isZero() &&
+          !(IsSigned && C1->isMinSignedValue() && C2->isAllOnes())) {
+        assert(!C2->isMinSignedValue() &&
+               "This should have been folded away by InstSimplify");
+        APInt GCD = APIntOps::GreatestCommonDivisor(C1->abs(), C2->abs());
+        if (!GCD.isOne() && !GCD.isZero()) {
+          APInt NewC1;
+          APInt NewC2;
+          if (IsSigned && C1->isNegative() && C2->isNegative()) {
+            NewC1 = C1->abs().udiv(GCD);
+            NewC2 = C2->abs().udiv(GCD);
+          } else if (IsSigned) {
+            NewC1 = C1->sdiv(GCD);
+            NewC2 = C2->sdiv(GCD);
+          } else {
+            NewC1 = C1->udiv(GCD);
+            NewC2 = C2->udiv(GCD);
+          }
+
+          auto *NewMul = Builder.CreateMul(
+              X, ConstantInt::get(Ty, NewC1), "",
+              /*NUW*/ cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap(),
+              /*NSW*/ true);
+
+          if (IsSigned) {
+            auto *NewDiv =
+                BinaryOperator::CreateSDiv(NewMul, ConstantInt::get(Ty, NewC2));
+            NewDiv->setIsExact(I.isExact());
+            return NewDiv;
+          }
+
+          auto *NewDiv =
+              BinaryOperator::CreateUDiv(NewMul, ConstantInt::get(Ty, NewC2));
+          NewDiv->setIsExact(I.isExact());
+          return NewDiv;
+        }
+      }
     }
 
     if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) &&
@@ -1198,8 +1239,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
         auto *Mul = BinaryOperator::Create(Instruction::Mul, X,
                                            ConstantInt::get(Ty, Quotient));
         auto *OBO = cast<OverflowingBinaryOperator>(Op0);
-        Mul->setHasNoUnsignedWrap(!IsSigned && OBO->hasNoUnsignedWrap());
-        Mul->setHasNoSignedWrap(OBO->hasNoSignedWrap());
+        Mul->setHasNoUnsignedWrap(OBO->hasNoUnsignedWrap());
+        Mul->setHasNoSignedWrap(!IsSigned || OBO->hasNoSignedWrap());
         return Mul;
       }
     }
@@ -1273,13 +1314,30 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
   }
 
   // (X << Z) / (X * Y) -> (1 << Z) / Y
-  // TODO: Handle sdiv.
   if (!IsSigned && Op1->hasOneUse() &&
       match(Op0, m_NUWShl(m_Value(X), m_Value(Z))) &&
       match(Op1, m_c_Mul(m_Specific(X), m_Value(Y))))
     if (cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap()) {
       Instruction *NewDiv = BinaryOperator::CreateUDiv(
-          Builder.CreateShl(ConstantInt::get(Ty, 1), Z, "", /*NUW*/ true), Y);
+          Builder.CreateShl(
+              ConstantInt::get(Ty, 1), Z, "", /*NUW*/ true,
+              cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()),
+          Y);
+      NewDiv->setIsExact(I.isExact());
+      return NewDiv;
+    }
+
+  // (X << Z) / (X * Y) -> (1 << Z) / Y
+  if (IsSigned && Op1->hasOneUse() &&
+      match(Op0, m_NSWShl(m_Value(X), m_Value(Z))) &&
+      match(Op1, m_c_Mul(m_Specific(X), m_Value(Y))))
+    if (cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) {
+      Instruction *NewDiv = BinaryOperator::CreateSDiv(
+          Builder.CreateShl(
+              ConstantInt::get(Ty, 1), Z, "",
+              /*NUW*/ cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(),
+              /*NSW*/ true),
+          Y);
       NewDiv->setIsExact(I.isExact());
       return NewDiv;
     }
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index e8a25ff44d029..6f7e10aef515a 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -380,7 +380,7 @@ define i32 @test26(i32 %a) {
 
 define i32 @test27(i32 %a) {
 ; CHECK-LABEL: @test27(
-; CHECK-NEXT:    [[DIV:%.*]] = shl nuw i32 [[A:%.*]], 1
+; CHECK-NEXT:    [[DIV:%.*]] = shl nuw nsw i32 [[A:%.*]], 1
 ; CHECK-NEXT:    ret i32 [[DIV]]
 ;
   %shl = shl nuw i32 %a, 2
@@ -390,7 +390,7 @@ define i32 @test27(i32 %a) {
 
 define i32 @test28(i32 %a) {
 ; CHECK-LABEL: @test28(
-; CHECK-NEXT:    [[DIV:%.*]] = mul nuw i32 [[A:%.*]], 12
+; CHECK-NEXT:    [[DIV:%.*]] = mul nuw nsw i32 [[A:%.*]], 12
 ; CHECK-NEXT:    ret i32 [[DIV]]
 ;
   %mul = mul nuw i32 %a, 36
@@ -1678,6 +1678,17 @@ define i32 @sdiv_sdiv_mul_nsw_exact_use(i32 %x, i32 %y, i32 %z) {
   ret i32 %r
 }
 
+define i32 @x_times_150_over_100(i32 %x) {
+; CHECK-LABEL: @x_times_150_over_100(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[D1:%.*]] = lshr i32 [[TMP1]], 1
+; CHECK-NEXT:    ret i32 [[D1]]
+;
+  %m = mul nuw nsw i32 %x, 150
+  %d = udiv i32 %m, 100
+  ret i32 %d
+}
+
 ; negative test - must have nsw
 
 define i8 @sdiv_sdiv_mul_nuw(i8 %x, i8 %y, i8 %z) {



More information about the llvm-commits mailing list