[llvm] fd2bb51 - [ADT] Add APInt/MathExtras isShiftedMask variant returning mask offset/length

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 8 04:04:30 PST 2022


Author: Simon Pilgrim
Date: 2022-02-08T12:04:13Z
New Revision: fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4

URL: https://github.com/llvm/llvm-project/commit/fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4
DIFF: https://github.com/llvm/llvm-project/commit/fd2bb51f1ec3f53e5dd1e69eb48bf191b49edda4.diff

LOG: [ADT] Add APInt/MathExtras isShiftedMask variant returning mask offset/length

In many cases, calls to isShiftedMask are immediately followed with checks to determine the size and position of the bitmask.

This patch adds variants of APInt::isShiftedMask, isShiftedMask_32 and isShiftedMask_64 that return these values as additional arguments.

I've updated a number of cases that were either performing seperate size/position calculations or had created their own local wrapper versions of these.

Differential Revision: https://reviews.llvm.org/D119019

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/include/llvm/Support/MathExtras.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
    llvm/lib/Target/Mips/MipsISelLowering.cpp
    llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
    llvm/unittests/ADT/APIntTest.cpp
    llvm/unittests/Support/MathExtrasTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 0fdcbb6c3dfbc..a475e27c797d2 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -496,6 +496,23 @@ class LLVM_NODISCARD APInt {
     return (Ones + LeadZ + countTrailingZeros()) == BitWidth;
   }
 
+  /// Return true if this APInt value contains a non-empty sequence of ones with
+  /// the remainder zero. If true, \p MaskIdx will specify the index of the
+  /// lowest set bit and \p MaskLen is updated to specify the length of the
+  /// mask, else neither are updated.
+  bool isShiftedMask(unsigned &MaskIdx, unsigned &MaskLen) const {
+    if (isSingleWord())
+      return isShiftedMask_64(U.VAL, MaskIdx, MaskLen);
+    unsigned Ones = countPopulationSlowCase();
+    unsigned LeadZ = countLeadingZerosSlowCase();
+    unsigned TrailZ = countTrailingZerosSlowCase();
+    if ((Ones + LeadZ + TrailZ) != BitWidth)
+      return false;
+    MaskLen = Ones;
+    MaskIdx = TrailZ;
+    return true;
+  }
+
   /// Compute an APInt containing numBits highbits from this APInt.
   ///
   /// Get an APInt with the same BitWidth as this APInt, just zero mask the low

diff  --git a/llvm/include/llvm/Support/MathExtras.h b/llvm/include/llvm/Support/MathExtras.h
index 753b1998c40c0..ccb0f5594ebd5 100644
--- a/llvm/include/llvm/Support/MathExtras.h
+++ b/llvm/include/llvm/Support/MathExtras.h
@@ -571,6 +571,33 @@ inline unsigned countPopulation(T Value) {
   return detail::PopulationCounter<T, sizeof(T)>::count(Value);
 }
 
+/// Return true if the argument contains a non-empty sequence of ones with the
+/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
+/// If true, \p MaskIdx will specify the index of the lowest set bit and \p
+/// MaskLen is updated to specify the length of the mask, else neither are
+/// updated.
+inline bool isShiftedMask_32(uint32_t Value, unsigned &MaskIdx,
+                             unsigned &MaskLen) {
+  if (!isShiftedMask_32(Value))
+    return false;
+  MaskIdx = countTrailingZeros(Value);
+  MaskLen = countPopulation(Value);
+  return true;
+}
+
+/// Return true if the argument contains a non-empty sequence of ones with the
+/// remainder zero (64 bit version.) If true, \p MaskIdx will specify the index
+/// of the lowest set bit and \p MaskLen is updated to specify the length of the
+/// mask, else neither are updated.
+inline bool isShiftedMask_64(uint64_t Value, unsigned &MaskIdx,
+                             unsigned &MaskLen) {
+  if (!isShiftedMask_64(Value))
+    return false;
+  MaskIdx = countTrailingZeros(Value);
+  MaskLen = countPopulation(Value);
+  return true;
+}
+
 /// Compile time Log2.
 /// Valid only for positive powers of two.
 template <size_t kValue> constexpr inline size_t CTLog2() {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8539c5c681f64..f2a562e49c834 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12254,10 +12254,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
     unsigned ActiveBits = 0;
     if (Mask.isMask()) {
       ActiveBits = Mask.countTrailingOnes();
-    } else if (Mask.isShiftedMask()) {
-      ShAmt = Mask.countTrailingZeros();
-      APInt ShiftedMask = Mask.lshr(ShAmt);
-      ActiveBits = ShiftedMask.countTrailingOnes();
+    } else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
       HasShiftedOffset = true;
     } else
       return SDValue();

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index b9d0655feef72..23d970f6d1bff 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3281,8 +3281,9 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
   // this improves the ability to match BFE patterns in isel.
   if (LHS.getOpcode() == ISD::AND) {
     if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
-      if (Mask->getAPIntValue().isShiftedMask() &&
-          Mask->getAPIntValue().countTrailingZeros() == ShiftAmt) {
+      unsigned MaskIdx, MaskLen;
+      if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+          MaskIdx == ShiftAmt) {
         return DAG.getNode(
             ISD::AND, SL, VT,
             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),

diff  --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 0c2e129b8f1fc..a0e4bbc730354 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -94,18 +94,6 @@ static const MCPhysReg Mips64DPRegs[8] = {
   Mips::D16_64, Mips::D17_64, Mips::D18_64, Mips::D19_64
 };
 
-// If I is a shifted mask, set the size (Size) and the first bit of the
-// mask (Pos), and return true.
-// For example, if I is 0x003ff800, (Pos, Size) = (11, 11).
-static bool isShiftedMask(uint64_t I, uint64_t &Pos, uint64_t &Size) {
-  if (!isShiftedMask_64(I))
-    return false;
-
-  Size = countPopulation(I);
-  Pos = countTrailingZeros(I);
-  return true;
-}
-
 // The MIPS MSA ABI passes vector arguments in the integer register set.
 // The number of integer registers used is dependant on the ABI used.
 MVT MipsTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
@@ -794,14 +782,15 @@ static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
   EVT ValTy = N->getValueType(0);
   SDLoc DL(N);
 
-  uint64_t Pos = 0, SMPos, SMSize;
+  uint64_t Pos = 0;
+  unsigned SMPos, SMSize;
   ConstantSDNode *CN;
   SDValue NewOperand;
   unsigned Opc;
 
   // Op's second operand must be a shifted mask.
   if (!(CN = dyn_cast<ConstantSDNode>(Mask)) ||
-      !isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
+      !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
     return SDValue();
 
   if (FirstOperandOpc == ISD::SRA || FirstOperandOpc == ISD::SRL) {
@@ -875,7 +864,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDValue And0 = N->getOperand(0), And1 = N->getOperand(1);
-  uint64_t SMPos0, SMSize0, SMPos1, SMSize1;
+  unsigned SMPos0, SMSize0, SMPos1, SMSize1;
   ConstantSDNode *CN, *CN1;
 
   // See if Op's first operand matches (and $src1 , mask0).
@@ -883,7 +872,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   if (!(CN = dyn_cast<ConstantSDNode>(And0.getOperand(1))) ||
-      !isShiftedMask(~CN->getSExtValue(), SMPos0, SMSize0))
+      !isShiftedMask_64(~CN->getSExtValue(), SMPos0, SMSize0))
     return SDValue();
 
   // See if Op's second operand matches (and (shl $src, pos), mask1).
@@ -891,7 +880,7 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
       And1.getOperand(0).getOpcode() == ISD::SHL) {
 
     if (!(CN = dyn_cast<ConstantSDNode>(And1.getOperand(1))) ||
-        !isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1))
+        !isShiftedMask_64(CN->getZExtValue(), SMPos1, SMSize1))
       return SDValue();
 
     // The shift masks must have the same position and size.
@@ -1118,7 +1107,8 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
   EVT ValTy = N->getValueType(0);
   SDLoc DL(N);
 
-  uint64_t Pos = 0, SMPos, SMSize;
+  uint64_t Pos = 0;
+  unsigned SMPos, SMSize;
   ConstantSDNode *CN;
   SDValue NewOperand;
 
@@ -1136,7 +1126,7 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
 
   // AND's second operand must be a shifted mask.
   if (!(CN = dyn_cast<ConstantSDNode>(FirstOperand.getOperand(1))) ||
-      !isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
+      !isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
     return SDValue();
 
   // Return if the shifted mask does not start at bit 0 or the sum of its size

diff  --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index 1b021ada6b668..cdbb1d7584b60 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -996,20 +996,18 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
         return IC.replaceInstUsesWith(II, II.getArgOperand(0));
       }
 
-      if (MaskC->getValue().isShiftedMask()) {
+      unsigned MaskIdx, MaskLen;
+      if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
         // any single contingous sequence of 1s anywhere in the mask simply
         // describes a subset of the input bits shifted to the appropriate
         // position.  Replace with the straight forward IR.
-        unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
         Value *Input = II.getArgOperand(0);
         Value *Masked = IC.Builder.CreateAnd(Input, II.getArgOperand(1));
-        Value *Shifted = IC.Builder.CreateLShr(Masked,
-                                               ConstantInt::get(II.getType(),
-                                                                ShiftAmount));
+        Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
+        Value *Shifted = IC.Builder.CreateLShr(Masked, ShiftAmt);
         return IC.replaceInstUsesWith(II, Shifted);
       }
 
-
       if (auto *SrcC = dyn_cast<ConstantInt>(II.getArgOperand(0))) {
         uint64_t Src = SrcC->getZExtValue();
         uint64_t Mask = MaskC->getZExtValue();
@@ -1041,15 +1039,15 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       if (MaskC->isAllOnesValue()) {
         return IC.replaceInstUsesWith(II, II.getArgOperand(0));
       }
-      if (MaskC->getValue().isShiftedMask()) {
+
+      unsigned MaskIdx, MaskLen;
+      if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
         // any single contingous sequence of 1s anywhere in the mask simply
         // describes a subset of the input bits shifted to the appropriate
         // position.  Replace with the straight forward IR.
-        unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
         Value *Input = II.getArgOperand(0);
-        Value *Shifted = IC.Builder.CreateShl(Input,
-                                              ConstantInt::get(II.getType(),
-                                                               ShiftAmount));
+        Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
+        Value *Shifted = IC.Builder.CreateShl(Input, ShiftAmt);
         Value *Masked = IC.Builder.CreateAnd(Shifted, II.getArgOperand(1));
         return IC.replaceInstUsesWith(II, Masked);
       }

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 95cb213aa8512..566ceabf3f569 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -1746,21 +1746,43 @@ TEST(APIntTest, isShiftedMask) {
   EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask());
   EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask());
 
+  unsigned MaskIdx, MaskLen;
+  EXPECT_FALSE(APInt(32, 0x01010101).isShiftedMask(MaskIdx, MaskLen));
+  EXPECT_TRUE(APInt(32, 0xf0000000).isShiftedMask(MaskIdx, MaskLen));
+  EXPECT_EQ(28, MaskIdx);
+  EXPECT_EQ(4, MaskLen);
+  EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask(MaskIdx, MaskLen));
+  EXPECT_EQ(16, MaskIdx);
+  EXPECT_EQ(16, MaskLen);
+  EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask(MaskIdx, MaskLen));
+  EXPECT_EQ(1, MaskIdx);
+  EXPECT_EQ(8, MaskLen);
+
   for (int N : { 1, 2, 3, 4, 7, 8, 16, 32, 64, 127, 128, 129, 256 }) {
     EXPECT_FALSE(APInt(N, 0).isShiftedMask());
+    EXPECT_FALSE(APInt(N, 0).isShiftedMask(MaskIdx, MaskLen));
 
     APInt One(N, 1);
     for (int I = 1; I < N; ++I) {
       APInt MaskVal = One.shl(I) - 1;
       EXPECT_TRUE(MaskVal.isShiftedMask());
+      EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+      EXPECT_EQ(0, MaskIdx);
+      EXPECT_EQ(I, MaskLen);
     }
     for (int I = 1; I < N - 1; ++I) {
       APInt MaskVal = One.shl(I);
       EXPECT_TRUE(MaskVal.isShiftedMask());
+      EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+      EXPECT_EQ(I, MaskIdx);
+      EXPECT_EQ(1, MaskLen);
     }
     for (int I = 1; I < N; ++I) {
       APInt MaskVal = APInt::getHighBitsSet(N, I);
       EXPECT_TRUE(MaskVal.isShiftedMask());
+      EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
+      EXPECT_EQ(N - I, MaskIdx);
+      EXPECT_EQ(I, MaskLen);
     }
   }
 }

diff  --git a/llvm/unittests/Support/MathExtrasTest.cpp b/llvm/unittests/Support/MathExtrasTest.cpp
index cf055c6804da9..b9ed0dffcce20 100644
--- a/llvm/unittests/Support/MathExtrasTest.cpp
+++ b/llvm/unittests/Support/MathExtrasTest.cpp
@@ -180,6 +180,18 @@ TEST(MathExtras, isShiftedMask_32) {
   EXPECT_TRUE(isShiftedMask_32(0xf0000000));
   EXPECT_TRUE(isShiftedMask_32(0xffff0000));
   EXPECT_TRUE(isShiftedMask_32(0xff << 1));
+
+  unsigned MaskIdx, MaskLen;
+  EXPECT_FALSE(isShiftedMask_32(0x01010101, MaskIdx, MaskLen));
+  EXPECT_TRUE(isShiftedMask_32(0xf0000000, MaskIdx, MaskLen));
+  EXPECT_EQ(28, MaskIdx);
+  EXPECT_EQ(4, MaskLen);
+  EXPECT_TRUE(isShiftedMask_32(0xffff0000, MaskIdx, MaskLen));
+  EXPECT_EQ(16, MaskIdx);
+  EXPECT_EQ(16, MaskLen);
+  EXPECT_TRUE(isShiftedMask_32(0xff << 1, MaskIdx, MaskLen));
+  EXPECT_EQ(1, MaskIdx);
+  EXPECT_EQ(8, MaskLen);
 }
 
 TEST(MathExtras, isShiftedMask_64) {
@@ -187,6 +199,18 @@ TEST(MathExtras, isShiftedMask_64) {
   EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull));
   EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull));
   EXPECT_TRUE(isShiftedMask_64(0xffull << 55));
+
+  unsigned MaskIdx, MaskLen;
+  EXPECT_FALSE(isShiftedMask_64(0x0101010101010101ull, MaskIdx, MaskLen));
+  EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull, MaskIdx, MaskLen));
+  EXPECT_EQ(60, MaskIdx);
+  EXPECT_EQ(4, MaskLen);
+  EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull, MaskIdx, MaskLen));
+  EXPECT_EQ(48, MaskIdx);
+  EXPECT_EQ(16, MaskLen);
+  EXPECT_TRUE(isShiftedMask_64(0xffull << 55, MaskIdx, MaskLen));
+  EXPECT_EQ(55, MaskIdx);
+  EXPECT_EQ(8, MaskLen);
 }
 
 TEST(MathExtras, isPowerOf2_32) {


        


More information about the llvm-commits mailing list