[llvm] [SLP]Consider (f)sub, being operand of llvm.(f)abs/icmp eq/ne 0, commutative. (PR #86196)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 11 11:36:40 PDT 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/86196

>From 7783e6103d16a0e27a2ece5b4909129b05a68d1c Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 21 Mar 2024 20:45:20 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 56 +++++++++++++++++--
 .../PhaseOrdering/X86/vector-reductions.ll    |  2 +-
 .../X86/reorder_diamond_match.ll              |  8 +--
 .../X86/store-abs-minbitwidth.ll              | 14 +----
 4 files changed, 58 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36b446962c4a63..6507a32ed49eda 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -305,7 +305,32 @@ static bool isCommutative(Instruction *I) {
   if (auto *Cmp = dyn_cast<CmpInst>(I))
     return Cmp->isCommutative();
   if (auto *BO = dyn_cast<BinaryOperator>(I))
-    return BO->isCommutative();
+    return BO->isCommutative() ||
+           (BO->getOpcode() == Instruction::Sub &&
+            !BO->hasNUsesOrMore(UsesLimit) &&
+            all_of(BO->uses(),
+                   [](const Use &U) {
+                     // Commutative, if icmp eq/ne sub, 0
+                     ICmpInst::Predicate Pred;
+                     if (match(U.getUser(),
+                               m_ICmp(Pred, m_Specific(U.get()), m_Zero())) &&
+                         (Pred == ICmpInst::ICMP_EQ ||
+                          Pred == ICmpInst::ICMP_NE))
+                       return true;
+                     // Commutative, if abs(sub, flag).
+                     if (U.getOperandNo() != 0)
+                       return false;
+                     const auto *IC = dyn_cast<IntrinsicInst>(U.getUser());
+                     return IC && IC->getIntrinsicID() == Intrinsic::abs;
+                   })) ||
+           (BO->getOpcode() == Instruction::FSub &&
+            !BO->hasNUsesOrMore(UsesLimit) &&
+            all_of(BO->uses(), [](const Use &U) {
+              if (U.getOperandNo() != 0)
+                return false;
+              const auto *IC = dyn_cast<IntrinsicInst>(U.getUser());
+              return IC && IC->getIntrinsicID() == Intrinsic::fabs;
+            }));
   // TODO: This should check for generic Instruction::isCommutative(), but
   //       we need to confirm that the caller code correctly handles Intrinsics
   //       for example (does not have 2 operands).
@@ -6671,7 +6696,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
 
       // Sort operands of the instructions so that each side is more likely to
       // have the same opcode.
-      if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) {
+      if (isa<BinaryOperator>(VL0) && isCommutative(VL0)) {
         ValueList Left, Right;
         reorderInputsAccordingToOpcode(VL, Left, Right, *TLI, *DL, *SE, *this);
         TE->setOperand(0, Left);
@@ -12318,7 +12343,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Value *V = Builder.CreateBinOp(
           static_cast<Instruction::BinaryOps>(E->getOpcode()), LHS,
           RHS);
-      propagateIRFlags(V, E->Scalars, VL0, !MinBWs.contains(E));
+      propagateIRFlags(V, E->Scalars, VL0,
+                       !MinBWs.contains(E) &&
+                           (ShuffleOrOp != Instruction::Sub ||
+                            none_of(E->Scalars, [](Value *V) {
+                              return isCommutative(cast<Instruction>(V));
+                            })));
       if (auto *I = dyn_cast<Instruction>(V))
         V = propagateMetadata(I, E->Scalars);
 
@@ -12612,8 +12642,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           },
           Mask, &OpScalars, &AltScalars);
 
-      propagateIRFlags(V0, OpScalars);
-      propagateIRFlags(V1, AltScalars);
+      propagateIRFlags(V0, OpScalars, VL0,
+                       !isa<BinaryOperator>(VL0) ||
+                           (!MinBWs.contains(E) &&
+                            (E->getOpcode() != Instruction::Sub ||
+                             none_of(E->Scalars, [](Value *V) {
+                               auto *I = cast<Instruction>(V);
+                               return I->getOpcode() == Instruction::Sub &&
+                                      isCommutative(cast<Instruction>(V));
+                             }))));
+      propagateIRFlags(V1, AltScalars, E->getAltOp(),
+                       !isa<BinaryOperator>(VL0) ||
+                           (!MinBWs.contains(E) &&
+                            (E->getAltOpcode() != Instruction::Sub ||
+                             none_of(E->Scalars, [](Value *V) {
+                               auto *I = cast<Instruction>(V);
+                               return I->getOpcode() == Instruction::Sub &&
+                                      isCommutative(cast<Instruction>(V));
+                             }))));
 
       Value *V = Builder.CreateShuffleVector(V0, V1, Mask);
       if (auto *I = dyn_cast<Instruction>(V)) {
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
index f28f297591e7b1..7acc71ca2476ac 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
@@ -63,7 +63,7 @@ define i32 @TestVectorsEqual(ptr noalias %Vec0, ptr noalias %Vec1, i32 %Toleranc
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[VEC0:%.*]], align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[VEC1:%.*]], align 4
-; CHECK-NEXT:    [[TMP2:%.*]] = sub nsw <4 x i32> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = sub <4 x i32> [[TMP0]], [[TMP1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP2]], i1 true)
 ; CHECK-NEXT:    [[TMP4:%.*]] = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP3]])
 ; CHECK-NEXT:    [[CMP5_NOT:%.*]] = icmp sle i32 [[TMP4]], [[TOLERANCE:%.*]]
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index dce85b4b2a195e..9682567b173c3e 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -11,11 +11,11 @@ define void @test() {
 ; CHECK-NEXT:    [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
 ; CHECK-NEXT:    [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
 ; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT:    [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
-; CHECK-NEXT:    [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP9:%.*]] = add <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sub <4 x i16> [[TMP7]], [[TMP8]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT:    [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP12:%.*]] = add <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i16> zeroinitializer, [[TMP11]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
 ; CHECK-NEXT:    [[TMP15:%.*]] = sext <4 x i16> [[TMP14]] to <4 x i32>
 ; CHECK-NEXT:    store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
index e8b854b7cea6cb..8dcca6ec2bb3a3 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
@@ -4,19 +4,9 @@
 
 define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
 ; CHECK-LABEL: @test(
-; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x i8>, ptr [[IN:%.*]], align 1
-; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[IN]], i64 2
-; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x i8>, ptr [[GEP_2]], align 1
-; CHECK-NEXT:    [[TMP3:%.*]] = load <2 x i8>, ptr [[INN:%.*]], align 1
-; CHECK-NEXT:    [[GEP_5:%.*]] = getelementptr inbounds i8, ptr [[INN]], i64 2
-; CHECK-NEXT:    [[TMP4:%.*]] = load <2 x i8>, ptr [[GEP_5]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
+; CHECK-NEXT:    [[TMP7:%.*]] = load <4 x i8>, ptr [[IN:%.*]], align 1
+; CHECK-NEXT:    [[TMP11:%.*]] = load <4 x i8>, ptr [[INN:%.*]], align 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
 ; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
 ; CHECK-NEXT:    [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)



More information about the llvm-commits mailing list