[llvm] fd2bb51 - [ADT] Add APInt/MathExtras isShiftedMask variant returning mask offset/length
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 8 04:04:30 PST 2022
Author: Simon Pilgrim
Date: 2022-02-08T12:04:13Z
New Revision: fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4
URL: https://github.com/llvm/llvm-project/commit/fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4
DIFF: https://github.com/llvm/llvm-project/commit/fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4.diff
LOG: [ADT] Add APInt/MathExtras isShiftedMask variant returning mask offset/length
In many cases, calls to isShiftedMask are immediately followed with checks to determine the size and position of the bitmask.
This patch adds variants of APInt::isShiftedMask, isShiftedMask_32 and isShiftedMask_64 that return these values as additional arguments.
I've updated a number of cases that were either performing seperate size/position calculations or had created their own local wrapper versions of these.
Differential Revision: https://reviews.llvm.org/D119019
Added:
Modified:
llvm/include/llvm/ADT/APInt.h
llvm/include/llvm/Support/MathExtras.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
llvm/lib/Target/Mips/MipsISelLowering.cpp
llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
llvm/unittests/ADT/APIntTest.cpp
llvm/unittests/Support/MathExtrasTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 0fdcbb6c3dfbc..a475e27c797d2 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -496,6 +496,23 @@ class LLVM_NODISCARD APInt {
return (Ones + LeadZ + countTrailingZeros()) == BitWidth;
}
+ /// Return true if this APInt value contains a non-empty sequence of ones with
+ /// the remainder zero. If true, \p MaskIdx will specify the index of the
+ /// lowest set bit and \p MaskLen is updated to specify the length of the
+ /// mask, else neither are updated.
+ bool isShiftedMask(unsigned &MaskIdx, unsigned &MaskLen) const {
+ if (isSingleWord())
+ return isShiftedMask_64(U.VAL, MaskIdx, MaskLen);
+ unsigned Ones = countPopulationSlowCase();
+ unsigned LeadZ = countLeadingZerosSlowCase();
+ unsigned TrailZ = countTrailingZerosSlowCase();
+ if ((Ones + LeadZ + TrailZ) != BitWidth)
+ return false;
+ MaskLen = Ones;
+ MaskIdx = TrailZ;
+ return true;
+ }
+
/// Compute an APInt containing numBits highbits from this APInt.
///
/// Get an APInt with the same BitWidth as this APInt, just zero mask the low
diff --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 753b1998c40c0..ccb0f5594ebd5 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -571,6 +571,33 @@ inline unsigned countPopulation(T Value) {
return detail::PopulationCounter<T, sizeof(T)>::count(Value);
}
+/// Return true if the argument contains a non-empty sequence of ones with the
+/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
+/// If true, \p MaskIdx will specify the index of the lowest set bit and \p
+/// MaskLen is updated to specify the length of the mask, else neither are
+/// updated.
+inline bool isShiftedMask_32(uint32_t Value, unsigned &MaskIdx,
+ unsigned &MaskLen) {
+ if (!isShiftedMask_32(Value))
+ return false;
+ MaskIdx = countTrailingZeros(Value);
+ MaskLen = countPopulation(Value);
+ return true;
+}
+
+/// Return true if the argument contains a non-empty sequence of ones with the
+/// remainder zero (64 bit version.) If true, \p MaskIdx will specify the index
+/// of the lowest set bit and \p MaskLen is updated to specify the length of the
+/// mask, else neither are updated.
+inline bool isShiftedMask_64(uint64_t Value, unsigned &MaskIdx,
+ unsigned &MaskLen) {
+ if (!isShiftedMask_64(Value))
+ return false;
+ MaskIdx = countTrailingZeros(Value);
+ MaskLen = countPopulation(Value);
+ return true;
+}
+
/// Compile time Log2.
/// Valid only for positive powers of two.
template <size_t kValue> constexpr inline size_t CTLog2() {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8539c5c681f64..f2a562e49c834 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12254,10 +12254,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
unsigned ActiveBits = 0;
if (Mask.isMask()) {
ActiveBits = Mask.countTrailingOnes();
- } else if (Mask.isShiftedMask()) {
- ShAmt = Mask.countTrailingZeros();
- APInt ShiftedMask = Mask.lshr(ShAmt);
- ActiveBits = ShiftedMask.countTrailingOnes();
+ } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
HasShiftedOffset = true;
} else
return SDValue();
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index b9d0655feef72..23d970f6d1bff 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3281,8 +3281,9 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
// this improves the ability to match BFE patterns in isel.
if (LHS.getOpcode() == ISD::AND) {
if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
- if (Mask->getAPIntValue().isShiftedMask() &&
- Mask->getAPIntValue().countTrailingZeros() == ShiftAmt) {
+ unsigned MaskIdx, MaskLen;
+ if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+ MaskIdx == ShiftAmt) {
return DAG.getNode(
ISD::AND, SL, VT,
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 0c2e129b8f1fc..a0e4bbc730354 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -94,18 +94,6 @@ static const MCPhysReg Mips64DPRegs[8] = {
Mips::D16_64, Mips::D17_64, Mips::D18_64, Mips::D19_64
};
-// If I is a shifted mask, set the size (Size) and the first bit of the
-// mask (Pos), and return true.
-// For example, if I is 0x003ff800, (Pos, Size) = (11, 11).
-static bool isShiftedMask(uint64_t I, uint64_t &Pos, uint64_t &Size) {
- if (!isShiftedMask_64(I))
- return false;
-
- Size = countPopulation(I);
- Pos = countTrailingZeros(I);
- return true;
-}
-
// The MIPS MSA ABI passes vector arguments in the integer register set.
// The number of integer registers used is dependant on the ABI used.
MVT MipsTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
@@ -794,14 +782,15 @@ static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
EVT ValTy = N->getValueType(0);
SDLoc DL(N);
- uint64_t Pos = 0, SMPos, SMSize;
+ uint64_t Pos = 0;
+ unsigned SMPos, SMSize;
ConstantSDNode *CN;
SDValue NewOperand;
unsigned Opc;
// Op's second operand must be a shifted mask.
if (!(CN = dyn_cast<ConstantSDNode>(Mask)) ||
- !isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
+ !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
return SDValue();
if (FirstOperandOpc == ISD::SRA || FirstOperandOpc == ISD::SRL) {
@@ -875,7 +864,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
SDValue And0 = N->getOperand(0), And1 = N->getOperand(1);
- uint64_t SMPos0, SMSize0, SMPos1, SMSize1;
+ unsigned SMPos0, SMSize0, SMPos1, SMSize1;
ConstantSDNode *CN, *CN1;
// See if Op's first operand matches (and $src1 , mask0).
@@ -883,7 +872,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
if (!(CN = dyn_cast<ConstantSDNode>(And0.getOperand(1))) ||
- !isShiftedMask(~CN->getSExtValue(), SMPos0, SMSize0))
+ !isShiftedMask_64(~CN->getSExtValue(), SMPos0, SMSize0))
return SDValue();
// See if Op's second operand matches (and (shl $src, pos), mask1).
@@ -891,7 +880,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
And1.getOperand(0).getOpcode() == ISD::SHL) {
if (!(CN = dyn_cast<ConstantSDNode>(And1.getOperand(1))) ||
- !isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1))
+ !isShiftedMask_64(CN->getZExtValue(), SMPos1, SMSize1))
return SDValue();
// The shift masks must have the same position and size.
@@ -1118,7 +1107,8 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
EVT ValTy = N->getValueType(0);
SDLoc DL(N);
- uint64_t Pos = 0, SMPos, SMSize;
+ uint64_t Pos = 0;
+ unsigned SMPos, SMSize;
ConstantSDNode *CN;
SDValue NewOperand;
@@ -1136,7 +1126,7 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
// AND's second operand must be a shifted mask.
if (!(CN = dyn_cast<ConstantSDNode>(FirstOperand.getOperand(1))) ||
- !isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
+ !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
return SDValue();
// Return if the shifted mask does not start at bit 0 or the sum of its size
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index 1b021ada6b668..cdbb1d7584b60 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -996,20 +996,18 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
return IC.replaceInstUsesWith(II, II.getArgOperand(0));
}
- if (MaskC->getValue().isShiftedMask()) {
+ unsigned MaskIdx, MaskLen;
+ if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
// any single contingous sequence of 1s anywhere in the mask simply
// describes a subset of the input bits shifted to the appropriate
// position. Replace with the straight forward IR.
- unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
Value *Input = II.getArgOperand(0);
Value *Masked = IC.Builder.CreateAnd(Input, II.getArgOperand(1));
- Value *Shifted = IC.Builder.CreateLShr(Masked,
- ConstantInt::get(II.getType(),
- ShiftAmount));
+ Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
+ Value *Shifted = IC.Builder.CreateLShr(Masked, ShiftAmt);
return IC.replaceInstUsesWith(II, Shifted);
}
-
if (auto *SrcC = dyn_cast<ConstantInt>(II.getArgOperand(0))) {
uint64_t Src = SrcC->getZExtValue();
uint64_t Mask = MaskC->getZExtValue();
@@ -1041,15 +1039,15 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
if (MaskC->isAllOnesValue()) {
return IC.replaceInstUsesWith(II, II.getArgOperand(0));
}
- if (MaskC->getValue().isShiftedMask()) {
+
+ unsigned MaskIdx, MaskLen;
+ if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
// any single contingous sequence of 1s anywhere in the mask simply
// describes a subset of the input bits shifted to the appropriate
// position. Replace with the straight forward IR.
- unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
Value *Input = II.getArgOperand(0);
- Value *Shifted = IC.Builder.CreateShl(Input,
- ConstantInt::get(II.getType(),
- ShiftAmount));
+ Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
+ Value *Shifted = IC.Builder.CreateShl(Input, ShiftAmt);
Value *Masked = IC.Builder.CreateAnd(Shifted, II.getArgOperand(1));
return IC.replaceInstUsesWith(II, Masked);
}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 95cb213aa8512..566ceabf3f569 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -1746,21 +1746,43 @@ TEST(APIntTest, isShiftedMask) {
EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask());
EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask());
+ unsigned MaskIdx, MaskLen;
+ EXPECT_FALSE(APInt(32, 0x01010101).isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_TRUE(APInt(32, 0xf0000000).isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(28, MaskIdx);
+ EXPECT_EQ(4, MaskLen);
+ EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(16, MaskIdx);
+ EXPECT_EQ(16, MaskLen);
+ EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(1, MaskIdx);
+ EXPECT_EQ(8, MaskLen);
+
for (int N : { 1, 2, 3, 4, 7, 8, 16, 32, 64, 127, 128, 129, 256 }) {
EXPECT_FALSE(APInt(N, 0).isShiftedMask());
+ EXPECT_FALSE(APInt(N, 0).isShiftedMask(MaskIdx, MaskLen));
APInt One(N, 1);
for (int I = 1; I < N; ++I) {
APInt MaskVal = One.shl(I) - 1;
EXPECT_TRUE(MaskVal.isShiftedMask());
+ EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(0, MaskIdx);
+ EXPECT_EQ(I, MaskLen);
}
for (int I = 1; I < N - 1; ++I) {
APInt MaskVal = One.shl(I);
EXPECT_TRUE(MaskVal.isShiftedMask());
+ EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(I, MaskIdx);
+ EXPECT_EQ(1, MaskLen);
}
for (int I = 1; I < N; ++I) {
APInt MaskVal = APInt::getHighBitsSet(N, I);
EXPECT_TRUE(MaskVal.isShiftedMask());
+ EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+ EXPECT_EQ(N - I, MaskIdx);
+ EXPECT_EQ(I, MaskLen);
}
}
}
diff --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index cf055c6804da9..b9ed0dffcce20 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -180,6 +180,18 @@ TEST(MathExtras, isShiftedMask_32) {
EXPECT_TRUE(isShiftedMask_32(0xf0000000));
EXPECT_TRUE(isShiftedMask_32(0xffff0000));
EXPECT_TRUE(isShiftedMask_32(0xff << 1));
+
+ unsigned MaskIdx, MaskLen;
+ EXPECT_FALSE(isShiftedMask_32(0x01010101, MaskIdx, MaskLen));
+ EXPECT_TRUE(isShiftedMask_32(0xf0000000, MaskIdx, MaskLen));
+ EXPECT_EQ(28, MaskIdx);
+ EXPECT_EQ(4, MaskLen);
+ EXPECT_TRUE(isShiftedMask_32(0xffff0000, MaskIdx, MaskLen));
+ EXPECT_EQ(16, MaskIdx);
+ EXPECT_EQ(16, MaskLen);
+ EXPECT_TRUE(isShiftedMask_32(0xff << 1, MaskIdx, MaskLen));
+ EXPECT_EQ(1, MaskIdx);
+ EXPECT_EQ(8, MaskLen);
}
TEST(MathExtras, isShiftedMask_64) {
@@ -187,6 +199,18 @@ TEST(MathExtras, isShiftedMask_64) {
EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull));
EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull));
EXPECT_TRUE(isShiftedMask_64(0xffull << 55));
+
+ unsigned MaskIdx, MaskLen;
+ EXPECT_FALSE(isShiftedMask_64(0x0101010101010101ull, MaskIdx, MaskLen));
+ EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull, MaskIdx, MaskLen));
+ EXPECT_EQ(60, MaskIdx);
+ EXPECT_EQ(4, MaskLen);
+ EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull, MaskIdx, MaskLen));
+ EXPECT_EQ(48, MaskIdx);
+ EXPECT_EQ(16, MaskLen);
+ EXPECT_TRUE(isShiftedMask_64(0xffull << 55, MaskIdx, MaskLen));
+ EXPECT_EQ(55, MaskIdx);
+ EXPECT_EQ(8, MaskLen);
}
TEST(MathExtras, isPowerOf2_32) {
More information about the llvm-commits
mailing list