[llvm] Fold patterns which uses <2N x iM> type for comparisons on <N x i2M> type (PR #184328)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 10 04:53:44 PDT 2026
================
@@ -1296,6 +1296,181 @@ Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
return SimplifySelectsFeedingBinaryOp(I, LHS, RHS);
}
+// Folds patterns which uses comparisons on <2N x iM> type for a <N x i2M>
+// equality comparison.
+//
+// (A1, ..., AN) -> (A1Lower, A1Upper, ..., ANLower, ANUpper)
+// (B1, ..., BN) -> (B1Lower, B1Upper, ..., BNLower, BNUpper)
+// (Result1, ..., ResultN) -> (Result1, Result1, ..., ResultN, ResultN)
+//
+// where,
+//
+// ResultX = EqLowerX & EqUpperX
+// EqLowerX = AXLower == BXLower
+// EqUpperX = AXUpper == BXUpper
+//
+// Bitwise AND between the upper and lower parts can be achived by performing
+// the operation between the original and shuffled equality vector.
+Instruction *InstCombinerImpl::foldVni2mCmpEqUsingV2nim(Instruction &I) {
+ auto *ResultVecType = dyn_cast<VectorType>(I.getType());
+
+ if (!ResultVecType || ResultVecType->isScalableTy() ||
+ !ResultVecType->getElementType()->isIntegerTy() ||
+ ResultVecType->getElementCount().getFixedValue() % 2 != 0)
+ return nullptr;
+
+ // Check pattern existance
+ Value *L, *R;
+ CmpPredicate Pred;
+ ArrayRef<int> Mask;
+
+ auto Equal = m_SExtOrSelf(m_ICmp(Pred, m_Value(L), m_Value(R)));
+ auto Shuffle = m_SExtOrSelf(m_Shuffle(Equal, m_Poison(), m_Mask(Mask)));
+ if (!match(&I,
+ m_SExtOrSelf(m_CombineOr(m_c_And(Equal, Shuffle),
+ m_Select(Equal, Shuffle, m_Zero())))) ||
+ Pred != CmpInst::ICMP_EQ)
+ return nullptr;
+
+ auto *OldVecType = cast<VectorType>(L->getType());
+
+ if (OldVecType != ResultVecType)
+ return nullptr;
+
+ // Example shuffle mask: {1, 0, 3, 2}
+ for (auto I = 0; I < static_cast<int>(Mask.size()); I += 2)
+ if (Mask[I] != I + 1 || Mask[I + 1] != I)
+ return nullptr;
+
+ LLVM_DEBUG(dbgs() << "IC: Folding Vn2im CmpEq using V2nim CmpEq pattern"
+ << '\n');
+
+ // Perform folding
+ auto OldElementCount = OldVecType->getElementCount().getFixedValue();
+ auto OldElementWidth = OldVecType->getElementType()->getIntegerBitWidth();
+ auto *NewElementType = IntegerType::get(I.getContext(), OldElementWidth * 2);
+ auto *NewVecType =
+ VectorType::get(NewElementType, OldElementCount / 2, false);
+ auto *BitCastL = Builder.CreateBitCast(L, NewVecType);
+ auto *BitCastR = Builder.CreateBitCast(R, NewVecType);
+ auto *Cmp = Builder.CreateICmp(Pred, BitCastL, BitCastR);
+ auto *SExt = Builder.CreateSExt(Cmp, NewVecType);
+ auto *BitCastCmp = Builder.CreateBitCast(SExt, OldVecType);
----------------
RKSimon wrote:
(style) avoid auto
https://github.com/llvm/llvm-project/pull/184328
More information about the llvm-commits
mailing list