[llvm] c1f6ce0 - [DemandedBits] Improve accuracy of Add propagator

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 17 05:05:31 PDT 2020


Looks like the git "Author:" line wasn't updated correctly here.


On Mon, Aug 17, 2020 at 2:57 PM Simon Pilgrim via llvm-commits
<llvm-commits at lists.llvm.org> wrote:
>
>
> Author: Simon Pilgrim
> Date: 2020-08-17T12:54:09+01:00
> New Revision: c1f6ce0c7322d47f1bb90169585fa54232231ede
>
> URL: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede
> DIFF: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede.diff
>
> LOG: [DemandedBits] Improve accuracy of Add propagator
>
> The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.
>
> I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.
>
> This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.
>
> Known A:     0_______0_______0_______0_______
> Known B:     0_______0_______0_______0_______
> AOut:        00000000001000000000000000000000
> AB, current: 00000000001111111111111111111111
> AB, patch:   00000000001111111000000000000000
>
> Committed on behalf of: @rrika (Erika)
>
> Differential Revision: https://reviews.llvm.org/D72423
>
> Added:
>     llvm/unittests/IR/DemandedBitsTest.cpp
>     llvm/unittests/Support/KnownBitsTest.h
>
> Modified:
>     llvm/include/llvm/Analysis/DemandedBits.h
>     llvm/lib/Analysis/DemandedBits.cpp
>     llvm/test/Analysis/DemandedBits/add.ll
>     llvm/unittests/IR/CMakeLists.txt
>     llvm/unittests/Support/KnownBitsTest.cpp
>
> Removed:
>
>
>
> ################################################################################
> diff  --git a/llvm/include/llvm/Analysis/DemandedBits.h b/llvm/include/llvm/Analysis/DemandedBits.h
> index 04db3eb57c18..7a8618a27ce7 100644
> --- a/llvm/include/llvm/Analysis/DemandedBits.h
> +++ b/llvm/include/llvm/Analysis/DemandedBits.h
> @@ -61,6 +61,20 @@ class DemandedBits {
>
>    void print(raw_ostream &OS);
>
> +  /// Compute alive bits of one addition operand from alive output and known
> +  /// operand bits
> +  static APInt determineLiveOperandBitsAdd(unsigned OperandNo,
> +                                           const APInt &AOut,
> +                                           const KnownBits &LHS,
> +                                           const KnownBits &RHS);
> +
> +  /// Compute alive bits of one subtraction operand from alive output and known
> +  /// operand bits
> +  static APInt determineLiveOperandBitsSub(unsigned OperandNo,
> +                                           const APInt &AOut,
> +                                           const KnownBits &LHS,
> +                                           const KnownBits &RHS);
> +
>  private:
>    void performAnalysis();
>    void determineLiveOperandBits(const Instruction *UserI,
>
> diff  --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
> index aaee8c21f289..62e08f3f8a8b 100644
> --- a/llvm/lib/Analysis/DemandedBits.cpp
> +++ b/llvm/lib/Analysis/DemandedBits.cpp
> @@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits(
>        }
>      break;
>    case Instruction::Add:
> +    if (AOut.isMask()) {
> +      AB = AOut;
> +    } else {
> +      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
> +      AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
> +    }
> +    break;
>    case Instruction::Sub:
> +    if (AOut.isMask()) {
> +      AB = AOut;
> +    } else {
> +      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
> +      AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
> +    }
> +    break;
>    case Instruction::Mul:
>      // Find the highest live output bit. We don't need any more input
>      // bits than that (adds, and thus subtracts, ripple only to the
> @@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) {
>    }
>  }
>
> +static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
> +                                              const APInt &AOut,
> +                                              const KnownBits &LHS,
> +                                              const KnownBits &RHS,
> +                                              bool CarryZero, bool CarryOne) {
> +  assert(!(CarryZero && CarryOne) &&
> +         "Carry can't be zero and one at the same time");
> +
> +  // The following check should be done by the caller, as it also indicates
> +  // that LHS and RHS don't need to be computed.
> +  //
> +  // if (AOut.isMask())
> +  //   return AOut;
> +
> +  // Boundary bits' carry out is unaffected by their carry in.
> +  APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
> +
> +  // First, the alive carry bits are determined from the alive output bits:
> +  // Let demand ripple to the right but only up to any set bit in Bound.
> +  //   AOut         = -1----
> +  //   Bound        = ----1-
> +  //   ACarry&~AOut = --111-
> +  APInt RBound = Bound.reverseBits();
> +  APInt RAOut = AOut.reverseBits();
> +  APInt RProp = RAOut + (RAOut | ~RBound);
> +  APInt RACarry = RProp ^ ~RBound;
> +  APInt ACarry = RACarry.reverseBits();
> +
> +  // Then, the alive input bits are determined from the alive carry bits:
> +  APInt NeededToMaintainCarryZero;
> +  APInt NeededToMaintainCarryOne;
> +  if (OperandNo == 0) {
> +    NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
> +    NeededToMaintainCarryOne = LHS.One | ~RHS.One;
> +  } else {
> +    NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
> +    NeededToMaintainCarryOne = RHS.One | ~LHS.One;
> +  }
> +
> +  // As in computeForAddCarry
> +  APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
> +  APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
> +
> +  // The below is simplified from
> +  //
> +  // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
> +  // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
> +  // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
> +  //
> +  // APInt NeededToMaintainCarry =
> +  //   (CarryKnownZero & NeededToMaintainCarryZero) |
> +  //   (CarryKnownOne  & NeededToMaintainCarryOne) |
> +  //   CarryUnknown;
> +
> +  APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
> +                                (PossibleSumOne | NeededToMaintainCarryOne);
> +
> +  APInt AB = AOut | (ACarry & NeededToMaintainCarry);
> +  return AB;
> +}
> +
> +APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
> +                                                const APInt &AOut,
> +                                                const KnownBits &LHS,
> +                                                const KnownBits &RHS) {
> +  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
> +                                          false);
> +}
> +
> +APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
> +                                                const APInt &AOut,
> +                                                const KnownBits &LHS,
> +                                                const KnownBits &RHS) {
> +  KnownBits NRHS;
> +  NRHS.Zero = RHS.One;
> +  NRHS.One = RHS.Zero;
> +  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
> +                                          true);
> +}
> +
>  FunctionPass *llvm::createDemandedBitsWrapperPass() {
>    return new DemandedBitsWrapperPass();
>  }
>
> diff  --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll
> index 9203ed15d627..01673f82c2b3 100644
> --- a/llvm/test/Analysis/DemandedBits/add.ll
> +++ b/llvm/test/Analysis/DemandedBits/add.ll
> @@ -1,22 +1,22 @@
> -; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
> -; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
> -
> -; CHECK-DAG: DemandedBits: 0x1f for   %1 = and i32 %a, 9
> -; CHECK-DAG: DemandedBits: 0x1f for   %2 = and i32 %b, 9
> -; CHECK-DAG: DemandedBits: 0x1f for   %3 = and i32 %c, 13
> -; CHECK-DAG: DemandedBits: 0x1f for   %4 = and i32 %d, 4
> -; CHECK-DAG: DemandedBits: 0x1f for   %5 = or i32 %2, %3
> -; CHECK-DAG: DemandedBits: 0x1f for   %6 = or i32 %4, %5
> +; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
> +; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
> +
> +; CHECK-DAG: DemandedBits: 0x1e for   %1 = and i32 %a, 9
> +; CHECK-DAG: DemandedBits: 0x1a for   %2 = and i32 %b, 9
> +; CHECK-DAG: DemandedBits: 0x1a for   %3 = and i32 %c, 13
> +; CHECK-DAG: DemandedBits: 0x1a for   %4 = and i32 %d, 4
> +; CHECK-DAG: DemandedBits: 0x1a for   %5 = or i32 %2, %3
> +; CHECK-DAG: DemandedBits: 0x1a for   %6 = or i32 %4, %5
>  ; CHECK-DAG: DemandedBits: 0x10 for   %7 = add i32 %1, %6
>  ; CHECK-DAG: DemandedBits: 0xffffffff for   %8 = and i32 %7, 16
> -define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
> -  %1 = and i32 %a, 9
> -  %2 = and i32 %b, 9
> -  %3 = and i32 %c, 13
> -  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
> -  %5 = or i32 %2, %3
> -  %6 = or i32 %4, %5
> -  %7 = add i32 %1, %6
> -  %8 = and i32 %7, 16
> -  ret i32 %8
> -}
> \ No newline at end of file
> +define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
> +  %1 = and i32 %a, 9
> +  %2 = and i32 %b, 9
> +  %3 = and i32 %c, 13
> +  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
> +  %5 = or i32 %2, %3
> +  %6 = or i32 %4, %5
> +  %7 = add i32 %1, %6
> +  %8 = and i32 %7, 16
> +  ret i32 %8
> +}
>
> diff  --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
> index 4634bf89059a..c4386fed6174 100644
> --- a/llvm/unittests/IR/CMakeLists.txt
> +++ b/llvm/unittests/IR/CMakeLists.txt
> @@ -18,6 +18,7 @@ add_llvm_unittest(IRTests
>    DataLayoutTest.cpp
>    DebugInfoTest.cpp
>    DebugTypeODRUniquingTest.cpp
> +  DemandedBitsTest.cpp
>    DominatorTreeTest.cpp
>    DominatorTreeBatchUpdatesTest.cpp
>    FunctionTest.cpp
>
> diff  --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp
> new file mode 100644
> index 000000000000..4d15e8189961
> --- /dev/null
> +++ b/llvm/unittests/IR/DemandedBitsTest.cpp
> @@ -0,0 +1,66 @@
> +//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===//
> +//
> +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
> +// See https://llvm.org/LICENSE.txt for license information.
> +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
> +//
> +//===----------------------------------------------------------------------===//
> +
> +#include "llvm/Analysis/DemandedBits.h"
> +#include "../Support/KnownBitsTest.h"
> +#include "llvm/Support/KnownBits.h"
> +#include "gtest/gtest.h"
> +
> +using namespace llvm;
> +
> +namespace {
> +
> +template <typename Fn1, typename Fn2>
> +static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) {
> +  unsigned Bits = 4;
> +  unsigned Max = 1 << Bits;
> +  ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
> +    ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
> +      for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) {
> +        APInt AOut(Bits, AOut_);
> +        APInt AB1 = PropagateFn(0, AOut, Known1, Known2);
> +        APInt AB2 = PropagateFn(1, AOut, Known1, Known2);
> +        {
> +          // If the propagator claims that certain known bits
> +          // didn't matter, check it doesn't change its mind
> +          // when they become unknown.
> +          KnownBits Known1Redacted;
> +          KnownBits Known2Redacted;
> +          Known1Redacted.Zero = Known1.Zero & AB1;
> +          Known1Redacted.One = Known1.One & AB1;
> +          Known2Redacted.Zero = Known2.Zero & AB2;
> +          Known2Redacted.One = Known2.One & AB2;
> +
> +          APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted);
> +          APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted);
> +          EXPECT_EQ(AB1, AB1R);
> +          EXPECT_EQ(AB2, AB2R);
> +        }
> +        ForeachNumInKnownBits(Known1, [&](APInt Value1) {
> +          ForeachNumInKnownBits(Known2, [&](APInt Value2) {
> +            APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2));
> +            APInt Result = EvalFn(Value1, Value2);
> +            EXPECT_EQ(Result & AOut, ReferenceResult & AOut);
> +          });
> +        });
> +      }
> +    });
> +  });
> +}
> +
> +TEST(DemandedBitsTest, Add) {
> +  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd,
> +                      [](APInt N1, APInt N2) -> APInt { return N1 + N2; });
> +}
> +
> +TEST(DemandedBitsTest, Sub) {
> +  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub,
> +                      [](APInt N1, APInt N2) -> APInt { return N1 - N2; });
> +}
> +
> +} // anonymous namespace
>
> diff  --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
> index bfd8eb204caf..694e5c4dcc71 100644
> --- a/llvm/unittests/Support/KnownBitsTest.cpp
> +++ b/llvm/unittests/Support/KnownBitsTest.cpp
> @@ -11,41 +11,13 @@
>  //===----------------------------------------------------------------------===//
>
>  #include "llvm/Support/KnownBits.h"
> +#include "KnownBitsTest.h"
>  #include "gtest/gtest.h"
>
>  using namespace llvm;
>
>  namespace {
>
> -template<typename FnTy>
> -void ForeachKnownBits(unsigned Bits, FnTy Fn) {
> -  unsigned Max = 1 << Bits;
> -  KnownBits Known(Bits);
> -  for (unsigned Zero = 0; Zero < Max; ++Zero) {
> -    for (unsigned One = 0; One < Max; ++One) {
> -      Known.Zero = Zero;
> -      Known.One = One;
> -      if (Known.hasConflict())
> -        continue;
> -
> -      Fn(Known);
> -    }
> -  }
> -}
> -
> -template<typename FnTy>
> -void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
> -  unsigned Bits = Known.getBitWidth();
> -  unsigned Max = 1 << Bits;
> -  for (unsigned N = 0; N < Max; ++N) {
> -    APInt Num(Bits, N);
> -    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
> -      continue;
> -
> -    Fn(Num);
> -  }
> -}
> -
>  TEST(KnownBitsTest, AddCarryExhaustive) {
>    unsigned Bits = 4;
>    ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
>
> diff  --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h
> new file mode 100644
> index 000000000000..bc291898814b
> --- /dev/null
> +++ b/llvm/unittests/Support/KnownBitsTest.h
> @@ -0,0 +1,52 @@
> +//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===//
> +//
> +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
> +// See https://llvm.org/LICENSE.txt for license information.
> +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
> +//
> +//===----------------------------------------------------------------------===//
> +//
> +// This file implements helpers for KnownBits and DemandedBits unit tests.
> +//
> +//===----------------------------------------------------------------------===//
> +
> +#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
> +#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
> +
> +#include "llvm/Support/KnownBits.h"
> +
> +namespace {
> +
> +using namespace llvm;
> +
> +template <typename FnTy> void ForeachKnownBits(unsigned Bits, FnTy Fn) {
> +  unsigned Max = 1 << Bits;
> +  KnownBits Known(Bits);
> +  for (unsigned Zero = 0; Zero < Max; ++Zero) {
> +    for (unsigned One = 0; One < Max; ++One) {
> +      Known.Zero = Zero;
> +      Known.One = One;
> +      if (Known.hasConflict())
> +        continue;
> +
> +      Fn(Known);
> +    }
> +  }
> +}
> +
> +template <typename FnTy>
> +void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
> +  unsigned Bits = Known.getBitWidth();
> +  unsigned Max = 1 << Bits;
> +  for (unsigned N = 0; N < Max; ++N) {
> +    APInt Num(Bits, N);
> +    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
> +      continue;
> +
> +    Fn(Num);
> +  }
> +}
> +
> +} // end anonymous namespace
> +
> +#endif
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list