[llvm] [SLP]Improve minbitwidth analysis for abs/smin/smax/umin/umax intrinsics. (PR #86135)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 22 13:47:30 PDT 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/86135

>From 80a78476a3611e793a054d9aa2fe78fff2f5ceb1 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 21 Mar 2024 15:32:57 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 123 ++++++++++++++++--
 .../cmp-after-intrinsic-call-minbitwidth.ll   |  12 +-
 .../X86/store-abs-minbitwidth.ll              |   9 +-
 3 files changed, 124 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 36b446962c4a63..7f680b7af9b565 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6994,14 +6994,11 @@ bool BoUpSLP::areAllUsersVectorized(
 
 static std::pair<InstructionCost, InstructionCost>
 getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
-                   TargetTransformInfo *TTI, TargetLibraryInfo *TLI) {
+                   TargetTransformInfo *TTI, TargetLibraryInfo *TLI,
+                   ArrayRef<Type *> VecTys) {
   Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
   // Calculate the cost of the scalar and vector calls.
-  SmallVector<Type *, 4> VecTys;
-  for (Use &Arg : CI->args())
-    VecTys.push_back(
-        FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
   FastMathFlags FMF;
   if (auto *FPCI = dyn_cast<FPMathOperator>(CI))
     FMF = FPCI->getFastMathFlags();
@@ -9009,7 +9006,25 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     };
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       auto *CI = cast<CallInst>(VL0);
-      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+      Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
+      SmallVector<Type *> VecTys;
+      for (auto [Idx, Arg] : enumerate(CI->args())) {
+        if (ID != Intrinsic::not_intrinsic) {
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+            VecTys.push_back(Arg->getType());
+            continue;
+          }
+          if (It != MinBWs.end()) {
+            VecTys.push_back(FixedVectorType::get(
+                IntegerType::get(CI->getContext(), It->second.first),
+                VecTy->getNumElements()));
+            continue;
+          }
+        }
+        VecTys.push_back(
+            FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+      }
+      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
       return std::min(VecCallCosts.first, VecCallCosts.second) + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -12462,7 +12477,24 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
 
       Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI);
 
-      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI);
+      SmallVector<Type *> VecTys;
+      for (auto [Idx, Arg] : enumerate(CI->args())) {
+        if (ID != Intrinsic::not_intrinsic) {
+          if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+            VecTys.push_back(Arg->getType());
+            continue;
+          }
+          if (It != MinBWs.end()) {
+            VecTys.push_back(FixedVectorType::get(
+                IntegerType::get(CI->getContext(), It->second.first),
+                VecTy->getNumElements()));
+            continue;
+          }
+        }
+        VecTys.push_back(
+            FixedVectorType::get(Arg->getType(), VecTy->getNumElements()));
+      }
+      auto VecCallCosts = getVectorCallCosts(CI, VecTy, TTI, TLI, VecTys);
       bool UseIntrinsic = ID != Intrinsic::not_intrinsic &&
                           VecCallCosts.first <= VecCallCosts.second;
 
@@ -12471,14 +12503,13 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       SmallVector<Type *, 2> TysForDecl;
       // Add return type if intrinsic is overloaded on it.
       if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, -1))
-        TysForDecl.push_back(
-            FixedVectorType::get(CI->getType(), E->Scalars.size()));
+        TysForDecl.push_back(VecTy);
+      auto *CEI = cast<CallInst>(VL0);
       for (unsigned I : seq<unsigned>(0, CI->arg_size())) {
         ValueList OpVL;
         // Some intrinsics have scalar arguments. This argument should not be
         // vectorized.
         if (UseIntrinsic && isVectorIntrinsicWithScalarOpAtArg(ID, I)) {
-          CallInst *CEI = cast<CallInst>(VL0);
           ScalarArg = CEI->getArgOperand(I);
           OpVecs.push_back(CEI->getArgOperand(I));
           if (isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -12491,6 +12522,15 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           LLVM_DEBUG(dbgs() << "SLP: Diamond merged for " << *VL0 << ".\n");
           return E->VectorizedValue;
         }
+        ScalarArg = CEI->getArgOperand(I);
+        if (cast<VectorType>(OpVec->getType())->getElementType() !=
+            ScalarArg->getType() && It == MinBWs.end()) {
+          auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
+                                              VecTy->getNumElements());
+          OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));
+        } else if (It != MinBWs.end()) {
+          OpVec = Builder.CreateIntCast(OpVec, VecTy, GetOperandSignedness(I));
+        }
         LLVM_DEBUG(dbgs() << "SLP: OpVec[" << I << "]: " << *OpVec << "\n");
         OpVecs.push_back(OpVec);
         if (UseIntrinsic && isVectorIntrinsicWithOverloadTypeAtArg(ID, I))
@@ -14195,6 +14235,69 @@ bool BoUpSLP::collectValuesToDemote(
     break;
   }
 
+  case Instruction::Call: {
+    if (ITE->UserTreeIndices.size() > 1 && !IsPotentiallyTruncated(I, BitWidth))
+      return false;
+    if (auto *IC = dyn_cast<IntrinsicInst>(I)) {
+      Intrinsic::ID ID = getVectorIntrinsicIDForCall(IC, TLI);
+      if (ID == Intrinsic::abs || ID == Intrinsic::smin ||
+          ID == Intrinsic::smax || ID == Intrinsic::umin ||
+          ID == Intrinsic::umax) {
+        InstructionCost BestCost =
+            std::numeric_limits<InstructionCost::CostType>::max();
+        unsigned BestBitWidth = BitWidth;
+        unsigned VF = ITE->Scalars.size();
+        // Choose the best bitwidth based on cost estimations.
+        (void)AttemptCheckBitwidth(
+            [&](unsigned BitWidth, unsigned) {
+              SmallVector<Type *> VecTys;
+              auto *ITy =
+                  IntegerType::get(IC->getContext(), PowerOf2Ceil(BitWidth));
+              for (auto [Idx, Arg] : enumerate(IC->args())) {
+                if (isVectorIntrinsicWithScalarOpAtArg(ID, Idx)) {
+                  VecTys.push_back(Arg->getType());
+                  continue;
+                }
+                VecTys.push_back(FixedVectorType::get(ITy, VF));
+              }
+              auto VecCallCosts = getVectorCallCosts(
+                  IC, FixedVectorType::get(ITy, VF), TTI, TLI, VecTys);
+              InstructionCost Cost =
+                  std::min(VecCallCosts.first, VecCallCosts.second);
+              if (Cost < BestCost) {
+                BestCost = Cost;
+                BestBitWidth = BitWidth;
+              }
+              return false;
+            },
+            NeedToExit);
+        NeedToExit = false;
+        BitWidth = BestBitWidth;
+        switch (ID) {
+        case Intrinsic::abs:
+          End = 1;
+          if (!ProcessOperands(IC->getArgOperand(0), NeedToExit))
+            return false;
+          break;
+        case Intrinsic::smin:
+        case Intrinsic::smax:
+        case Intrinsic::umin:
+        case Intrinsic::umax:
+          End = 2;
+          if (!ProcessOperands({IC->getArgOperand(0), IC->getArgOperand(1)},
+                               NeedToExit))
+            return false;
+          break;
+        default:
+          llvm_unreachable("Unexpected intrinsic.");
+        }
+        break;
+      }
+    }
+    MaxDepthLevel = 1;
+    return FinalAnalysis();
+  }
+
   // Otherwise, conservatively give up.
   default:
     MaxDepthLevel = 1;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
index a05d4fdd6315b5..9fa88084aaa0af 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/cmp-after-intrinsic-call-minbitwidth.ll
@@ -5,12 +5,14 @@ define void @test() {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = call <2 x i32> @llvm.smin.v2i32(<2 x i32> zeroinitializer, <2 x i32> zeroinitializer)
-; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i32> zeroinitializer, <2 x i32> [[TMP0]]
-; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i32> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[ADD:%.*]] = extractelement <2 x i32> [[TMP2]], i32 1
+; CHECK-NEXT:    [[TMP0:%.*]] = call <2 x i2> @llvm.smin.v2i2(<2 x i2> zeroinitializer, <2 x i2> zeroinitializer)
+; CHECK-NEXT:    [[TMP1:%.*]] = select <2 x i1> zeroinitializer, <2 x i2> zeroinitializer, <2 x i2> [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i2> [[TMP1]], zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i2> [[TMP2]], i32 1
+; CHECK-NEXT:    [[ADD:%.*]] = zext i2 [[TMP3]] to i32
 ; CHECK-NEXT:    [[SHR:%.*]] = ashr i32 [[ADD]], 0
-; CHECK-NEXT:    [[ADD45:%.*]] = extractelement <2 x i32> [[TMP2]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i2> [[TMP2]], i32 0
+; CHECK-NEXT:    [[ADD45:%.*]] = zext i2 [[TMP5]] to i32
 ; CHECK-NEXT:    [[ADD152:%.*]] = or i32 [[ADD45]], [[ADD]]
 ; CHECK-NEXT:    [[IDXPROM153:%.*]] = sext i32 [[ADD152]] to i64
 ; CHECK-NEXT:    [[ARRAYIDX154:%.*]] = getelementptr i8, ptr null, i64 [[IDXPROM153]]
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
index e8b854b7cea6cb..60bec6668d23ba 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/store-abs-minbitwidth.ll
@@ -13,14 +13,13 @@ define i32 @test(ptr noalias %in, ptr noalias %inn, ptr %out) {
 ; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <2 x i8> [[TMP3]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <2 x i8> [[TMP2]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> [[TMP6]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i32>
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i8> [[TMP7]] to <4 x i16>
 ; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x i8> [[TMP1]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <2 x i8> [[TMP4]], <2 x i8> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP9]], <4 x i8> [[TMP10]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i32>
-; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i32> [[TMP12]], [[TMP8]]
-; CHECK-NEXT:    [[TMP14:%.*]] = call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[TMP13]], i1 true)
-; CHECK-NEXT:    [[TMP15:%.*]] = trunc <4 x i32> [[TMP14]] to <4 x i16>
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i8> [[TMP11]] to <4 x i16>
+; CHECK-NEXT:    [[TMP13:%.*]] = sub <4 x i16> [[TMP12]], [[TMP8]]
+; CHECK-NEXT:    [[TMP15:%.*]] = call <4 x i16> @llvm.abs.v4i16(<4 x i16> [[TMP13]], i1 true)
 ; CHECK-NEXT:    store <4 x i16> [[TMP15]], ptr [[OUT:%.*]], align 2
 ; CHECK-NEXT:    ret i32 undef
 ;

>From eb6ccb0c1ab58df2bb32d216535ef8e8a14e37c4 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 21 Mar 2024 15:38:15 +0000
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.5
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7f680b7af9b565..b878eb95ed3466 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -12524,7 +12524,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         }
         ScalarArg = CEI->getArgOperand(I);
         if (cast<VectorType>(OpVec->getType())->getElementType() !=
-            ScalarArg->getType() && It == MinBWs.end()) {
+                ScalarArg->getType() &&
+            It == MinBWs.end()) {
           auto *CastTy = FixedVectorType::get(ScalarArg->getType(),
                                               VecTy->getNumElements());
           OpVec = Builder.CreateIntCast(OpVec, CastTy, GetOperandSignedness(I));



More information about the llvm-commits mailing list