[llvm] 9db2082 - [APInt] Add APIntOps::ScaleBitMask helper

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 13 08:27:52 PDT 2021


Author: Simon Pilgrim
Date: 2021-09-13T16:27:12+01:00
New Revision: 9db20822f795f1057d5ba31bdc2aa82a6d96a6a0

URL: https://github.com/llvm/llvm-project/commit/9db20822f795f1057d5ba31bdc2aa82a6d96a6a0
DIFF: https://github.com/llvm/llvm-project/commit/9db20822f795f1057d5ba31bdc2aa82a6d96a6a0.diff

LOG: [APInt] Add APIntOps::ScaleBitMask helper

APInt is used to describe a bit mask in a variety of value tracking and demanded bits/elts functions.

When traversing through dst/src operands, we have a number of places where these masks need to widened/narrowed to translate through bitcasts, reductions etc. to a different type.

This patch add a APIntOps::ScaleBitMask common helper, adds unit test coverage, and updates a number of cases to use the the helper instead of their own implementation.

This came up on D109065 where we currently have to add yet another implementation of the same code.

Differential Revision: https://reviews.llvm.org/D109683

Added: 
    

Modified: 
    llvm/include/llvm/ADT/APInt.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Support/APInt.cpp
    llvm/unittests/ADT/APIntTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 4aad2c950f744..e00709a54465f 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -2206,6 +2206,15 @@ Optional<APInt> SolveQuadraticEquationWrap(APInt A, APInt B, APInt C,
 Optional<unsigned> GetMostSignificantDifferentBit(const APInt &A,
                                                   const APInt &B);
 
+/// Splat/Merge neighboring bits to widen/narrow the bitmask represented
+/// by \param A to \param NewBitWidth bits.
+///
+/// e.g. ScaleBitMask(0b0101, 8) -> 0b00110011
+/// e.g. ScaleBitMask(0b00011011, 4) -> 0b0111
+/// A.getBitwidth() or NewBitWidth must be a whole multiples of the other.
+///
+/// TODO: Do we need a mode where all bits must be set when merging down?
+APInt ScaleBitMask(const APInt &A, unsigned NewBitWidth);
 } // namespace APIntOps
 
 // See friend declaration above. This additional declaration is required in

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 08ff77b5ef33e..6808b212bb2b2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2990,11 +2990,8 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
       // bits from the overlapping larger input elements and extracting the
       // sub sections we actually care about.
       unsigned SubScale = SubBitWidth / BitWidth;
-      APInt SubDemandedElts(NumElts / SubScale, 0);
-      for (unsigned i = 0; i != NumElts; ++i)
-        if (DemandedElts[i])
-          SubDemandedElts.setBit(i / SubScale);
-
+      APInt SubDemandedElts =
+          APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale);
       Known2 = computeKnownBits(N0, SubDemandedElts, Depth + 1);
 
       Known.Zero.setAllBits(); Known.One.setAllBits();
@@ -3802,10 +3799,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
       assert(VT.isVector() && "Expected bitcast to vector");
 
       unsigned Scale = SrcBits / VTBits;
-      APInt SrcDemandedElts(NumElts / Scale, 0);
-      for (unsigned i = 0; i != NumElts; ++i)
-        if (DemandedElts[i])
-          SrcDemandedElts.setBit(i / Scale);
+      APInt SrcDemandedElts =
+          APIntOps::ScaleBitMask(DemandedElts, NumElts / Scale);
 
       // Fast case - sign splat can be simply split across the small elements.
       Tmp = ComputeNumSignBits(N0, SrcDemandedElts, Depth + 1);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 6a3eba1cb6231..1b84aff7ac2fc 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -2469,17 +2469,13 @@ bool TargetLowering::SimplifyDemandedVectorElts(
       return SimplifyDemandedVectorElts(Src, DemandedElts, KnownUndef,
                                         KnownZero, TLO, Depth + 1);
 
-    APInt SrcZero, SrcUndef;
-    APInt SrcDemandedElts = APInt::getZero(NumSrcElts);
+    APInt SrcDemandedElts, SrcZero, SrcUndef;
 
     // Bitcast from 'large element' src vector to 'small element' vector, we
     // must demand a source element if any DemandedElt maps to it.
     if ((NumElts % NumSrcElts) == 0) {
       unsigned Scale = NumElts / NumSrcElts;
-      for (unsigned i = 0; i != NumElts; ++i)
-        if (DemandedElts[i])
-          SrcDemandedElts.setBit(i / Scale);
-
+      SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
       if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
                                      TLO, Depth + 1))
         return true;
@@ -2519,10 +2515,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     // of this vector.
     if ((NumSrcElts % NumElts) == 0) {
       unsigned Scale = NumSrcElts / NumElts;
-      for (unsigned i = 0; i != NumElts; ++i)
-        if (DemandedElts[i])
-          SrcDemandedElts.setBits(i * Scale, (i + 1) * Scale);
-
+      SrcDemandedElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
       if (SimplifyDemandedVectorElts(Src, SrcDemandedElts, SrcUndef, SrcZero,
                                      TLO, Depth + 1))
         return true;

diff  --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 39824905434cc..a630050c0157a 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -2948,6 +2948,40 @@ llvm::APIntOps::GetMostSignificantDifferentBit(const APInt &A, const APInt &B) {
   return A.getBitWidth() - ((A ^ B).countLeadingZeros() + 1);
 }
 
+APInt llvm::APIntOps::ScaleBitMask(const APInt &A, unsigned NewBitWidth) {
+  unsigned OldBitWidth = A.getBitWidth();
+  assert((((OldBitWidth % NewBitWidth) == 0) ||
+          ((NewBitWidth % OldBitWidth) == 0)) &&
+         "One size should be a multiple of the other one. "
+         "Can't do fractional scaling.");
+
+  // Check for matching bitwidths.
+  if (OldBitWidth == NewBitWidth)
+    return A;
+
+  APInt NewA = APInt::getNullValue(NewBitWidth);
+
+  // Check for null input.
+  if (A.isNullValue())
+    return NewA;
+
+  if (NewBitWidth > OldBitWidth) {
+    // Repeat bits.
+    unsigned Scale = NewBitWidth / OldBitWidth;
+    for (unsigned i = 0; i != OldBitWidth; ++i)
+      if (A[i])
+        NewA.setBits(i * Scale, (i + 1) * Scale);
+  } else {
+    // Merge bits - if any old bit is set, then set scale equivalent new bit.
+    unsigned Scale = OldBitWidth / NewBitWidth;
+    for (unsigned i = 0; i != NewBitWidth; ++i)
+      if (!A.extractBits(Scale, i * Scale).isNullValue())
+        NewA.setBit(i);
+  }
+
+  return NewA;
+}
+
 /// StoreIntToMemory - Fills the StoreBytes bytes of memory starting from Dst
 /// with the integer held in IntVal.
 void llvm::StoreIntToMemory(const APInt &IntVal, uint8_t *Dst,

diff  --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 0f8f4626692b4..6da9d2ddeb993 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2982,4 +2982,24 @@ TEST(APIntTest, ZeroWidth) {
   EXPECT_EQ(0U, MZW1.getBitWidth());
 }
 
+TEST(APIntTest, ScaleBitMask) {
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x00), 8), APInt(8, 0x00));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x01), 8), APInt(8, 0x0F));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x02), 8), APInt(8, 0xF0));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(2, 0x03), 8), APInt(8, 0xFF));
+
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 4), APInt(4, 0x00));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xFF), 4), APInt(4, 0x0F));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0xE4), 4), APInt(4, 0x0E));
+
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt(8, 0x00), 8), APInt(8, 0x00));
+
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getNullValue(1024), 4096),
+            APInt::getNullValue(4096));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getAllOnes(4096), 256),
+            APInt::getAllOnes(256));
+  EXPECT_EQ(APIntOps::ScaleBitMask(APInt::getOneBitSet(4096, 32), 256),
+            APInt::getOneBitSet(256, 2));
+}
+
 } // end anonymous namespace


        


More information about the llvm-commits mailing list