[llvm] [Reassociate] Use uint64_t for repeat count (PR #94232)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 7 09:14:24 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/94232

>From 6b6a134f56a5e515c30528bc3a3bb0e9142d6ece Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 3 Jun 2024 23:12:58 +0800
Subject: [PATCH 1/2] [Reassociate] Use uint64_t for repeat count

---
 llvm/lib/Transforms/Scalar/Reassociate.cpp  | 119 ++------------------
 llvm/test/Transforms/Reassociate/repeats.ll |  45 +++++---
 2 files changed, 42 insertions(+), 122 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index c73d7c8d83bec..6cf097094ddd0 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -302,98 +302,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
   return Res;
 }
 
-/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael
-/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for
-/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
-/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
-/// even x in Bitwidth-bit arithmetic.
-static unsigned CarmichaelShift(unsigned Bitwidth) {
-  if (Bitwidth < 3)
-    return Bitwidth - 1;
-  return Bitwidth - 2;
-}
-
-/// Add the extra weight 'RHS' to the existing weight 'LHS',
-/// reducing the combined weight using any special properties of the operation.
-/// The existing weight LHS represents the computation X op X op ... op X where
-/// X occurs LHS times.  The combined weight represents  X op X op ... op X with
-/// X occurring LHS + RHS times.  If op is "Xor" for example then the combined
-/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
-/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
-static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
-  // If we were working with infinite precision arithmetic then the combined
-  // weight would be LHS + RHS.  But we are using finite precision arithmetic,
-  // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
-  // for nilpotent operations and addition, but not for idempotent operations
-  // and multiplication), so it is important to correctly reduce the combined
-  // weight back into range if wrapping would be wrong.
-
-  // If RHS is zero then the weight didn't change.
-  if (RHS.isMinValue())
-    return;
-  // If LHS is zero then the combined weight is RHS.
-  if (LHS.isMinValue()) {
-    LHS = RHS;
-    return;
-  }
-  // From this point on we know that neither LHS nor RHS is zero.
-
-  if (Instruction::isIdempotent(Opcode)) {
-    // Idempotent means X op X === X, so any non-zero weight is equivalent to a
-    // weight of 1.  Keeping weights at zero or one also means that wrapping is
-    // not a problem.
-    assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
-    return; // Return a weight of 1.
-  }
-  if (Instruction::isNilpotent(Opcode)) {
-    // Nilpotent means X op X === 0, so reduce weights modulo 2.
-    assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
-    LHS = 0; // 1 + 1 === 0 modulo 2.
-    return;
-  }
-  if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) {
-    // TODO: Reduce the weight by exploiting nsw/nuw?
-    LHS += RHS;
-    return;
-  }
-
-  assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) &&
-         "Unknown associative operation!");
-  unsigned Bitwidth = LHS.getBitWidth();
-  // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
-  // can be replaced with W-CM.  That's because x^W=x^(W-CM) for every Bitwidth
-  // bit number x, since either x is odd in which case x^CM = 1, or x is even in
-  // which case both x^W and x^(W - CM) are zero.  By subtracting off multiples
-  // of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
-  // which by a happy accident means that they can always be represented using
-  // Bitwidth bits.
-  // TODO: Reduce the weight by exploiting nsw/nuw?  (Could do much better than
-  // the Carmichael number).
-  if (Bitwidth > 3) {
-    /// CM - The value of Carmichael's lambda function.
-    APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
-    // Any weight W >= Threshold can be replaced with W - CM.
-    APInt Threshold = CM + Bitwidth;
-    assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
-    // For Bitwidth 4 or more the following sum does not overflow.
-    LHS += RHS;
-    while (LHS.uge(Threshold))
-      LHS -= CM;
-  } else {
-    // To avoid problems with overflow do everything the same as above but using
-    // a larger type.
-    unsigned CM = 1U << CarmichaelShift(Bitwidth);
-    unsigned Threshold = CM + Bitwidth;
-    assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
-           "Weights not reduced!");
-    unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
-    while (Total >= Threshold)
-      Total -= CM;
-    LHS = Total;
-  }
-}
-
-using RepeatedValue = std::pair<Value*, APInt>;
+using RepeatedValue = std::pair<Value *, uint64_t>;
 
 /// Given an associative binary expression, return the leaf
 /// nodes in Ops along with their weights (how many times the leaf occurs).  The
@@ -475,7 +384,6 @@ static bool LinearizeExprTree(Instruction *I,
   assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
          "Expected a UnaryOperator or BinaryOperator!");
   LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
-  unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits();
   unsigned Opcode = I->getOpcode();
   assert(I->isAssociative() && I->isCommutative() &&
          "Expected an associative and commutative operation!");
@@ -490,8 +398,8 @@ static bool LinearizeExprTree(Instruction *I,
   // with their weights, representing a certain number of paths to the operator.
   // If an operator occurs in the worklist multiple times then we found multiple
   // ways to get to it.
-  SmallVector<std::pair<Instruction*, APInt>, 8> Worklist; // (Op, Weight)
-  Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1)));
+  SmallVector<std::pair<Instruction *, uint64_t>, 8> Worklist; // (Op, Weight)
+  Worklist.push_back(std::make_pair(I, 1));
   bool Changed = false;
 
   // Leaves of the expression are values that either aren't the right kind of
@@ -509,7 +417,7 @@ static bool LinearizeExprTree(Instruction *I,
 
   // Leaves - Keeps track of the set of putative leaves as well as the number of
   // paths to each leaf seen so far.
-  using LeafMap = DenseMap<Value *, APInt>;
+  using LeafMap = DenseMap<Value *, uint64_t>;
   LeafMap Leaves; // Leaf -> Total weight so far.
   SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order.
   const DataLayout DL = I->getModule()->getDataLayout();
@@ -518,8 +426,8 @@ static bool LinearizeExprTree(Instruction *I,
   SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme.
 #endif
   while (!Worklist.empty()) {
-    std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
-    I = P.first; // We examine the operands of this binary operator.
+    // We examine the operands of this binary operator.
+    auto [I, Weight] = Worklist.pop_back_val();
 
     if (isa<OverflowingBinaryOperator>(I)) {
       Flags.HasNUW &= I->hasNoUnsignedWrap();
@@ -528,7 +436,6 @@ static bool LinearizeExprTree(Instruction *I,
 
     for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
       Value *Op = I->getOperand(OpIdx);
-      APInt Weight = P.second; // Number of paths to this operand.
       LLVM_DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n");
       assert(!Op->use_empty() && "No uses, so how did we get to it?!");
 
@@ -562,7 +469,7 @@ static bool LinearizeExprTree(Instruction *I,
                "In leaf map but not visited!");
 
         // Update the number of paths to the leaf.
-        IncorporateWeight(It->second, Weight, Opcode);
+        It->second += Weight;
 
         // If we still have uses that are not accounted for by the expression
         // then it is not safe to modify the value.
@@ -625,10 +532,7 @@ static bool LinearizeExprTree(Instruction *I,
       // Node initially thought to be a leaf wasn't.
       continue;
     assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!");
-    APInt Weight = It->second;
-    if (Weight.isMinValue())
-      // Leaf already output or weight reduction eliminated it.
-      continue;
+    uint64_t Weight = It->second;
     // Ensure the leaf is only output once.
     It->second = 0;
     Ops.push_back(std::make_pair(V, Weight));
@@ -642,7 +546,7 @@ static bool LinearizeExprTree(Instruction *I,
   if (Ops.empty()) {
     Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType());
     assert(Identity && "Associative operation without identity!");
-    Ops.emplace_back(Identity, APInt(Bitwidth, 1));
+    Ops.emplace_back(Identity, 1);
   }
 
   return Changed;
@@ -1188,8 +1092,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
   Factors.reserve(Tree.size());
   for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
     RepeatedValue E = Tree[i];
-    Factors.append(E.second.getZExtValue(),
-                   ValueEntry(getRank(E.first), E.first));
+    Factors.append(E.second, ValueEntry(getRank(E.first), E.first));
   }
 
   bool FoundFactor = false;
@@ -2368,7 +2271,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
   SmallVector<ValueEntry, 8> Ops;
   Ops.reserve(Tree.size());
   for (const RepeatedValue &E : Tree)
-    Ops.append(E.second.getZExtValue(), ValueEntry(getRank(E.first), E.first));
+    Ops.append(E.second, ValueEntry(getRank(E.first), E.first));
 
   LLVM_DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n');
 
diff --git a/llvm/test/Transforms/Reassociate/repeats.ll b/llvm/test/Transforms/Reassociate/repeats.ll
index ba25c4bfc643c..8600777877bb3 100644
--- a/llvm/test/Transforms/Reassociate/repeats.ll
+++ b/llvm/test/Transforms/Reassociate/repeats.ll
@@ -60,7 +60,8 @@ define i3 @foo3x5(i3 %x) {
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]]
-; CHECK-NEXT:    ret i3 [[TMP4]]
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    ret i3 [[TMP5]]
 ;
   %tmp1 = mul i3 %x, %x
   %tmp2 = mul i3 %tmp1, %x
@@ -74,7 +75,8 @@ define i3 @foo3x5_nsw(i3 %x) {
 ; CHECK-LABEL: define i3 @foo3x5_nsw(
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]]
 ; CHECK-NEXT:    ret i3 [[TMP4]]
 ;
   %tmp1 = mul i3 %x, %x
@@ -89,7 +91,8 @@ define i3 @foo3x6(i3 %x) {
 ; CHECK-LABEL: define i3 @foo3x6(
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i3 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]]
 ; CHECK-NEXT:    ret i3 [[TMP2]]
 ;
   %tmp1 = mul i3 %x, %x
@@ -106,7 +109,9 @@ define i3 @foo3x7(i3 %x) {
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = mul i3 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]]
-; CHECK-NEXT:    ret i3 [[TMP6]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[TMP6]], [[X]]
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i3 [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    ret i3 [[TMP7]]
 ;
   %tmp1 = mul i3 %x, %x
   %tmp2 = mul i3 %tmp1, %x
@@ -123,7 +128,8 @@ define i4 @foo4x8(i4 %x) {
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    ret i4 [[TMP4]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
+; CHECK-NEXT:    ret i4 [[TMP3]]
 ;
   %tmp1 = mul i4 %x, %x
   %tmp2 = mul i4 %tmp1, %x
@@ -140,8 +146,9 @@ define i4 @foo4x9(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x9(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]]
 ; CHECK-NEXT:    ret i4 [[TMP8]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -160,7 +167,8 @@ define i4 @foo4x10(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x10(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
 ; CHECK-NEXT:    ret i4 [[TMP3]]
 ;
@@ -181,7 +189,8 @@ define i4 @foo4x11(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x11(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]]
 ; CHECK-NEXT:    ret i4 [[TMP10]]
@@ -204,7 +213,9 @@ define i4 @foo4x12(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x12(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]]
 ; CHECK-NEXT:    ret i4 [[TMP2]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -227,7 +238,9 @@ define i4 @foo4x13(i4 %x) {
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
+; CHECK-NEXT:    [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]]
 ; CHECK-NEXT:    ret i4 [[TMP12]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -252,7 +265,9 @@ define i4 @foo4x14(i4 %x) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]]
-; CHECK-NEXT:    ret i4 [[TMP7]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP7]], [[X]]
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]]
+; CHECK-NEXT:    ret i4 [[TMP5]]
 ;
   %tmp1 = mul i4 %x, %x
   %tmp2 = mul i4 %tmp1, %x
@@ -276,8 +291,10 @@ define i4 @foo4x15(i4 %x) {
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]]
-; CHECK-NEXT:    [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP6]], [[TMP6]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
+; CHECK-NEXT:    [[TMP5:%.*]] = mul i4 [[TMP4]], [[X]]
+; CHECK-NEXT:    [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP4]]
 ; CHECK-NEXT:    ret i4 [[TMP14]]
 ;
   %tmp1 = mul i4 %x, %x

>From cd30a7a533174f6a4bf03635cd9cf4365410e944 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 8 Jun 2024 00:13:40 +0800
Subject: [PATCH 2/2] [Reassociate] Add overflow checks.

---
 llvm/lib/Transforms/Scalar/Reassociate.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 6cf097094ddd0..f36e21b296bd1 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -470,6 +470,7 @@ static bool LinearizeExprTree(Instruction *I,
 
         // Update the number of paths to the leaf.
         It->second += Weight;
+        assert(It->second >= Weight && "Weight overflows");
 
         // If we still have uses that are not accounted for by the expression
         // then it is not safe to modify the value.



More information about the llvm-commits mailing list