[llvm] [SLP] Cluster SortedBases before sorting. (PR #101144)
    David Green via llvm-commits 
    llvm-commits at lists.llvm.org
       
    Tue Jul 30 11:16:58 PDT 2024
    
    
  
https://github.com/davemgreen updated https://github.com/llvm/llvm-project/pull/101144
>From 62f69dc7eff5c420fec19479c1d6a5c0ea568b62 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 30 Jul 2024 08:15:48 +0100
Subject: [PATCH 1/2] [SLP] Cluster SortedBases before sorting.
In order to enforce a strict-weak ordering, this patch clusters the bases that
are being sorted by the root - the first value in a gep chain. The sorting is
then performed in each cluster.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 45 ++++++++++++-------
 .../SLPVectorizer/AArch64/loadorder.ll        |  8 ++--
 2 files changed, 34 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 6501a14d87789..f716667ad99a8 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4843,25 +4843,40 @@ static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
     return false;
 
   // If we have a better order, also sort the base pointers by increasing
-  // (variable) values if possible, to try and keep the order more regular.
-  SmallVector<std::pair<Value *, Value *>> SortedBases;
-  for (auto &Base : Bases)
-    SortedBases.emplace_back(Base.first,
-                             Base.first->stripInBoundsConstantOffsets());
-  llvm::stable_sort(SortedBases, [](std::pair<Value *, Value *> V1,
-                                    std::pair<Value *, Value *> V2) {
-    const Value *V = V2.second;
-    while (auto *Gep = dyn_cast<GetElementPtrInst>(V)) {
-      if (Gep->getOperand(0) == V1.second)
-        return true;
-      V = Gep->getOperand(0);
+  // (variable) values if possible, to try and keep the order more regular. In
+  // order to create a valid strict-weak order we cluster by the Root of gep
+  // chains and sort within each.
+  SmallVector<std::tuple<Value *, Value *, Value *>> SortedBases;
+  for (auto &Base : Bases) {
+    Value *Strip = Base.first->stripInBoundsConstantOffsets();
+    Value *Root = Strip;
+    while (auto *Gep = dyn_cast<GetElementPtrInst>(Root))
+      Root = Gep->getOperand(0);
+    SortedBases.emplace_back(Base.first, Strip, Root);
+  }
+  if (SortedBases.size() <= 16) {
+    auto Begin = SortedBases.begin();
+    auto End = SortedBases.end();
+    while (Begin != End) {
+      Value *Root = std::get<2>(*Begin);
+      auto Mid = std::stable_partition(
+          Begin, End, [&Root](auto V) { return std::get<2>(V) == Root; });
+      std::stable_sort(Begin, Mid, [](auto V1, auto V2) {
+        const Value *V = std::get<1>(V2);
+        while (auto *Gep = dyn_cast<GetElementPtrInst>(V)) {
+          if (Gep->getOperand(0) == std::get<1>(V1))
+            return true;
+          V = Gep->getOperand(0);
+        }
+        return false;
+      });
+      Begin = Mid;
     }
-    return false;
-  });
+  }
 
   // Collect the final order of sorted indices
   for (auto Base : SortedBases)
-    for (auto &T : Bases[Base.first])
+    for (auto &T : Bases[std::get<0>(Base)])
       SortedIndices.push_back(std::get<2>(T));
 
   assert(SortedIndices.size() == VL.size() &&
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll
index 6b5503f26fabf..d79aed89b0be7 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/loadorder.ll
@@ -428,14 +428,14 @@ define i32 @reduce_blockstrided4x4(ptr nocapture noundef readonly %p1, i32 nound
 ; CHECK-NEXT:    [[TMP5:%.*]] = load <4 x i8>, ptr [[ADD_PTR64]], align 1
 ; CHECK-NEXT:    [[TMP6:%.*]] = load <4 x i8>, ptr [[ARRAYIDX3_1]], align 1
 ; CHECK-NEXT:    [[TMP7:%.*]] = load <4 x i8>, ptr [[ARRAYIDX5_1]], align 1
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> [[TMP1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> [[TMP4]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <16 x i8> [[TMP8]], <16 x i8> [[TMP9]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP5]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <16 x i8> [[TMP10]], <16 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
 ; CHECK-NEXT:    [[TMP13:%.*]] = zext <16 x i8> [[TMP12]] to <16 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP3]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
-; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i8> [[TMP6]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP6]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP15:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <16 x i8> [[TMP14]], <16 x i8> [[TMP15]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP17:%.*]] = shufflevector <4 x i8> [[TMP7]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
 ; CHECK-NEXT:    [[TMP18:%.*]] = shufflevector <16 x i8> [[TMP16]], <16 x i8> [[TMP17]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
>From 0a8e6e4c00974257b53a3e1a27a064d04d48ac73 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 30 Jul 2024 19:16:46 +0100
Subject: [PATCH 2/2] Compute LessThan by walking up gep chains
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 40 ++++++++++---------
 1 file changed, 22 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f716667ad99a8..73fcaacbc58a7 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4854,24 +4854,28 @@ static bool clusterSortPtrAccesses(ArrayRef<Value *> VL, Type *ElemTy,
       Root = Gep->getOperand(0);
     SortedBases.emplace_back(Base.first, Strip, Root);
   }
-  if (SortedBases.size() <= 16) {
-    auto Begin = SortedBases.begin();
-    auto End = SortedBases.end();
-    while (Begin != End) {
-      Value *Root = std::get<2>(*Begin);
-      auto Mid = std::stable_partition(
-          Begin, End, [&Root](auto V) { return std::get<2>(V) == Root; });
-      std::stable_sort(Begin, Mid, [](auto V1, auto V2) {
-        const Value *V = std::get<1>(V2);
-        while (auto *Gep = dyn_cast<GetElementPtrInst>(V)) {
-          if (Gep->getOperand(0) == std::get<1>(V1))
-            return true;
-          V = Gep->getOperand(0);
-        }
-        return false;
-      });
-      Begin = Mid;
-    }
+  auto *Begin = SortedBases.begin();
+  auto *End = SortedBases.end();
+  while (Begin != End) {
+    Value *Root = std::get<2>(*Begin);
+    auto *Mid = std::stable_partition(
+        Begin, End, [&Root](auto V) { return std::get<2>(V) == Root; });
+    DenseMap<Value *, DenseMap<Value *, bool>> LessThan;
+    SmallPtrSet<Value *, 4> Geps;
+    for (auto I = Begin; I < Mid; ++I)
+      Geps.insert(std::get<1>(*I));
+    for (auto I = Begin; I < Mid; ++I) {
+      Value *V = std::get<1>(*I);
+      while (auto *Gep = dyn_cast<GetElementPtrInst>(V)) {
+        V = Gep->getOperand(0);
+        if (Geps.contains(V))
+          LessThan[V][std::get<1>(*I)] = true;
+      }
+    }
+    std::stable_sort(Begin, Mid, [&LessThan](auto &V1, auto &V2) {
+      return LessThan[std::get<1>(V1)][std::get<1>(V2)];
+    });
+    Begin = Mid;
   }
 
   // Collect the final order of sorted indices
    
    
More information about the llvm-commits
mailing list