[llvm] c1f6ce0 - [DemandedBits] Improve accuracy of Add propagator

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 17 04:57:07 PDT 2020


Author: Simon Pilgrim
Date: 2020-08-17T12:54:09+01:00
New Revision: c1f6ce0c7322d47f1bb90169585fa54232231ede

URL: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede
DIFF: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede.diff

LOG: [DemandedBits] Improve accuracy of Add propagator

The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.

I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.

This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.

Known A:     0_______0_______0_______0_______
Known B:     0_______0_______0_______0_______
AOut:        00000000001000000000000000000000
AB, current: 00000000001111111111111111111111
AB, patch:   00000000001111111000000000000000

Committed on behalf of: @rrika (Erika)

Differential Revision: https://reviews.llvm.org/D72423

Added: 
    llvm/unittests/IR/DemandedBitsTest.cpp
    llvm/unittests/Support/KnownBitsTest.h

Modified: 
    llvm/include/llvm/Analysis/DemandedBits.h
    llvm/lib/Analysis/DemandedBits.cpp
    llvm/test/Analysis/DemandedBits/add.ll
    llvm/unittests/IR/CMakeLists.txt
    llvm/unittests/Support/KnownBitsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/DemandedBits.h b/llvm/include/llvm/Analysis/DemandedBits.h
index 04db3eb57c18..7a8618a27ce7 100644
--- a/llvm/include/llvm/Analysis/DemandedBits.h
+++ b/llvm/include/llvm/Analysis/DemandedBits.h
@@ -61,6 +61,20 @@ class DemandedBits {
 
   void print(raw_ostream &OS);
 
+  /// Compute alive bits of one addition operand from alive output and known
+  /// operand bits
+  static APInt determineLiveOperandBitsAdd(unsigned OperandNo,
+                                           const APInt &AOut,
+                                           const KnownBits &LHS,
+                                           const KnownBits &RHS);
+
+  /// Compute alive bits of one subtraction operand from alive output and known
+  /// operand bits
+  static APInt determineLiveOperandBitsSub(unsigned OperandNo,
+                                           const APInt &AOut,
+                                           const KnownBits &LHS,
+                                           const KnownBits &RHS);
+
 private:
   void performAnalysis();
   void determineLiveOperandBits(const Instruction *UserI,

diff  --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
index aaee8c21f289..62e08f3f8a8b 100644
--- a/llvm/lib/Analysis/DemandedBits.cpp
+++ b/llvm/lib/Analysis/DemandedBits.cpp
@@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits(
       }
     break;
   case Instruction::Add:
+    if (AOut.isMask()) {
+      AB = AOut;
+    } else {
+      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+      AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
+    }
+    break;
   case Instruction::Sub:
+    if (AOut.isMask()) {
+      AB = AOut;
+    } else {
+      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+      AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
+    }
+    break;
   case Instruction::Mul:
     // Find the highest live output bit. We don't need any more input
     // bits than that (adds, and thus subtracts, ripple only to the
@@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) {
   }
 }
 
+static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
+                                              const APInt &AOut,
+                                              const KnownBits &LHS,
+                                              const KnownBits &RHS,
+                                              bool CarryZero, bool CarryOne) {
+  assert(!(CarryZero && CarryOne) &&
+         "Carry can't be zero and one at the same time");
+
+  // The following check should be done by the caller, as it also indicates
+  // that LHS and RHS don't need to be computed.
+  //
+  // if (AOut.isMask())
+  //   return AOut;
+
+  // Boundary bits' carry out is unaffected by their carry in.
+  APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
+
+  // First, the alive carry bits are determined from the alive output bits:
+  // Let demand ripple to the right but only up to any set bit in Bound.
+  //   AOut         = -1----
+  //   Bound        = ----1-
+  //   ACarry&~AOut = --111-
+  APInt RBound = Bound.reverseBits();
+  APInt RAOut = AOut.reverseBits();
+  APInt RProp = RAOut + (RAOut | ~RBound);
+  APInt RACarry = RProp ^ ~RBound;
+  APInt ACarry = RACarry.reverseBits();
+
+  // Then, the alive input bits are determined from the alive carry bits:
+  APInt NeededToMaintainCarryZero;
+  APInt NeededToMaintainCarryOne;
+  if (OperandNo == 0) {
+    NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
+    NeededToMaintainCarryOne = LHS.One | ~RHS.One;
+  } else {
+    NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
+    NeededToMaintainCarryOne = RHS.One | ~LHS.One;
+  }
+
+  // As in computeForAddCarry
+  APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
+  APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
+
+  // The below is simplified from
+  //
+  // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
+  // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
+  // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
+  //
+  // APInt NeededToMaintainCarry =
+  //   (CarryKnownZero & NeededToMaintainCarryZero) |
+  //   (CarryKnownOne  & NeededToMaintainCarryOne) |
+  //   CarryUnknown;
+
+  APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
+                                (PossibleSumOne | NeededToMaintainCarryOne);
+
+  APInt AB = AOut | (ACarry & NeededToMaintainCarry);
+  return AB;
+}
+
+APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
+                                                const APInt &AOut,
+                                                const KnownBits &LHS,
+                                                const KnownBits &RHS) {
+  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
+                                          false);
+}
+
+APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
+                                                const APInt &AOut,
+                                                const KnownBits &LHS,
+                                                const KnownBits &RHS) {
+  KnownBits NRHS;
+  NRHS.Zero = RHS.One;
+  NRHS.One = RHS.Zero;
+  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
+                                          true);
+}
+
 FunctionPass *llvm::createDemandedBitsWrapperPass() {
   return new DemandedBitsWrapperPass();
 }

diff  --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll
index 9203ed15d627..01673f82c2b3 100644
--- a/llvm/test/Analysis/DemandedBits/add.ll
+++ b/llvm/test/Analysis/DemandedBits/add.ll
@@ -1,22 +1,22 @@
-; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
-; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
-
-; CHECK-DAG: DemandedBits: 0x1f for   %1 = and i32 %a, 9
-; CHECK-DAG: DemandedBits: 0x1f for   %2 = and i32 %b, 9
-; CHECK-DAG: DemandedBits: 0x1f for   %3 = and i32 %c, 13
-; CHECK-DAG: DemandedBits: 0x1f for   %4 = and i32 %d, 4
-; CHECK-DAG: DemandedBits: 0x1f for   %5 = or i32 %2, %3
-; CHECK-DAG: DemandedBits: 0x1f for   %6 = or i32 %4, %5
+; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
+; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
+
+; CHECK-DAG: DemandedBits: 0x1e for   %1 = and i32 %a, 9
+; CHECK-DAG: DemandedBits: 0x1a for   %2 = and i32 %b, 9
+; CHECK-DAG: DemandedBits: 0x1a for   %3 = and i32 %c, 13
+; CHECK-DAG: DemandedBits: 0x1a for   %4 = and i32 %d, 4
+; CHECK-DAG: DemandedBits: 0x1a for   %5 = or i32 %2, %3
+; CHECK-DAG: DemandedBits: 0x1a for   %6 = or i32 %4, %5
 ; CHECK-DAG: DemandedBits: 0x10 for   %7 = add i32 %1, %6
 ; CHECK-DAG: DemandedBits: 0xffffffff for   %8 = and i32 %7, 16
-define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
-  %1 = and i32 %a, 9
-  %2 = and i32 %b, 9
-  %3 = and i32 %c, 13
-  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
-  %5 = or i32 %2, %3
-  %6 = or i32 %4, %5
-  %7 = add i32 %1, %6
-  %8 = and i32 %7, 16
-  ret i32 %8
-}
\ No newline at end of file
+define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
+  %1 = and i32 %a, 9
+  %2 = and i32 %b, 9
+  %3 = and i32 %c, 13
+  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
+  %5 = or i32 %2, %3
+  %6 = or i32 %4, %5
+  %7 = add i32 %1, %6
+  %8 = and i32 %7, 16
+  ret i32 %8
+}

diff  --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
index 4634bf89059a..c4386fed6174 100644
--- a/llvm/unittests/IR/CMakeLists.txt
+++ b/llvm/unittests/IR/CMakeLists.txt
@@ -18,6 +18,7 @@ add_llvm_unittest(IRTests
   DataLayoutTest.cpp
   DebugInfoTest.cpp
   DebugTypeODRUniquingTest.cpp
+  DemandedBitsTest.cpp
   DominatorTreeTest.cpp
   DominatorTreeBatchUpdatesTest.cpp
   FunctionTest.cpp

diff  --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp
new file mode 100644
index 000000000000..4d15e8189961
--- /dev/null
+++ b/llvm/unittests/IR/DemandedBitsTest.cpp
@@ -0,0 +1,66 @@
+//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DemandedBits.h"
+#include "../Support/KnownBitsTest.h"
+#include "llvm/Support/KnownBits.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+template <typename Fn1, typename Fn2>
+static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) {
+  unsigned Bits = 4;
+  unsigned Max = 1 << Bits;
+  ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+      for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) {
+        APInt AOut(Bits, AOut_);
+        APInt AB1 = PropagateFn(0, AOut, Known1, Known2);
+        APInt AB2 = PropagateFn(1, AOut, Known1, Known2);
+        {
+          // If the propagator claims that certain known bits
+          // didn't matter, check it doesn't change its mind
+          // when they become unknown.
+          KnownBits Known1Redacted;
+          KnownBits Known2Redacted;
+          Known1Redacted.Zero = Known1.Zero & AB1;
+          Known1Redacted.One = Known1.One & AB1;
+          Known2Redacted.Zero = Known2.Zero & AB2;
+          Known2Redacted.One = Known2.One & AB2;
+
+          APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted);
+          APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted);
+          EXPECT_EQ(AB1, AB1R);
+          EXPECT_EQ(AB2, AB2R);
+        }
+        ForeachNumInKnownBits(Known1, [&](APInt Value1) {
+          ForeachNumInKnownBits(Known2, [&](APInt Value2) {
+            APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2));
+            APInt Result = EvalFn(Value1, Value2);
+            EXPECT_EQ(Result & AOut, ReferenceResult & AOut);
+          });
+        });
+      }
+    });
+  });
+}
+
+TEST(DemandedBitsTest, Add) {
+  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd,
+                      [](APInt N1, APInt N2) -> APInt { return N1 + N2; });
+}
+
+TEST(DemandedBitsTest, Sub) {
+  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub,
+                      [](APInt N1, APInt N2) -> APInt { return N1 - N2; });
+}
+
+} // anonymous namespace

diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index bfd8eb204caf..694e5c4dcc71 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -11,41 +11,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/KnownBits.h"
+#include "KnownBitsTest.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
 
 namespace {
 
-template<typename FnTy>
-void ForeachKnownBits(unsigned Bits, FnTy Fn) {
-  unsigned Max = 1 << Bits;
-  KnownBits Known(Bits);
-  for (unsigned Zero = 0; Zero < Max; ++Zero) {
-    for (unsigned One = 0; One < Max; ++One) {
-      Known.Zero = Zero;
-      Known.One = One;
-      if (Known.hasConflict())
-        continue;
-
-      Fn(Known);
-    }
-  }
-}
-
-template<typename FnTy>
-void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
-  unsigned Bits = Known.getBitWidth();
-  unsigned Max = 1 << Bits;
-  for (unsigned N = 0; N < Max; ++N) {
-    APInt Num(Bits, N);
-    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
-      continue;
-
-    Fn(Num);
-  }
-}
-
 TEST(KnownBitsTest, AddCarryExhaustive) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {

diff  --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h
new file mode 100644
index 000000000000..bc291898814b
--- /dev/null
+++ b/llvm/unittests/Support/KnownBitsTest.h
@@ -0,0 +1,52 @@
+//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements helpers for KnownBits and DemandedBits unit tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
+#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
+
+#include "llvm/Support/KnownBits.h"
+
+namespace {
+
+using namespace llvm;
+
+template <typename FnTy> void ForeachKnownBits(unsigned Bits, FnTy Fn) {
+  unsigned Max = 1 << Bits;
+  KnownBits Known(Bits);
+  for (unsigned Zero = 0; Zero < Max; ++Zero) {
+    for (unsigned One = 0; One < Max; ++One) {
+      Known.Zero = Zero;
+      Known.One = One;
+      if (Known.hasConflict())
+        continue;
+
+      Fn(Known);
+    }
+  }
+}
+
+template <typename FnTy>
+void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
+  unsigned Bits = Known.getBitWidth();
+  unsigned Max = 1 << Bits;
+  for (unsigned N = 0; N < Max; ++N) {
+    APInt Num(Bits, N);
+    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
+      continue;
+
+    Fn(Num);
+  }
+}
+
+} // end anonymous namespace
+
+#endif


        


More information about the llvm-commits mailing list