[llvm] 22b63b9 - Revert "[Reassociate] Drop weight reduction to fix issue 91417 (#91469)" (#94210)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 06:40:10 PDT 2024


Author: Yingwei Zheng
Date: 2024-06-03T21:40:06+08:00
New Revision: 22b63b97ffd4e2d71c1ec2f8b4757b32587526c9

URL: https://github.com/llvm/llvm-project/commit/22b63b97ffd4e2d71c1ec2f8b4757b32587526c9
DIFF: https://github.com/llvm/llvm-project/commit/22b63b97ffd4e2d71c1ec2f8b4757b32587526c9.diff

LOG: Revert "[Reassociate] Drop weight reduction to fix issue 91417 (#91469)" (#94210)

Reverts
https://github.com/llvm/llvm-project/commit/3bcccb6af685c3132a9ee578b9e11b2503c35a5c
and
https://github.com/llvm/llvm-project/commit/9a282724a29899e84adc91bdeaf639010408a80d
because #91469 causes a miscompilation
https://github.com/llvm/llvm-project/pull/91469#discussion_r1623925158.

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/Reassociate.cpp
    llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll
    llvm/test/Transforms/Reassociate/repeats.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 04c54ed69e93f..c73d7c8d83bec 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -302,6 +302,97 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
   return Res;
 }
 
+/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael
+/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for
+/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
+/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
+/// even x in Bitwidth-bit arithmetic.
+static unsigned CarmichaelShift(unsigned Bitwidth) {
+  if (Bitwidth < 3)
+    return Bitwidth - 1;
+  return Bitwidth - 2;
+}
+
+/// Add the extra weight 'RHS' to the existing weight 'LHS',
+/// reducing the combined weight using any special properties of the operation.
+/// The existing weight LHS represents the computation X op X op ... op X where
+/// X occurs LHS times.  The combined weight represents  X op X op ... op X with
+/// X occurring LHS + RHS times.  If op is "Xor" for example then the combined
+/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
+/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
+static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
+  // If we were working with infinite precision arithmetic then the combined
+  // weight would be LHS + RHS.  But we are using finite precision arithmetic,
+  // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
+  // for nilpotent operations and addition, but not for idempotent operations
+  // and multiplication), so it is important to correctly reduce the combined
+  // weight back into range if wrapping would be wrong.
+
+  // If RHS is zero then the weight didn't change.
+  if (RHS.isMinValue())
+    return;
+  // If LHS is zero then the combined weight is RHS.
+  if (LHS.isMinValue()) {
+    LHS = RHS;
+    return;
+  }
+  // From this point on we know that neither LHS nor RHS is zero.
+
+  if (Instruction::isIdempotent(Opcode)) {
+    // Idempotent means X op X === X, so any non-zero weight is equivalent to a
+    // weight of 1.  Keeping weights at zero or one also means that wrapping is
+    // not a problem.
+    assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+    return; // Return a weight of 1.
+  }
+  if (Instruction::isNilpotent(Opcode)) {
+    // Nilpotent means X op X === 0, so reduce weights modulo 2.
+    assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+    LHS = 0; // 1 + 1 === 0 modulo 2.
+    return;
+  }
+  if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) {
+    // TODO: Reduce the weight by exploiting nsw/nuw?
+    LHS += RHS;
+    return;
+  }
+
+  assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) &&
+         "Unknown associative operation!");
+  unsigned Bitwidth = LHS.getBitWidth();
+  // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
+  // can be replaced with W-CM.  That's because x^W=x^(W-CM) for every Bitwidth
+  // bit number x, since either x is odd in which case x^CM = 1, or x is even in
+  // which case both x^W and x^(W - CM) are zero.  By subtracting off multiples
+  // of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
+  // which by a happy accident means that they can always be represented using
+  // Bitwidth bits.
+  // TODO: Reduce the weight by exploiting nsw/nuw?  (Could do much better than
+  // the Carmichael number).
+  if (Bitwidth > 3) {
+    /// CM - The value of Carmichael's lambda function.
+    APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
+    // Any weight W >= Threshold can be replaced with W - CM.
+    APInt Threshold = CM + Bitwidth;
+    assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
+    // For Bitwidth 4 or more the following sum does not overflow.
+    LHS += RHS;
+    while (LHS.uge(Threshold))
+      LHS -= CM;
+  } else {
+    // To avoid problems with overflow do everything the same as above but using
+    // a larger type.
+    unsigned CM = 1U << CarmichaelShift(Bitwidth);
+    unsigned Threshold = CM + Bitwidth;
+    assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
+           "Weights not reduced!");
+    unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
+    while (Total >= Threshold)
+      Total -= CM;
+    LHS = Total;
+  }
+}
+
 using RepeatedValue = std::pair<Value*, APInt>;
 
 /// Given an associative binary expression, return the leaf
@@ -471,7 +562,7 @@ static bool LinearizeExprTree(Instruction *I,
                "In leaf map but not visited!");
 
         // Update the number of paths to the leaf.
-        It->second += Weight;
+        IncorporateWeight(It->second, Weight, Opcode);
 
         // If we still have uses that are not accounted for by the expression
         // then it is not safe to modify the value.

diff  --git a/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll b/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll
index bd0060cc5abbd..d4aa5c507ec8b 100644
--- a/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll
+++ b/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll
@@ -56,19 +56,20 @@ define <8 x i1> @vector2(<8 x i1> %a, <8 x i1> %b0, <8 x i1> %b1, <8 x i1> %b2,
 ; CHECK-NEXT:    [[OR5:%.*]] = or <8 x i1> [[B5]], [[A]]
 ; CHECK-NEXT:    [[OR6:%.*]] = or <8 x i1> [[B6]], [[A]]
 ; CHECK-NEXT:    [[OR7:%.*]] = or <8 x i1> [[B7]], [[A]]
-; CHECK-NEXT:    [[XOR0:%.*]] = xor <8 x i1> [[OR1]], [[OR0]]
-; CHECK-NEXT:    [[XOR2:%.*]] = xor <8 x i1> [[XOR0]], [[OR2]]
-; CHECK-NEXT:    [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR3]]
-; CHECK-NEXT:    [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR4]]
-; CHECK-NEXT:    [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR5]]
-; CHECK-NEXT:    [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR6]]
-; CHECK-NEXT:    [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR7]]
+; CHECK-NEXT:    [[XOR2:%.*]] = xor <8 x i1> [[OR1]], [[OR0]]
+; CHECK-NEXT:    [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR2]]
+; CHECK-NEXT:    [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR3]]
+; CHECK-NEXT:    [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR4]]
+; CHECK-NEXT:    [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR5]]
+; CHECK-NEXT:    [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR6]]
+; CHECK-NEXT:    [[XOR7:%.*]] = xor <8 x i1> [[XOR6]], [[OR7]]
 ; CHECK-NEXT:    [[OR4560:%.*]] = or <8 x i1> [[OR045]], [[XOR2]]
 ; CHECK-NEXT:    [[OR023:%.*]] = or <8 x i1> [[OR4560]], [[XOR3]]
 ; CHECK-NEXT:    [[OR001:%.*]] = or <8 x i1> [[OR023]], [[XOR4]]
 ; CHECK-NEXT:    [[OR0123:%.*]] = or <8 x i1> [[OR001]], [[XOR5]]
 ; CHECK-NEXT:    [[OR01234567:%.*]] = or <8 x i1> [[OR0123]], [[XOR6]]
-; CHECK-NEXT:    ret <8 x i1> [[OR01234567]]
+; CHECK-NEXT:    [[OR1234567:%.*]] = or <8 x i1> [[OR01234567]], [[XOR7]]
+; CHECK-NEXT:    ret <8 x i1> [[OR1234567]]
 ;
   %or0 = or <8 x i1> %b0, %a
   %or1 = or <8 x i1> %b1, %a

diff  --git a/llvm/test/Transforms/Reassociate/repeats.ll b/llvm/test/Transforms/Reassociate/repeats.ll
index 28177f1c0ba5e..ba25c4bfc643c 100644
--- a/llvm/test/Transforms/Reassociate/repeats.ll
+++ b/llvm/test/Transforms/Reassociate/repeats.ll
@@ -15,7 +15,7 @@ define i8 @nilpotent(i8 %x) {
 define i2 @idempotent(i2 %x) {
 ; CHECK-LABEL: define i2 @idempotent(
 ; CHECK-SAME: i2 [[X:%.*]]) {
-; CHECK-NEXT:    ret i2 -1
+; CHECK-NEXT:    ret i2 [[X]]
 ;
   %tmp1 = and i2 %x, %x
   %tmp2 = and i2 %tmp1, %x
@@ -60,8 +60,7 @@ define i3 @foo3x5(i3 %x) {
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]]
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]]
-; CHECK-NEXT:    ret i3 [[TMP5]]
+; CHECK-NEXT:    ret i3 [[TMP4]]
 ;
   %tmp1 = mul i3 %x, %x
   %tmp2 = mul i3 %tmp1, %x
@@ -75,8 +74,7 @@ define i3 @foo3x5_nsw(i3 %x) {
 ; CHECK-LABEL: define i3 @foo3x5_nsw(
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]]
 ; CHECK-NEXT:    ret i3 [[TMP4]]
 ;
   %tmp1 = mul i3 %x, %x
@@ -91,8 +89,7 @@ define i3 @foo3x6(i3 %x) {
 ; CHECK-LABEL: define i3 @foo3x6(
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i3 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i3 [[TMP2]]
 ;
   %tmp1 = mul i3 %x, %x
@@ -108,9 +105,7 @@ define i3 @foo3x7(i3 %x) {
 ; CHECK-LABEL: define i3 @foo3x7(
 ; CHECK-SAME: i3 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = mul i3 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP7:%.*]] = mul i3 [[TMP5]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i3 [[TMP7]], [[X]]
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i3 [[TMP3]], [[TMP7]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]]
 ; CHECK-NEXT:    ret i3 [[TMP6]]
 ;
   %tmp1 = mul i3 %x, %x
@@ -127,8 +122,7 @@ define i4 @foo4x8(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x8(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i4 [[TMP4]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -146,9 +140,8 @@ define i4 @foo4x9(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x9(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
-; CHECK-NEXT:    [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]]
 ; CHECK-NEXT:    ret i4 [[TMP8]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -167,8 +160,7 @@ define i4 @foo4x10(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x10(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
 ; CHECK-NEXT:    ret i4 [[TMP3]]
 ;
@@ -189,8 +181,7 @@ define i4 @foo4x11(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x11(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]]
 ; CHECK-NEXT:    ret i4 [[TMP10]]
@@ -213,9 +204,7 @@ define i4 @foo4x12(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x12(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
-; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]]
 ; CHECK-NEXT:    ret i4 [[TMP2]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -238,9 +227,7 @@ define i4 @foo4x13(i4 %x) {
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]]
-; CHECK-NEXT:    [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]]
 ; CHECK-NEXT:    ret i4 [[TMP12]]
 ;
   %tmp1 = mul i4 %x, %x
@@ -263,9 +250,7 @@ define i4 @foo4x14(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x14(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]]
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP5]], [[X]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]]
 ; CHECK-NEXT:    ret i4 [[TMP7]]
 ;
@@ -290,9 +275,7 @@ define i4 @foo4x15(i4 %x) {
 ; CHECK-LABEL: define i4 @foo4x15(
 ; CHECK-SAME: i4 [[X:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i4 [[X]], [[X]]
-; CHECK-NEXT:    [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]]
-; CHECK-NEXT:    [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]]
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP3]], [[X]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]]
 ; CHECK-NEXT:    ret i4 [[TMP14]]


        


More information about the llvm-commits mailing list