[llvm] InstSimplify: lookthru casts, binops in folding shuffles (PR #92668)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Mon May 20 05:56:37 PDT 2024
================
@@ -5315,55 +5315,107 @@ Value *llvm::simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
return ::simplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit);
}
+using ReplacementTy = std::optional<std::pair<Value *, Value *>>;
+
/// For the given destination element of a shuffle, peek through shuffles to
/// match a root vector source operand that contains that element in the same
/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s).
-static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1,
- int MaskVal, Value *RootVec,
- unsigned MaxRecurse) {
+static std::pair<Value *, ReplacementTy>
+foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, int MaskVal,
+ Value *RootVec, unsigned MaxRecurse) {
if (!MaxRecurse--)
- return nullptr;
+ return {nullptr, std::nullopt};
// Bail out if any mask value is undefined. That kind of shuffle may be
// simplified further based on demanded bits or other folds.
if (MaskVal == -1)
- return nullptr;
+ return {nullptr, std::nullopt};
// The mask value chooses which source operand we need to look at next.
- int InVecNumElts = cast<FixedVectorType>(Op0->getType())->getNumElements();
+ unsigned InVecNumElts =
+ cast<FixedVectorType>(Op0->getType())->getNumElements();
int RootElt = MaskVal;
Value *SourceOp = Op0;
- if (MaskVal >= InVecNumElts) {
+ if (MaskVal > -1 && static_cast<unsigned>(MaskVal) >= InVecNumElts) {
RootElt = MaskVal - InVecNumElts;
SourceOp = Op1;
}
+ // The next RootVec preseves an existing RootVec for casts and binops, so that
+ // the final RootVec is the last cast or binop in the chain.
+ Value *NextRootVec = RootVec ? RootVec : SourceOp;
+
+ // Look through a cast instruction that preserves number of elements in the
+ // vector. Set the RootVec to the cast, the SourceOp to the operand and
+ // recurse. If, in a later stack frame, an appropriate ShuffleVector is
+ // matched, the example will reduce.
+ if (auto *SourceCast = dyn_cast<CastInst>(SourceOp)) {
+ if (auto *CastTy = dyn_cast<FixedVectorType>(SourceCast->getSrcTy()))
+ if (CastTy->getNumElements() == InVecNumElts)
+ return foldIdentityShuffles(DestElt, SourceCast->getOperand(0), Op1,
+ MaskVal, NextRootVec, MaxRecurse);
+ }
+
+ // Look through a binary operator, with two identical operands. Set the
+ // RootVec to the binop, the SourceOp to the operand, and recurse. If, in a
+ // later stack frame, an appropriate ShuffleVector is matched, the example
+ // will reduce.
+ Value *BinOpLHS = nullptr, *BinOpRHS = nullptr;
+ if (match(SourceOp, m_BinOp(m_Value(BinOpLHS), m_Value(BinOpRHS))) &&
+ BinOpLHS == BinOpRHS)
+ return foldIdentityShuffles(DestElt, BinOpLHS, Op1, MaskVal, NextRootVec,
+ MaxRecurse);
+
// If the source operand is a shuffle itself, look through it to find the
// matching root vector.
if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) {
+ // Here, we use RootVec, because there is no requirement for finding the
+ // last shuffle in a chain. In fact, the zeroth operand of the first shuffle
+ // in the chain will be used as the RootVec for folding.
return foldIdentityShuffles(
DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1),
SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse);
}
- // TODO: Look through bitcasts? What if the bitcast changes the vector element
- // size?
-
- // The source operand is not a shuffle. Initialize the root vector value for
- // this shuffle if that has not been done yet.
- if (!RootVec)
- RootVec = SourceOp;
-
- // Give up as soon as a source operand does not match the existing root value.
- if (RootVec != SourceOp)
- return nullptr;
-
// The element must be coming from the same lane in the source vector
// (although it may have crossed lanes in intermediate shuffles).
if (RootElt != DestElt)
- return nullptr;
+ return {nullptr, std::nullopt};
+
+ // If NextRootVec is equal to SourceOp, no replacements are required. Just
+ // return NextRootVec as the leaf value of the recursion.
+ if (NextRootVec == SourceOp)
+ return {NextRootVec, std::nullopt};
+
+ auto *RootCast = dyn_cast<CastInst>(RootVec);
+ auto *CastTy =
+ RootCast ? dyn_cast<FixedVectorType>(RootCast->getSrcTy()) : nullptr;
+
+ // We again have to match the condition for a vector-num-element-preserving
+ // cast or binop with equal operands, as we are not assured of the recursion
+ // happening from the call after the previous match. The RootVec was set to
+ // the last cast or binop in the chain.
+ if ((CastTy && CastTy->getNumElements() == InVecNumElts) ||
+ (match(RootVec, m_BinOp(m_Value(BinOpLHS), m_Value(BinOpRHS))) &&
+ BinOpLHS == BinOpRHS)) {
+ // ReplacementSrc should be the User of the the first cast or binop in the
+ // chain. SourceOp is the reduced value, which should replace
+ // ReplacementSrc.
+ if (auto *ReplacementSrc = SourceOp->getUniqueUndroppableUser()) {
+ // If ReplacementSrc equals RootVec, it means that we didn't recurse after
+ // matching a cast or binop, and we should terminate the recursion and
+ // return the leaf value here.
+ if (ReplacementSrc == RootVec)
+ return {RootVec, std::nullopt};
+ // The User of ReplacementSrc is the first cast or binop in the chain.
+ // There could be other Users, but we constrain it to a unique User, since
+ // we perform RAUW later.
+ if (ReplacementSrc->hasOneUser())
----------------
arsenm wrote:
Checking for users in InstSimplify feels weird, can you check for hasOneUse on the input instead/
https://github.com/llvm/llvm-project/pull/92668
More information about the llvm-commits
mailing list