[llvm] InstSimplify: lookthru casts, binops in folding shuffles (PR #92668)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Sat May 18 12:29:14 PDT 2024
================
@@ -5315,55 +5315,104 @@ Value *llvm::simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
return ::simplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit);
}
+using ReplacementTy = std::optional<std::pair<Value *, Value *>>;
+
/// For the given destination element of a shuffle, peek through shuffles to
/// match a root vector source operand that contains that element in the same
/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s).
-static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1,
- int MaskVal, Value *RootVec,
- unsigned MaxRecurse) {
+static std::pair<Value *, ReplacementTy>
+foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, int MaskVal,
+ Value *RootVec, unsigned MaxRecurse) {
if (!MaxRecurse--)
- return nullptr;
+ return {nullptr, std::nullopt};
// Bail out if any mask value is undefined. That kind of shuffle may be
// simplified further based on demanded bits or other folds.
if (MaskVal == -1)
- return nullptr;
+ return {nullptr, std::nullopt};
// The mask value chooses which source operand we need to look at next.
- int InVecNumElts = cast<FixedVectorType>(Op0->getType())->getNumElements();
+ unsigned InVecNumElts =
+ cast<FixedVectorType>(Op0->getType())->getNumElements();
int RootElt = MaskVal;
Value *SourceOp = Op0;
- if (MaskVal >= InVecNumElts) {
+ if (MaskVal > -1 && static_cast<unsigned>(MaskVal) >= InVecNumElts) {
RootElt = MaskVal - InVecNumElts;
SourceOp = Op1;
}
+ // The next RootVec preseves an existing RootVec for casts and binops, so that
+ // the final RootVec is the last cast or binop in the chain.
+ Value *NextRootVec = RootVec ? RootVec : SourceOp;
+
+ // Look through a cast instruction that preserves number of elements in the
+ // vector. Set the RootVec to the cast, the SourceOp to the operand and
+ // recurse. If, in a later stack frame, an appropriate ShuffleVector is
+ // matched, the example will reduce.
+ if (auto *SourceCast = dyn_cast<CastInst>(SourceOp))
+ if (auto *CastTy = dyn_cast<FixedVectorType>(SourceCast->getSrcTy()))
+ if (CastTy->getNumElements() == InVecNumElts)
+ return foldIdentityShuffles(DestElt, SourceCast->getOperand(0), Op1,
+ MaskVal, NextRootVec, MaxRecurse);
+
+ // Look through a binary operator, with two identical operands. Set the
+ // RootVec to the binop, the SourceOp to the operand, and recurse. If, in a
+ // later stack frame, an appropriate ShuffleVector is matched, the example
+ // will reduce.
+ Value *BinOpLHS = nullptr, *BinOpRHS = nullptr;
+ if (match(SourceOp, m_BinOp(m_Value(BinOpLHS), m_Value(BinOpRHS))) &&
+ BinOpLHS == BinOpRHS)
+ return foldIdentityShuffles(DestElt, BinOpLHS, Op1, MaskVal, NextRootVec,
+ MaxRecurse);
+
// If the source operand is a shuffle itself, look through it to find the
// matching root vector.
if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) {
+ // Here, we use RootVec, because there is no requirement for finding the
+ // last shuffle in a chain. In fact, the zeroth operand of the first shuffle
+ // in the chain will be used as the RootVec for folding.
return foldIdentityShuffles(
DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1),
SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse);
}
- // TODO: Look through bitcasts? What if the bitcast changes the vector element
- // size?
-
- // The source operand is not a shuffle. Initialize the root vector value for
- // this shuffle if that has not been done yet.
- if (!RootVec)
- RootVec = SourceOp;
-
- // Give up as soon as a source operand does not match the existing root value.
- if (RootVec != SourceOp)
- return nullptr;
-
// The element must be coming from the same lane in the source vector
// (although it may have crossed lanes in intermediate shuffles).
if (RootElt != DestElt)
- return nullptr;
+ return {nullptr, std::nullopt};
+
+ // If NextRootVec is equal to SourceOp, no replacements are required. Just
+ // return NextRootVec as the leaf value of the recursion.
+ if (NextRootVec == SourceOp)
+ return {NextRootVec, std::nullopt};
+
+ // We again have to match the condition for a vector-num-element-preserving
+ // cast or binop with equal operands, as we are not assured of the recursion
+ // happening from the call after the previous match. The RootVec was set to
+ // the last cast or binop in the chain.
+ if ((isa<CastInst>(RootVec) &&
+ isa<FixedVectorType>(cast<CastInst>(RootVec)->getSrcTy()) &&
----------------
arsenm wrote:
dyn_cast instead of isa+cast
https://github.com/llvm/llvm-project/pull/92668
More information about the llvm-commits
mailing list