[llvm] 55a11b5 - [VectorUtils] Add getShuffleDemandedElts helper
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sun Oct 30 10:04:04 PDT 2022
Author: Simon Pilgrim
Date: 2022-10-30T17:03:55Z
New Revision: 55a11b542e055744cac3beb87714fc90cacba4c0
URL: https://github.com/llvm/llvm-project/commit/55a11b542e055744cac3beb87714fc90cacba4c0
DIFF: https://github.com/llvm/llvm-project/commit/55a11b542e055744cac3beb87714fc90cacba4c0.diff
LOG: [VectorUtils] Add getShuffleDemandedElts helper
We have similar code to translate a demanded elements mask for a shuffle's operands in multiple places - this patch adds a helper function to VectorUtils and updates a number of locations to use it directly.
Differential Revision: https://reviews.llvm.org/D136832
Added:
Modified:
llvm/include/llvm/Analysis/VectorUtils.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/Analysis/VectorUtils.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/unittests/Analysis/VectorUtilsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 5119208761d4c..0f2346cf7c57e 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -366,6 +366,14 @@ Value *getSplatValue(const Value *V);
/// not limited by finding a scalar source value to a splatted vector.
bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
+/// Transform a shuffle mask's output demanded element mask into demanded
+/// element masks for the 2 operands, returns false if the mask isn't valid.
+/// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
+/// \p AllowUndefElts permits "-1" indices to be treated as undef.
+bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
+ const APInt &DemandedElts, APInt &DemandedLHS,
+ APInt &DemandedRHS, bool AllowUndefElts = false);
+
/// Replace each shuffle mask index with the scaled sequential indices for an
/// equivalent mask of narrowed elements. Mask elements that are less than 0
/// (sentinel values) are repeated in the output mask.
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 64594d9012440..49cfff7f2a68c 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -34,6 +34,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
@@ -174,32 +175,8 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
int NumElts =
cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
- int NumMaskElts = cast<FixedVectorType>(Shuf->getType())->getNumElements();
- DemandedLHS = DemandedRHS = APInt::getZero(NumElts);
- if (DemandedElts.isZero())
- return true;
- // Simple case of a shuffle with zeroinitializer.
- if (all_of(Shuf->getShuffleMask(), [](int Elt) { return Elt == 0; })) {
- DemandedLHS.setBit(0);
- return true;
- }
- for (int i = 0; i != NumMaskElts; ++i) {
- if (!DemandedElts[i])
- continue;
- int M = Shuf->getMaskValue(i);
- assert(M < (NumElts * 2) && "Invalid shuffle mask constant");
-
- // For undef elements, we don't know anything about the common state of
- // the shuffle result.
- if (M == -1)
- return false;
- if (M < NumElts)
- DemandedLHS.setBit(M % NumElts);
- else
- DemandedRHS.setBit(M % NumElts);
- }
-
- return true;
+ return llvm::getShuffleDemandedElts(NumElts, Shuf->getShuffleMask(),
+ DemandedElts, DemandedLHS, DemandedRHS);
}
static void computeKnownBits(const Value *V, const APInt &DemandedElts,
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index b4398170a34c5..cca347560e2d7 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -429,6 +429,43 @@ bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
return false;
}
+bool llvm::getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
+ const APInt &DemandedElts, APInt &DemandedLHS,
+ APInt &DemandedRHS, bool AllowUndefElts) {
+ DemandedLHS = DemandedRHS = APInt::getZero(SrcWidth);
+
+ // Early out if we don't demand any elements.
+ if (DemandedElts.isZero())
+ return true;
+
+ // Simple case of a shuffle with zeroinitializer.
+ if (all_of(Mask, [](int Elt) { return Elt == 0; })) {
+ DemandedLHS.setBit(0);
+ return true;
+ }
+
+ for (unsigned I = 0, E = Mask.size(); I != E; ++I) {
+ int M = Mask[I];
+ assert((-1 <= M) && (M < (SrcWidth * 2)) &&
+ "Invalid shuffle mask constant");
+
+ if (!DemandedElts[I] || (AllowUndefElts && (M < 0)))
+ continue;
+
+ // For undef elements, we don't know anything about the common state of
+ // the shuffle result.
+ if (M < 0)
+ return false;
+
+ if (M < SrcWidth)
+ DemandedLHS.setBit(M);
+ else
+ DemandedRHS.setBit(M - SrcWidth);
+ }
+
+ return true;
+}
+
void llvm::narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
SmallVectorImpl<int> &ScaledMask) {
assert(Scale > 0 && "Unexpected scaling factor");
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b908f2b4574a6..c849722dab4be 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -27,6 +27,7 @@
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
#include "llvm/CodeGen/ISDOpcodes.h"
@@ -2978,30 +2979,15 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
case ISD::VECTOR_SHUFFLE: {
// Collect the known bits that are shared by every vector element referenced
// by the shuffle.
- APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
- Known.Zero.setAllBits(); Known.One.setAllBits();
+ APInt DemandedLHS, DemandedRHS;
const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
- for (unsigned i = 0; i != NumElts; ++i) {
- if (!DemandedElts[i])
- continue;
-
- int M = SVN->getMaskElt(i);
- if (M < 0) {
- // For UNDEF elements, we don't know anything about the common state of
- // the shuffle result.
- Known.resetAll();
- DemandedLHS.clearAllBits();
- DemandedRHS.clearAllBits();
- break;
- }
+ if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
+ DemandedLHS, DemandedRHS))
+ break;
- if ((unsigned)M < NumElts)
- DemandedLHS.setBit((unsigned)M % NumElts);
- else
- DemandedRHS.setBit((unsigned)M % NumElts);
- }
// Known bits are the values that are shared by every demanded element.
+ Known.Zero.setAllBits(); Known.One.setAllBits();
if (!!DemandedLHS) {
SDValue LHS = Op.getOperand(0);
Known2 = computeKnownBits(LHS, DemandedLHS, Depth + 1);
@@ -3984,22 +3970,13 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::VECTOR_SHUFFLE: {
// Collect the minimum number of sign bits that are shared by every vector
// element referenced by the shuffle.
- APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0);
+ APInt DemandedLHS, DemandedRHS;
const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op);
assert(NumElts == SVN->getMask().size() && "Unexpected vector size");
- for (unsigned i = 0; i != NumElts; ++i) {
- int M = SVN->getMaskElt(i);
- if (!DemandedElts[i])
- continue;
- // For UNDEF elements, we don't know anything about the common state of
- // the shuffle result.
- if (M < 0)
- return 1;
- if ((unsigned)M < NumElts)
- DemandedLHS.setBit((unsigned)M % NumElts);
- else
- DemandedRHS.setBit((unsigned)M % NumElts);
- }
+ if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts,
+ DemandedLHS, DemandedRHS))
+ return 1;
+
Tmp = std::numeric_limits<unsigned>::max();
if (!!DemandedLHS)
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedLHS, Depth + 1);
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cc0789423cd44..4a34909dbcb69 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12,6 +12,7 @@
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/CodeGenCommonISel.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
@@ -1291,25 +1292,10 @@ bool TargetLowering::SimplifyDemandedBits(
ArrayRef<int> ShuffleMask = cast<ShuffleVectorSDNode>(Op)->getMask();
// Collect demanded elements from shuffle operands..
- APInt DemandedLHS(NumElts, 0);
- APInt DemandedRHS(NumElts, 0);
- for (unsigned i = 0; i != NumElts; ++i) {
- if (!DemandedElts[i])
- continue;
- int M = ShuffleMask[i];
- if (M < 0) {
- // For UNDEF elements, we don't know anything about the common state of
- // the shuffle result.
- DemandedLHS.clearAllBits();
- DemandedRHS.clearAllBits();
- break;
- }
- assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range");
- if (M < (int)NumElts)
- DemandedLHS.setBit(M);
- else
- DemandedRHS.setBit(M - NumElts);
- }
+ APInt DemandedLHS, DemandedRHS;
+ if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS,
+ DemandedRHS))
+ break;
if (!!DemandedLHS || !!DemandedRHS) {
SDValue Op0 = Op.getOperand(0);
diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index 0ee95b10e867c..2e1ec806b73a0 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -166,6 +166,38 @@ TEST_F(BasicTest, widenShuffleMaskElts) {
EXPECT_EQ(makeArrayRef(WideMask), makeArrayRef({-2,-3}));
}
+TEST_F(BasicTest, getShuffleDemandedElts) {
+ APInt LHS, RHS;
+
+ // broadcast zero
+ EXPECT_TRUE(getShuffleDemandedElts(4, {0, 0, 0, 0}, APInt(4,0xf), LHS, RHS));
+ EXPECT_EQ(LHS.getZExtValue(), 0x1);
+ EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+ // broadcast zero (with non-permitted undefs)
+ EXPECT_FALSE(getShuffleDemandedElts(2, {0, -1}, APInt(2, 0x3), LHS, RHS));
+
+ // broadcast zero (with permitted undefs)
+ EXPECT_TRUE(getShuffleDemandedElts(3, {0, 0, -1}, APInt(3, 0x7), LHS, RHS, true));
+ EXPECT_EQ(LHS.getZExtValue(), 0x1);
+ EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+ // broadcast one in demanded
+ EXPECT_TRUE(getShuffleDemandedElts(4, {1, 1, 1, -1}, APInt(4, 0x7), LHS, RHS));
+ EXPECT_EQ(LHS.getZExtValue(), 0x2);
+ EXPECT_EQ(RHS.getZExtValue(), 0x0);
+
+ // broadcast 7 in demanded
+ EXPECT_TRUE(getShuffleDemandedElts(4, {7, 0, 7, 7}, APInt(4, 0xd), LHS, RHS));
+ EXPECT_EQ(LHS.getZExtValue(), 0x0);
+ EXPECT_EQ(RHS.getZExtValue(), 0x8);
+
+ // general test
+ EXPECT_TRUE(getShuffleDemandedElts(4, {4, 2, 7, 3}, APInt(4, 0xf), LHS, RHS));
+ EXPECT_EQ(LHS.getZExtValue(), 0xc);
+ EXPECT_EQ(RHS.getZExtValue(), 0x9);
+}
+
TEST_F(BasicTest, getSplatIndex) {
EXPECT_EQ(getSplatIndex({0,0,0}), 0);
EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat
More information about the llvm-commits
mailing list