[llvm] f2f02b2 - [VectorCombine] foldShuffleOfBinops - only accept exact matching cmp predicates

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 28 01:54:39 PST 2024


Author: Simon Pilgrim
Date: 2024-12-28T09:21:31Z
New Revision: f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15

URL: https://github.com/llvm/llvm-project/commit/f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15
DIFF: https://github.com/llvm/llvm-project/commit/f2f02b21cd581057e3c9b4a7a27e0014eeb9ba15.diff

LOG: [VectorCombine] foldShuffleOfBinops - only accept exact matching cmp predicates

m_SpecificCmp allowed equivalent predicate+flags which don't necessarily work after being folded from "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)"

Fixes #121110

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp
    llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index ecbc13d489eb37..2460ccc61d84df 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1669,7 +1669,8 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
 
   Value *X, *Y, *Z, *W;
   bool IsCommutative = false;
-  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+  CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
+  CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
   if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
       match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
     auto *BO = cast<BinaryOperator>(LHS);
@@ -1677,8 +1678,9 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
     if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
       return false;
     IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
-  } else if (match(LHS, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
-             match(RHS, m_SpecificCmp(Pred, m_Value(Z), m_Value(W)))) {
+  } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
+             match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
+             (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
     IsCommutative = cast<CmpInst>(LHS)->isCommutative();
   } else
     return false;
@@ -1727,14 +1729,14 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
       TTI.getShuffleCost(SK0, BinOpTy, NewMask0, CostKind, 0, nullptr, {X, Z}) +
       TTI.getShuffleCost(SK1, BinOpTy, NewMask1, CostKind, 0, nullptr, {Y, W});
 
-  if (Pred == CmpInst::BAD_ICMP_PREDICATE) {
+  if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
     NewCost +=
         TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
   } else {
     auto *ShuffleCmpTy =
         FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
     NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
-                                      ShuffleDstTy, Pred, CostKind);
+                                      ShuffleDstTy, PredLHS, CostKind);
   }
 
   LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
@@ -1750,10 +1752,10 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
 
   Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
   Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
-  Value *NewBO = Pred == CmpInst::BAD_ICMP_PREDICATE
+  Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
                      ? Builder.CreateBinOp(
                            cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
-                     : Builder.CreateCmp(Pred, Shuf0, Shuf1);
+                     : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
 
   // Intersect flags from the old binops.
   if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {

diff  --git a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
index 6ee60287e62dc8..b8b2c6aef74a3e 100644
--- a/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffle-of-cmps.ll
@@ -276,11 +276,15 @@ define <4 x i32> @shuf_icmp_ugt_v4i32_use(<4 x i32> %x, <4 x i32> %y, <4 x i32>
   ret <4 x i32> %r
 }
 
-; TODO: PR121110 - don't merge equivalent (but not matching) predicates
+; PR121110 - don't merge equivalent (but not matching) predicates
+
 define <2 x i1> @PR121110() {
 ; CHECK-LABEL: define <2 x i1> @PR121110(
 ; CHECK-SAME: ) #[[ATTR0]] {
-; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+; CHECK-NEXT:    [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
+; CHECK-NEXT:    [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
+; CHECK-NEXT:    [[RES:%.*]] = shufflevector <2 x i1> [[UGT]], <2 x i1> [[SGT]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT:    ret <2 x i1> [[RES]]
 ;
   %ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >
   %sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
@@ -291,7 +295,10 @@ define <2 x i1> @PR121110() {
 define <2 x i1> @PR121110_commute() {
 ; CHECK-LABEL: define <2 x i1> @PR121110_commute(
 ; CHECK-SAME: ) #[[ATTR0]] {
-; CHECK-NEXT:    ret <2 x i1> zeroinitializer
+; CHECK-NEXT:    [[SGT:%.*]] = icmp sgt <2 x i32> zeroinitializer, <i32 6, i32 -4>
+; CHECK-NEXT:    [[UGT:%.*]] = icmp samesign ugt <2 x i32> zeroinitializer, zeroinitializer
+; CHECK-NEXT:    [[RES:%.*]] = shufflevector <2 x i1> [[SGT]], <2 x i1> [[UGT]], <2 x i32> <i32 0, i32 3>
+; CHECK-NEXT:    ret <2 x i1> [[RES]]
 ;
   %sgt = icmp sgt <2 x i32> < i32 0, i32 0 >, < i32 6, i32 4294967292 >
   %ugt = icmp samesign ugt <2 x i32> < i32 0, i32 0 >, < i32 0, i32 0 >


        


More information about the llvm-commits mailing list