[llvm] f2f02b2 - [VectorCombine] foldShuffleOfBinops - only accept exact matching cmp predicates
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat Dec 28 01:54:39 PST 2024
Author: Simon Pilgrim
Date: 2024-12-28T09:21:31Z
New Revision: f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15
URL: https://github.com/llvm/llvm-project/commit/f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15
DIFF: https://github.com/llvm/llvm-project/commit/f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15.diff
LOG: [VectorCombine] foldShuffleOfBinops - only accept exact matching cmp predicates
m_SpecificCmp allowed equivalent predicate+flags which don't necessarily work after being folded from "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)"
Fixes #121110
Added:
Modified:
llvm/lib/Transforms/Vectorize/VectorCombine.cpp
llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index ecbc13d489eb37..2460ccc61d84df 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1669,7 +1669,8 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
Value *X, *Y, *Z, *W;
bool IsCommutative = false;
- CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
auto *BO = cast<BinaryOperator>(LHS);
@@ -1677,8 +1678,9 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
return false;
IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
- } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
- match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) {
+ } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
+ match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
+ (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
IsCommutative = cast<CmpInst>(LHS)->isCommutative();
} else
return false;
@@ -1727,14 +1729,14 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W});
- if (Pred == CmpInst::BAD_ICMP_PREDICATE) {
+ if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
NewCost +=
TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
} else {
auto *ShuffleCmpTy =
FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
- ShuffleDstTy, Pred, CostKind);
+ ShuffleDstTy, PredLHS, CostKind);
}
LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
@@ -1750,10 +1752,10 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
- Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE
+ Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
? Builder.CreateBinOp(
cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
- : Builder.CreateCmp(Pred, Shuf0, Shuf1);
+ : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
// Intersect flags from the old binops.
if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
diff --git a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
index 6ee60287e62dc8..b8b2c6aef74a3e 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
@@ -276,11 +276,15 @@ define <4 x i32> @shuf_icmp_ugt_v4i32_use(<4 x i32> %x, <4 x i32> %y, <4 x i32>
ret <4 x i32> %r
}
-; TODO: PR121110 - don't merge equivalent (but not matching) predicates
+; PR121110 - don't merge equivalent (but not matching) predicates
+
define <2 x i1> @PR121110() {
; CHECK-LABEL: define <2 x i1> @PR121110(
; CHECK-SAME: ) #[[ATTR0]] {
-; CHECK-NEXT: ret <2 x i1> zeroinitializer
+; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
+; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
+; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT: ret <2 x i1> [[RES]]
;
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
@@ -291,7 +295,10 @@ define <2 x i1> @PR121110() {
define <2 x i1> @PR121110_commute() {
; CHECK-LABEL: define <2 x i1> @PR121110_commute(
; CHECK-SAME: ) #[[ATTR0]] {
-; CHECK-NEXT: ret <2 x i1> zeroinitializer
+; CHECK-NEXT: [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
+; CHECK-NEXT: [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
+; CHECK-NEXT: [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT: ret <2 x i1> [[RES]]
;
%sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
%ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
More information about the llvm-commits
mailing list