[llvm] [SLP]Improve minbitwidth analysis for shifts. (PR #84356)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 11:30:10 PST 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/84356

>From 107d47c8dddaa4b106bac5267ddc31f841d66208 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 7 Mar 2024 18:25:36 +0000
Subject: [PATCH] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20initia?=
 =?UTF-8?q?l=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 112 +++++++++++++++---
 .../X86/reorder-possible-strided-node.ll      |   6 +-
 .../X86/reorder_diamond_match.ll              |  25 ++--
 3 files changed, 112 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 1889bc09e85028..3364f34d0148cc 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10094,16 +10094,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
         BitWidth = UserIt->second.second;
     }
   }
-  auto CheckBitwidth = [&](const TreeEntry &TE) {
-    Type *ScalarTy = TE.Scalars.front()->getType();
-    if (!ScalarTy->isIntegerTy())
-      return true;
-    unsigned TEBitWidth = DL->getTypeStoreSize(ScalarTy);
-    auto UserIt = MinBWs.find(TEUseEI.UserTE);
-    if (UserIt != MinBWs.end())
-      TEBitWidth = UserIt->second.second;
-    return BitWidth == TEBitWidth;
-  };
   SmallVector<SmallPtrSet<const TreeEntry *, 4>> UsedTEs;
   DenseMap<Value *, int> UsedValuesEntry;
   for (Value *V : VL) {
@@ -10138,8 +10128,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
           continue;
       }
 
-      if (!CheckBitwidth(*TEPtr))
-        continue;
       // Check if the user node of the TE comes after user node of TEPtr,
       // otherwise TEPtr depends on TE.
       if ((TEInsertBlock != InsertPt->getParent() ||
@@ -10157,7 +10145,7 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
           VTE = *It->getSecond().begin();
           // Iterate through all vectorized nodes.
           auto *MIt = find_if(It->getSecond(), [&](const TreeEntry *MTE) {
-            return MTE->State == TreeEntry::Vectorize && CheckBitwidth(*MTE);
+            return MTE->State == TreeEntry::Vectorize;
           });
           if (MIt == It->getSecond().end())
             continue;
@@ -10167,8 +10155,6 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
       Instruction &LastBundleInst = getLastInstructionInBundle(VTE);
       if (&LastBundleInst == TEInsertPt || !CheckOrdering(&LastBundleInst))
         continue;
-      if (!CheckBitwidth(*VTE))
-        continue;
       VToTEs.insert(VTE);
     }
     if (VToTEs.empty())
@@ -10216,6 +10202,45 @@ BoUpSLP::isGatherShuffledSingleRegisterEntry(
     return std::nullopt;
   }
 
+  if (BitWidth > 0) {
+    // Check if the used TEs supposed to be resized and choose the best
+    // candidates.
+    unsigned NodesBitWidth = 0;
+    auto CheckBitwidth = [&](const TreeEntry &TE) {
+      unsigned TEBitWidth = BitWidth;
+      auto UserIt = MinBWs.find(TEUseEI.UserTE);
+      if (UserIt != MinBWs.end())
+        TEBitWidth = UserIt->second.second;
+      if (BitWidth <= TEBitWidth) {
+        if (NodesBitWidth == 0)
+          NodesBitWidth = TEBitWidth;
+        return NodesBitWidth == TEBitWidth;
+      }
+      return false;
+    };
+    for (auto [Idx, Set] : enumerate(UsedTEs)) {
+      DenseSet<const TreeEntry *> ForRemoval;
+      for (const TreeEntry *TE : Set) {
+        if (!CheckBitwidth(*TE))
+          ForRemoval.insert(TE);
+      }
+      // All elements must be removed - remove the whole container.
+      if (ForRemoval.size() == Set.size()) {
+        Set.clear();
+        continue;
+      }
+      for (const TreeEntry *TE : ForRemoval)
+        Set.erase(TE);
+    }
+    for (auto *It = UsedTEs.begin(); It != UsedTEs.end();) {
+      if (It->empty()) {
+        UsedTEs.erase(It);
+        continue;
+      }
+      std::advance(It, 1);
+    }
+  }
+
   unsigned VF = 0;
   if (UsedTEs.size() == 1) {
     // Keep the order to avoid non-determinism.
@@ -13946,6 +13971,63 @@ bool BoUpSLP::collectValuesToDemote(
     MaxDepthLevel = std::max(Level1, Level2);
     break;
   }
+  case Instruction::Shl: {
+    // If we are truncating the result of this SHL, and if it's a shift of an
+    // inrange amount, we can always perform a SHL in a smaller type.
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
+  case Instruction::LShr: {
+    // If this is a truncate of a logical shr, we can truncate it to a smaller
+    // lshr iff we know that the bits we would otherwise be shifting in are
+    // already zeros.
+    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        !MaskedValueIsZero(I->getOperand(0), ShiftedBits, SimplifyQuery(*DL)) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
+  case Instruction::AShr: {
+    // If this is a truncate of an arithmetic shr, we can truncate it to a
+    // smaller ashr iff we know that all the bits from the sign bit of the
+    // original type and the sign bit of the truncate type are similar.
+    uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
+    unsigned Level1, Level2;
+    KnownBits AmtKnownBits = computeKnownBits(I->getOperand(1), *DL);
+    unsigned ShiftedBits = OrigBitWidth - BitWidth;
+    if (AmtKnownBits.getMaxValue().uge(BitWidth) ||
+        ShiftedBits >=
+            ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT) ||
+        !collectValuesToDemote(I->getOperand(0), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level1, IsProfitableToDemote) ||
+        !collectValuesToDemote(I->getOperand(1), IsProfitableToDemoteRoot,
+                               BitWidth, ToDemote, DemotedConsts, Visited,
+                               Level2, IsProfitableToDemote))
+      return false;
+    MaxDepthLevel = std::max(Level1, Level2);
+    break;
+  }
 
   // We can demote selects if we can demote their true and false values.
   case Instruction::Select: {
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
index 6f5d3d3785e0c8..6378f696b470d4 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-possible-strided-node.ll
@@ -10,10 +10,8 @@ define void @test() {
 ; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[ARRAYIDX22]], align 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> poison, <4 x i32> <i32 3, i32 2, i32 1, i32 0>
 ; CHECK-NEXT:    [[TMP3:%.*]] = mul <4 x i32> [[TMP2]], [[TMP0]]
-; CHECK-NEXT:    [[TMP4:%.*]] = sext <4 x i32> [[TMP3]] to <4 x i64>
-; CHECK-NEXT:    [[TMP5:%.*]] = ashr <4 x i64> [[TMP4]], zeroinitializer
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i64> [[TMP5]] to <4 x i32>
-; CHECK-NEXT:    store <4 x i32> [[TMP6]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
+; CHECK-NEXT:    [[TMP4:%.*]] = ashr <4 x i32> [[TMP3]], zeroinitializer
+; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr getelementptr inbounds ([4 x i32], ptr null, i64 8, i64 0), align 16
 ; CHECK-NEXT:    ret void
 ;
 entry:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
index 86b1e1a801e32f..91ee4dba07009f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder_diamond_match.ll
@@ -5,18 +5,19 @@ define void @test() {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr undef, i64 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds [4 x [4 x i32]], ptr undef, i64 0, i64 1, i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = shl nsw <4 x i32> [[TMP6]], zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = add nsw <4 x i32> [[TMP7]], zeroinitializer
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x i32> [[TMP8]], <4 x i32> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
-; CHECK-NEXT:    [[TMP10:%.*]] = add nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = sub nsw <4 x i32> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP10]], <4 x i32> [[TMP11]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
-; CHECK-NEXT:    [[TMP13:%.*]] = add nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = sub nsw <4 x i32> zeroinitializer, [[TMP12]]
-; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i32> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
+; CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i8> [[TMP3]] to <4 x i16>
+; CHECK-NEXT:    [[TMP5:%.*]] = sub <4 x i16> zeroinitializer, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shl <4 x i16> [[TMP5]], zeroinitializer
+; CHECK-NEXT:    [[TMP7:%.*]] = add <4 x i16> [[TMP6]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i16> [[TMP7]], <4 x i16> poison, <4 x i32> <i32 1, i32 0, i32 3, i32 2>
+; CHECK-NEXT:    [[TMP9:%.*]] = add nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = sub nsw <4 x i16> [[TMP7]], [[TMP8]]
+; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i16> [[TMP9]], <4 x i16> [[TMP10]], <4 x i32> <i32 1, i32 4, i32 3, i32 6>
+; CHECK-NEXT:    [[TMP12:%.*]] = add nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP13:%.*]] = sub nsw <4 x i16> zeroinitializer, [[TMP11]]
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i16> [[TMP12]], <4 x i16> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP15:%.*]] = zext <4 x i16> [[TMP14]] to <4 x i32>
 ; CHECK-NEXT:    store <4 x i32> [[TMP15]], ptr [[TMP2]], align 16
 ; CHECK-NEXT:    ret void
 ;



More information about the llvm-commits mailing list