[llvm] 55a11b5 - [VectorUtils] Add getShuffleDemandedElts helper

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 30 10:04:04 PDT 2022


Author: Simon Pilgrim
Date: 2022-10-30T17:03:55Z
New Revision: 55a11b542e055744cac3beb87714fc90cacba4c0

URL: https://github.com/llvm/llvm-project/commit/55a11b542e055744cac3beb87714fc90cacba4c0
DIFF: https://github.com/llvm/llvm-project/commit/55a11b542e055744cac3beb87714fc90cacba4c0.diff

LOG: [VectorUtils] Add getShuffleDemandedElts helper

We have similar code to translate a demanded elements mask for a shuffle's operands in multiple places - this patch adds a helper function to VectorUtils and updates a number of locations to use it directly.

Differential Revision: https://reviews.llvm.org/D136832

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/unittests/Analysis/VectorUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 5119208761d4c..0f2346cf7c57e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -366,6 +366,14 @@ Value *getSplatValue(const Value *V);
 /// not limited by finding a scalar source value to a splatted vector.
 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
 
+/// Transform a shuffle mask's output demanded element mask into demanded
+/// element masks for the 2 operands, returns false if the mask isn't valid.
+/// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
+/// \p AllowUndefElts permits "-1" indices to be treated as undef.
+bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
+                            const APInt &DemandedElts, APInt &DemandedLHS,
+                            APInt &DemandedRHS, bool AllowUndefElts = false);
+
 /// Replace each shuffle mask index with the scaled sequential indices for an
 /// equivalent mask of narrowed elements. Mask elements that are less than 0
 /// (sentinel values) are repeated in the output mask.

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 64594d9012440..49cfff7f2a68c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -34,6 +34,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
@@ -174,32 +175,8 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
 
   int NumElts =
       cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
-  int NumMaskElts = cast<FixedVectorType>(Shuf->getType())->getNumElements();
-  DemandedLHS = DemandedRHS = APInt::getZero(NumElts);
-  if (DemandedElts.isZero())
-    return true;
-  // Simple case of a shuffle with zeroinitializer.
-  if (all_of(Shuf->getShuffleMask(), [](int Elt) { return Elt == 0; })) {
-    DemandedLHS.setBit(0);
-    return true;
-  }
-  for (int i = 0; i != NumMaskElts; ++i) {
-    if (!DemandedElts[i])
-      continue;
-    int M = Shuf->getMaskValue(i);
-    assert(M < (NumElts * 2) && "Invalid shuffle mask constant");
-
-    // For undef elements, we don't know anything about the common state of
-    // the shuffle result.
-    if (M == -1)
-      return false;
-    if (M < NumElts)
-      DemandedLHS.setBit(M % NumElts);
-    else
-      DemandedRHS.setBit(M % NumElts);
-  }
-
-  return true;
+  return llvm::getShuffleDemandedElts(NumElts, Shuf->getShuffleMask(),
+                                      DemandedElts, DemandedLHS, DemandedRHS);
 }
 
 static void computeKnownBits(const Value *V, const APInt &DemandedElts,

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index b4398170a34c5..cca347560e2d7 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -429,6 +429,43 @@ bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
   return false;
 }
 
+bool llvm::getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
+                                  const APInt &DemandedElts, APInt &DemandedLHS,
+                                  APInt &DemandedRHS, bool AllowUndefElts) {
+  DemandedLHS = DemandedRHS = APInt::getZero(SrcWidth);
+
+  // Early out if we don't demand any elements.
+  if (DemandedElts.isZero())
+    return true;
+
+  // Simple case of a shuffle with zeroinitializer.
+  if (all_of(Mask, [](int Elt) { return Elt == 0; })) {
+    DemandedLHS.setBit(0);
+    return true;
+  }
+
+  for (unsigned I = 0, E = Mask.size(); I != E; ++I) {
+    int M = Mask[I];
+    assert((-1 <= M) && (M < (SrcWidth * 2)) &&
+           "Invalid shuffle mask constant");
+
+    if (!DemandedElts[I] || (AllowUndefElts && (M < 0)))
+      continue;
+
+    // For undef elements, we don't know anything about the common state of
+    // the shuffle result.
+    if (M < 0)
+      return false;
+
+    if (M < SrcWidth)
+      DemandedLHS.setBit(M);
+    else
+      DemandedRHS.setBit(M - SrcWidth);
+  }
+
+  return true;
+}
+
 void llvm::narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
                                  SmallVectorImpl<int> &ScaledMask) {
   assert(Scale > 0 && "Unexpected scaling factor");

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b908f2b4574a6..c849722dab4be 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -27,6 +27,7 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/FunctionLoweringInfo.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
@@ -2978,30 +2979,15 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
   case ISD::VECTOR_SHUFFLE: {
     // Collect the known bits that are shared by every vector element referenced
     // by the shuffle.
-    APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
-    Known.Zero.setAllBits(); Known.One.setAllBits();
+    APInt DemandedLHS, DemandedRHS;
     const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
     assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
-    for (unsigned i = 0; i != NumElts; ++i) {
-      if (!DemandedElts[i])
-        continue;
-
-      int M = SVN->getMaskElt(i);
-      if (M < 0) {
-        // For UNDEF elements, we don't know anything about the common state of
-        // the shuffle result.
-        Known.resetAll();
-        DemandedLHS.clearAllBits();
-        DemandedRHS.clearAllBits();
-        break;
-      }
+    if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
+                                DemandedLHS, DemandedRHS))
+      break;
 
-      if ((unsigned)M < NumElts)
-        DemandedLHS.setBit((unsigned)M % NumElts);
-      else
-        DemandedRHS.setBit((unsigned)M % NumElts);
-    }
     // Known bits are the values that are shared by every demanded element.
+    Known.Zero.setAllBits(); Known.One.setAllBits();
     if (!!DemandedLHS) {
       SDValue LHS = Op.getOperand(0);
       Known2 = computeKnownBits(LHS, DemandedLHS, Depth + 1);
@@ -3984,22 +3970,13 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
   case ISD::VECTOR_SHUFFLE: {
     // Collect the minimum number of sign bits that are shared by every vector
     // element referenced by the shuffle.
-    APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
+    APInt DemandedLHS, DemandedRHS;
     const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
     assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
-    for (unsigned i = 0; i != NumElts; ++i) {
-      int M = SVN->getMaskElt(i);
-      if (!DemandedElts[i])
-        continue;
-      // For UNDEF elements, we don't know anything about the common state of
-      // the shuffle result.
-      if (M < 0)
-        return 1;
-      if ((unsigned)M < NumElts)
-        DemandedLHS.setBit((unsigned)M % NumElts);
-      else
-        DemandedRHS.setBit((unsigned)M % NumElts);
-    }
+    if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
+                                DemandedLHS, DemandedRHS))
+      return 1;
+
     Tmp = std::numeric_limits<unsigned>::max();
     if (!!DemandedLHS)
       Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedLHS, Depth + 1);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cc0789423cd44..4a34909dbcb69 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12,6 +12,7 @@
 
 #include "llvm/CodeGen/TargetLowering.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/CodeGenCommonISel.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
@@ -1291,25 +1292,10 @@ bool TargetLowering::SimplifyDemandedBits(
     ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
 
     // Collect demanded elements from shuffle operands..
-    APInt DemandedLHS(NumElts, 0);
-    APInt DemandedRHS(NumElts, 0);
-    for (unsigned i = 0; i != NumElts; ++i) {
-      if (!DemandedElts[i])
-        continue;
-      int M = ShuffleMask[i];
-      if (M < 0) {
-        // For UNDEF elements, we don't know anything about the common state of
-        // the shuffle result.
-        DemandedLHS.clearAllBits();
-        DemandedRHS.clearAllBits();
-        break;
-      }
-      assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
-      if (M < (int)NumElts)
-        DemandedLHS.setBit(M);
-      else
-        DemandedRHS.setBit(M - NumElts);
-    }
+    APInt DemandedLHS, DemandedRHS;
+    if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
+                                DemandedRHS))
+      break;
 
     if (!!DemandedLHS || !!DemandedRHS) {
       SDValue Op0 = Op.getOperand(0);

diff  --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index 0ee95b10e867c..2e1ec806b73a0 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -166,6 +166,38 @@ TEST_F(BasicTest, widenShuffleMaskElts) {
   EXPECT_EQ(makeArrayRef(WideMask), makeArrayRef({-2,-3}));
 }
 
+TEST_F(BasicTest, getShuffleDemandedElts) {
+  APInt LHS, RHS;
+
+  // broadcast zero
+  EXPECT_TRUE(getShuffleDemandedElts(4, {0, 0, 0, 0}, APInt(4,0xf), LHS, RHS));
+  EXPECT_EQ(LHS.getZExtValue(), 0x1);
+  EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+  // broadcast zero (with non-permitted undefs)
+  EXPECT_FALSE(getShuffleDemandedElts(2, {0, -1}, APInt(2, 0x3), LHS, RHS));
+
+  // broadcast zero (with permitted undefs)
+  EXPECT_TRUE(getShuffleDemandedElts(3, {0, 0, -1}, APInt(3, 0x7), LHS, RHS, true));
+  EXPECT_EQ(LHS.getZExtValue(), 0x1);
+  EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+  // broadcast one in demanded
+  EXPECT_TRUE(getShuffleDemandedElts(4, {1, 1, 1, -1}, APInt(4, 0x7), LHS, RHS));
+  EXPECT_EQ(LHS.getZExtValue(), 0x2);
+  EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+  // broadcast 7 in demanded
+  EXPECT_TRUE(getShuffleDemandedElts(4, {7, 0, 7, 7}, APInt(4, 0xd), LHS, RHS));
+  EXPECT_EQ(LHS.getZExtValue(), 0x0);
+  EXPECT_EQ(RHS.getZExtValue(), 0x8);
+
+  // general test
+  EXPECT_TRUE(getShuffleDemandedElts(4, {4, 2, 7, 3}, APInt(4, 0xf), LHS, RHS));
+  EXPECT_EQ(LHS.getZExtValue(), 0xc);
+  EXPECT_EQ(RHS.getZExtValue(), 0x9);
+}
+
 TEST_F(BasicTest, getSplatIndex) {
   EXPECT_EQ(getSplatIndex({0,0,0}), 0);
   EXPECT_EQ(getSplatIndex({1,0,0}), -1);     // no splat


        


More information about the llvm-commits mailing list