[llvm] c1f6ce0 - [DemandedBits] Improve accuracy of Add propagator
Roman Lebedev via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 17 05:05:31 PDT 2020
Looks like the git "Author:" line wasn't updated correctly here.
On Mon, Aug 17, 2020 at 2:57 PM Simon Pilgrim via llvm-commits
<llvm-commits at lists.llvm.org> wrote:
>
>
> Author: Simon Pilgrim
> Date: 2020-08-17T12:54:09+01:00
> New Revision: c1f6ce0c7322d47f1bb90169585fa54232231ede
>
> URL: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede
> DIFF: https://github.com/llvm/llvm-project/commit/c1f6ce0c7322d47f1bb90169585fa54232231ede.diff
>
> LOG: [DemandedBits] Improve accuracy of Add propagator
>
> The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.
>
> I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.
>
> This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.
>
> Known A: 0_______0_______0_______0_______
> Known B: 0_______0_______0_______0_______
> AOut: 00000000001000000000000000000000
> AB, current: 00000000001111111111111111111111
> AB, patch: 00000000001111111000000000000000
>
> Committed on behalf of: @rrika (Erika)
>
> Differential Revision: https://reviews.llvm.org/D72423
>
> Added:
> llvm/unittests/IR/DemandedBitsTest.cpp
> llvm/unittests/Support/KnownBitsTest.h
>
> Modified:
> llvm/include/llvm/Analysis/DemandedBits.h
> llvm/lib/Analysis/DemandedBits.cpp
> llvm/test/Analysis/DemandedBits/add.ll
> llvm/unittests/IR/CMakeLists.txt
> llvm/unittests/Support/KnownBitsTest.cpp
>
> Removed:
>
>
>
> ################################################################################
> diff --git a/llvm/include/llvm/Analysis/DemandedBits.h b/llvm/include/llvm/Analysis/DemandedBits.h
> index 04db3eb57c18..7a8618a27ce7 100644
> --- a/llvm/include/llvm/Analysis/DemandedBits.h
> +++ b/llvm/include/llvm/Analysis/DemandedBits.h
> @@ -61,6 +61,20 @@ class DemandedBits {
>
> void print(raw_ostream &OS);
>
> + /// Compute alive bits of one addition operand from alive output and known
> + /// operand bits
> + static APInt determineLiveOperandBitsAdd(unsigned OperandNo,
> + const APInt &AOut,
> + const KnownBits &LHS,
> + const KnownBits &RHS);
> +
> + /// Compute alive bits of one subtraction operand from alive output and known
> + /// operand bits
> + static APInt determineLiveOperandBitsSub(unsigned OperandNo,
> + const APInt &AOut,
> + const KnownBits &LHS,
> + const KnownBits &RHS);
> +
> private:
> void performAnalysis();
> void determineLiveOperandBits(const Instruction *UserI,
>
> diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp
> index aaee8c21f289..62e08f3f8a8b 100644
> --- a/llvm/lib/Analysis/DemandedBits.cpp
> +++ b/llvm/lib/Analysis/DemandedBits.cpp
> @@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits(
> }
> break;
> case Instruction::Add:
> + if (AOut.isMask()) {
> + AB = AOut;
> + } else {
> + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
> + AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
> + }
> + break;
> case Instruction::Sub:
> + if (AOut.isMask()) {
> + AB = AOut;
> + } else {
> + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
> + AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
> + }
> + break;
> case Instruction::Mul:
> // Find the highest live output bit. We don't need any more input
> // bits than that (adds, and thus subtracts, ripple only to the
> @@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) {
> }
> }
>
> +static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
> + const APInt &AOut,
> + const KnownBits &LHS,
> + const KnownBits &RHS,
> + bool CarryZero, bool CarryOne) {
> + assert(!(CarryZero && CarryOne) &&
> + "Carry can't be zero and one at the same time");
> +
> + // The following check should be done by the caller, as it also indicates
> + // that LHS and RHS don't need to be computed.
> + //
> + // if (AOut.isMask())
> + // return AOut;
> +
> + // Boundary bits' carry out is unaffected by their carry in.
> + APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
> +
> + // First, the alive carry bits are determined from the alive output bits:
> + // Let demand ripple to the right but only up to any set bit in Bound.
> + // AOut = -1----
> + // Bound = ----1-
> + // ACarry&~AOut = --111-
> + APInt RBound = Bound.reverseBits();
> + APInt RAOut = AOut.reverseBits();
> + APInt RProp = RAOut + (RAOut | ~RBound);
> + APInt RACarry = RProp ^ ~RBound;
> + APInt ACarry = RACarry.reverseBits();
> +
> + // Then, the alive input bits are determined from the alive carry bits:
> + APInt NeededToMaintainCarryZero;
> + APInt NeededToMaintainCarryOne;
> + if (OperandNo == 0) {
> + NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
> + NeededToMaintainCarryOne = LHS.One | ~RHS.One;
> + } else {
> + NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
> + NeededToMaintainCarryOne = RHS.One | ~LHS.One;
> + }
> +
> + // As in computeForAddCarry
> + APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
> + APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
> +
> + // The below is simplified from
> + //
> + // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
> + // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
> + // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
> + //
> + // APInt NeededToMaintainCarry =
> + // (CarryKnownZero & NeededToMaintainCarryZero) |
> + // (CarryKnownOne & NeededToMaintainCarryOne) |
> + // CarryUnknown;
> +
> + APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
> + (PossibleSumOne | NeededToMaintainCarryOne);
> +
> + APInt AB = AOut | (ACarry & NeededToMaintainCarry);
> + return AB;
> +}
> +
> +APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
> + const APInt &AOut,
> + const KnownBits &LHS,
> + const KnownBits &RHS) {
> + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
> + false);
> +}
> +
> +APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
> + const APInt &AOut,
> + const KnownBits &LHS,
> + const KnownBits &RHS) {
> + KnownBits NRHS;
> + NRHS.Zero = RHS.One;
> + NRHS.One = RHS.Zero;
> + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
> + true);
> +}
> +
> FunctionPass *llvm::createDemandedBitsWrapperPass() {
> return new DemandedBitsWrapperPass();
> }
>
> diff --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll
> index 9203ed15d627..01673f82c2b3 100644
> --- a/llvm/test/Analysis/DemandedBits/add.ll
> +++ b/llvm/test/Analysis/DemandedBits/add.ll
> @@ -1,22 +1,22 @@
> -; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
> -; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
> -
> -; CHECK-DAG: DemandedBits: 0x1f for %1 = and i32 %a, 9
> -; CHECK-DAG: DemandedBits: 0x1f for %2 = and i32 %b, 9
> -; CHECK-DAG: DemandedBits: 0x1f for %3 = and i32 %c, 13
> -; CHECK-DAG: DemandedBits: 0x1f for %4 = and i32 %d, 4
> -; CHECK-DAG: DemandedBits: 0x1f for %5 = or i32 %2, %3
> -; CHECK-DAG: DemandedBits: 0x1f for %6 = or i32 %4, %5
> +; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
> +; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
> +
> +; CHECK-DAG: DemandedBits: 0x1e for %1 = and i32 %a, 9
> +; CHECK-DAG: DemandedBits: 0x1a for %2 = and i32 %b, 9
> +; CHECK-DAG: DemandedBits: 0x1a for %3 = and i32 %c, 13
> +; CHECK-DAG: DemandedBits: 0x1a for %4 = and i32 %d, 4
> +; CHECK-DAG: DemandedBits: 0x1a for %5 = or i32 %2, %3
> +; CHECK-DAG: DemandedBits: 0x1a for %6 = or i32 %4, %5
> ; CHECK-DAG: DemandedBits: 0x10 for %7 = add i32 %1, %6
> ; CHECK-DAG: DemandedBits: 0xffffffff for %8 = and i32 %7, 16
> -define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
> - %1 = and i32 %a, 9
> - %2 = and i32 %b, 9
> - %3 = and i32 %c, 13
> - %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
> - %5 = or i32 %2, %3
> - %6 = or i32 %4, %5
> - %7 = add i32 %1, %6
> - %8 = and i32 %7, 16
> - ret i32 %8
> -}
> \ No newline at end of file
> +define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
> + %1 = and i32 %a, 9
> + %2 = and i32 %b, 9
> + %3 = and i32 %c, 13
> + %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
> + %5 = or i32 %2, %3
> + %6 = or i32 %4, %5
> + %7 = add i32 %1, %6
> + %8 = and i32 %7, 16
> + ret i32 %8
> +}
>
> diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
> index 4634bf89059a..c4386fed6174 100644
> --- a/llvm/unittests/IR/CMakeLists.txt
> +++ b/llvm/unittests/IR/CMakeLists.txt
> @@ -18,6 +18,7 @@ add_llvm_unittest(IRTests
> DataLayoutTest.cpp
> DebugInfoTest.cpp
> DebugTypeODRUniquingTest.cpp
> + DemandedBitsTest.cpp
> DominatorTreeTest.cpp
> DominatorTreeBatchUpdatesTest.cpp
> FunctionTest.cpp
>
> diff --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp
> new file mode 100644
> index 000000000000..4d15e8189961
> --- /dev/null
> +++ b/llvm/unittests/IR/DemandedBitsTest.cpp
> @@ -0,0 +1,66 @@
> +//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===//
> +//
> +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
> +// See https://llvm.org/LICENSE.txt for license information.
> +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
> +//
> +//===----------------------------------------------------------------------===//
> +
> +#include "llvm/Analysis/DemandedBits.h"
> +#include "../Support/KnownBitsTest.h"
> +#include "llvm/Support/KnownBits.h"
> +#include "gtest/gtest.h"
> +
> +using namespace llvm;
> +
> +namespace {
> +
> +template <typename Fn1, typename Fn2>
> +static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) {
> + unsigned Bits = 4;
> + unsigned Max = 1 << Bits;
> + ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
> + ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
> + for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) {
> + APInt AOut(Bits, AOut_);
> + APInt AB1 = PropagateFn(0, AOut, Known1, Known2);
> + APInt AB2 = PropagateFn(1, AOut, Known1, Known2);
> + {
> + // If the propagator claims that certain known bits
> + // didn't matter, check it doesn't change its mind
> + // when they become unknown.
> + KnownBits Known1Redacted;
> + KnownBits Known2Redacted;
> + Known1Redacted.Zero = Known1.Zero & AB1;
> + Known1Redacted.One = Known1.One & AB1;
> + Known2Redacted.Zero = Known2.Zero & AB2;
> + Known2Redacted.One = Known2.One & AB2;
> +
> + APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted);
> + APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted);
> + EXPECT_EQ(AB1, AB1R);
> + EXPECT_EQ(AB2, AB2R);
> + }
> + ForeachNumInKnownBits(Known1, [&](APInt Value1) {
> + ForeachNumInKnownBits(Known2, [&](APInt Value2) {
> + APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2));
> + APInt Result = EvalFn(Value1, Value2);
> + EXPECT_EQ(Result & AOut, ReferenceResult & AOut);
> + });
> + });
> + }
> + });
> + });
> +}
> +
> +TEST(DemandedBitsTest, Add) {
> + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd,
> + [](APInt N1, APInt N2) -> APInt { return N1 + N2; });
> +}
> +
> +TEST(DemandedBitsTest, Sub) {
> + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub,
> + [](APInt N1, APInt N2) -> APInt { return N1 - N2; });
> +}
> +
> +} // anonymous namespace
>
> diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
> index bfd8eb204caf..694e5c4dcc71 100644
> --- a/llvm/unittests/Support/KnownBitsTest.cpp
> +++ b/llvm/unittests/Support/KnownBitsTest.cpp
> @@ -11,41 +11,13 @@
> //===----------------------------------------------------------------------===//
>
> #include "llvm/Support/KnownBits.h"
> +#include "KnownBitsTest.h"
> #include "gtest/gtest.h"
>
> using namespace llvm;
>
> namespace {
>
> -template<typename FnTy>
> -void ForeachKnownBits(unsigned Bits, FnTy Fn) {
> - unsigned Max = 1 << Bits;
> - KnownBits Known(Bits);
> - for (unsigned Zero = 0; Zero < Max; ++Zero) {
> - for (unsigned One = 0; One < Max; ++One) {
> - Known.Zero = Zero;
> - Known.One = One;
> - if (Known.hasConflict())
> - continue;
> -
> - Fn(Known);
> - }
> - }
> -}
> -
> -template<typename FnTy>
> -void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
> - unsigned Bits = Known.getBitWidth();
> - unsigned Max = 1 << Bits;
> - for (unsigned N = 0; N < Max; ++N) {
> - APInt Num(Bits, N);
> - if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
> - continue;
> -
> - Fn(Num);
> - }
> -}
> -
> TEST(KnownBitsTest, AddCarryExhaustive) {
> unsigned Bits = 4;
> ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
>
> diff --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h
> new file mode 100644
> index 000000000000..bc291898814b
> --- /dev/null
> +++ b/llvm/unittests/Support/KnownBitsTest.h
> @@ -0,0 +1,52 @@
> +//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===//
> +//
> +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
> +// See https://llvm.org/LICENSE.txt for license information.
> +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
> +//
> +//===----------------------------------------------------------------------===//
> +//
> +// This file implements helpers for KnownBits and DemandedBits unit tests.
> +//
> +//===----------------------------------------------------------------------===//
> +
> +#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
> +#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
> +
> +#include "llvm/Support/KnownBits.h"
> +
> +namespace {
> +
> +using namespace llvm;
> +
> +template <typename FnTy> void ForeachKnownBits(unsigned Bits, FnTy Fn) {
> + unsigned Max = 1 << Bits;
> + KnownBits Known(Bits);
> + for (unsigned Zero = 0; Zero < Max; ++Zero) {
> + for (unsigned One = 0; One < Max; ++One) {
> + Known.Zero = Zero;
> + Known.One = One;
> + if (Known.hasConflict())
> + continue;
> +
> + Fn(Known);
> + }
> + }
> +}
> +
> +template <typename FnTy>
> +void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
> + unsigned Bits = Known.getBitWidth();
> + unsigned Max = 1 << Bits;
> + for (unsigned N = 0; N < Max; ++N) {
> + APInt Num(Bits, N);
> + if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
> + continue;
> +
> + Fn(Num);
> + }
> +}
> +
> +} // end anonymous namespace
> +
> +#endif
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
More information about the llvm-commits
mailing list