[llvm] [InstCombine] Improve bitfield addition (PR #77184)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 21 09:20:30 PST 2024
================
@@ -3379,6 +3379,149 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
}
+struct BitFieldAddBitMask {
+ const APInt *Lower;
+ const APInt *Upper;
+};
+struct BitFieldOptBitMask {
+ const APInt *Lower;
+ const APInt *Upper;
+ const APInt *New;
+};
+struct BitFieldAddInfo {
+ Value *X;
+ Value *Y;
+ bool opt;
+ union {
+ BitFieldAddBitMask AddMask;
+ BitFieldOptBitMask OptMask;
+ };
+};
+
+static Value *foldBitFieldArithmetic(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ auto *Disjoint = dyn_cast<PossiblyDisjointInst>(&I);
+ if (!Disjoint || !Disjoint->isDisjoint())
+ return nullptr;
+
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ auto AccumulateY = [&](Value *LoY, Value *UpY, APInt LoMask,
+ APInt UpMask) -> Value * {
+ Value *Y = nullptr;
+ auto CLoY = dyn_cast_or_null<Constant>(LoY);
+ auto CUpY = dyn_cast_or_null<Constant>(UpY);
+ if ((CLoY == nullptr) ^ (CUpY == nullptr))
+ return nullptr;
+
+ if (CLoY && CUpY) {
+ APInt IUpY = CUpY->getUniqueInteger();
+ APInt ILoY = CLoY->getUniqueInteger();
+ if (!(IUpY.isSubsetOf(UpMask) && ILoY.isSubsetOf(LoMask)))
+ return nullptr;
+ Y = ConstantInt::get(CLoY->getType(), ILoY + IUpY);
+ } else if (LoY == UpY) {
+ Y = LoY;
+ }
+
+ return Y;
+ };
+
+ auto MatchBitFieldAdd =
+ [&](BinaryOperator &I) -> std::optional<BitFieldAddInfo> {
+ const APInt *OptLoMask, *OptUpMask, *LoMask, *UpMask, *UpMask2 = nullptr;
+ Value *X, *Y, *UpY;
+ auto BitFieldAddUpper = m_CombineOr(
+ m_And(m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)),
+ m_APInt(UpMask2)),
+ m_c_Add(m_And(m_Value(X), m_APInt(UpMask)), m_Value(UpY)));
+ auto BitFieldAdd =
+ m_c_Or(BitFieldAddUpper,
+ m_And(m_c_Add(m_Deferred(X), m_Value(Y)), m_APInt(LoMask)));
+ auto BitFieldAddIC =
+ m_c_Or(m_And(m_c_Add(m_Value(X), m_Value(Y)), m_APInt(LoMask)),
+ m_And(m_c_Add(m_Deferred(X), m_Value(UpY)), m_APInt(UpMask)));
+ auto OptBitFieldAdd = m_c_Or(
+ m_c_Xor(m_CombineOr(
+ m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)),
+ m_And(m_Value(Y), m_APInt(OptLoMask))),
+ m_c_Add(m_And(m_Value(X), m_APInt(OptLoMask)), m_Value(Y))),
+ m_CombineOr(m_And(m_Deferred(X), m_APInt(OptUpMask)),
+ m_And(m_c_Xor(m_Deferred(X), m_Value(UpY)),
+ m_APInt(OptUpMask)))),
+ BitFieldAddUpper);
+
+ if (match(&I, BitFieldAdd) || match(&I, BitFieldAddIC)) {
+ APInt Mask = APInt::getBitsSet(BitWidth, BitWidth - UpMask->countl_zero(),
+ BitWidth);
+ if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
+ (LoMask->popcount() >= 2 && UpMask->popcount() >= 2) &&
+ (LoMask->isShiftedMask() && UpMask->isShiftedMask()) &&
+ ((*LoMask & *UpMask) == 0) &&
+ ((Mask ^ *LoMask ^ *UpMask).isAllOnes())))
+ return std::nullopt;
+
+ if (!(Y = AccumulateY(Y, UpY, *LoMask, *UpMask)))
+ return std::nullopt;
+
+ return {{X, Y, false, {{LoMask, UpMask}}}};
+ }
+
+ if (match(&I, OptBitFieldAdd)) {
+ APInt Mask = APInt::getBitsSet(
+ BitWidth, BitWidth - OptUpMask->countl_zero(), BitWidth);
+ APInt Mask2 = APInt::getBitsSet(
+ BitWidth, BitWidth - UpMask->countl_zero(), BitWidth);
+ if (!((UpMask2 == nullptr || *UpMask == *UpMask2) &&
+ (UpMask->isShiftedMask() && UpMask->popcount() >= 2) &&
+ ((*UpMask & (*OptLoMask | *OptUpMask)) == 0) &&
+ ((~*OptLoMask ^ Mask) == *OptUpMask) &&
+ (Mask2 ^ *UpMask ^ (*OptLoMask ^ *OptUpMask)).isAllOnes()))
+ return std::nullopt;
+
+ if (!(Y = AccumulateY(Y, UpY, (*OptLoMask + *OptUpMask), *UpMask)))
+ return std::nullopt;
+
+ struct BitFieldAddInfo Info = {X, Y, true, {{OptLoMask, OptUpMask}}};
+ Info.OptMask.New = UpMask;
+ return {Info};
+ }
+
+ return std::nullopt;
+ };
+
+ auto Info = MatchBitFieldAdd(I);
+ if (Info) {
+ Value *X = Info->X;
+ Value *Y = Info->Y;
+ APInt BitLoMask, BitUpMask;
+ if (Info->opt) {
+ unsigned NewHiBit = BitWidth - (Info->OptMask.New->countl_zero() + 1);
+ BitLoMask = *Info->OptMask.Lower | *Info->OptMask.New;
+ BitLoMask.clearBit(NewHiBit);
+ BitUpMask = *Info->OptMask.Upper;
+ BitUpMask.setBit(NewHiBit);
+ } else {
+ unsigned LowerHiBit = BitWidth - (Info->AddMask.Lower->countl_zero() + 1);
+ unsigned UpperHiBit = BitWidth - (Info->AddMask.Upper->countl_zero() + 1);
+ BitLoMask = *Info->AddMask.Lower | *Info->AddMask.Upper;
+ BitLoMask.clearBit(LowerHiBit);
+ BitLoMask.clearBit(UpperHiBit);
+ BitUpMask = APInt::getOneBitSet(BitWidth, LowerHiBit);
+ BitUpMask.setBit(UpperHiBit);
+ }
+
+ auto AndXLower = Builder.CreateAnd(X, BitLoMask);
+ auto AndYLower = Builder.CreateAnd(Y, BitLoMask);
+ auto Add = Builder.CreateNUWAdd(AndXLower, AndYLower);
+ auto Xor1 = Builder.CreateXor(X, Y);
+ auto AndUpper = Builder.CreateAnd(Xor1, BitUpMask);
+ auto Xor = Builder.CreateXor(Add, AndUpper);
+ return Xor;
+ }
+
+ return nullptr;
+}
----------------
goldsteinn wrote:
This entire impl is in desperate need of some comments.
https://github.com/llvm/llvm-project/pull/77184
More information about the llvm-commits
mailing list