[llvm] 8d89dd4 - [SLP]Fix PR79743: Check that all users are demoted before trying to

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 29 10:54:43 PST 2024


Author: Alexey Bataev
Date: 2024-01-29T10:51:20-08:00
New Revision: 8d89dd4a5872503d6d5b070bdb48d20973156e07

URL: https://github.com/llvm/llvm-project/commit/8d89dd4a5872503d6d5b070bdb48d20973156e07
DIFF: https://github.com/llvm/llvm-project/commit/8d89dd4a5872503d6d5b070bdb48d20973156e07.diff

LOG: [SLP]Fix PR79743: Check that all users are demoted before trying to
demote the tree entry.

Need to check if all user nodes are marked for demotion before demoting
the node. Otherwise, some data info might be lost after vectorization.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-with-multi-users.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7ecf3b244ad2b6..3a639cb94fdfb4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13323,6 +13323,20 @@ void BoUpSLP::computeMinimumValueSizes() {
                           Visited);
   }
 
+  // Check that all users are marked for demotion.
+  DenseSet<Value *> Demoted(ToDemote.begin(), ToDemote.end());
+  DenseSet<const TreeEntry *> Visited;
+  for (Value *V: ToDemote) {
+    const TreeEntry *TE = getTreeEntry(V);
+    assert(TE && "Expected vectorized scalar.");
+    if (!Visited.insert(TE).second)
+      continue;
+    if (!all_of(TE->UserTreeIndices, [&](const EdgeInfo &EI) {
+          return all_of(EI.UserTE->Scalars,
+                        [&](Value *V) { return Demoted.contains(V); });
+        }))
+      return;
+  }
   // Finally, map the values we can demote to the maximum bit with we computed.
   for (auto *Scalar : ToDemote) {
     auto *TE = getTreeEntry(Scalar);

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-with-multi-users.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-with-multi-users.ll
index 958761484303f4..136ab64007732f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-with-multi-users.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-node-with-multi-users.ll
@@ -10,20 +10,16 @@ define void @test() {
 ; CHECK-NEXT:    [[TMP3:%.*]] = select i1 false, i32 0, i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <4 x i8> <i8 poison, i8 0, i8 poison, i8 poison>, i8 [[TMP1]], i32 0
 ; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <4 x i32> <i32 0, i32 0, i32 0, i32 1>
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc <4 x i8> [[TMP5]] to <4 x i1>
-; CHECK-NEXT:    [[TMP7:%.*]] = zext <4 x i1> [[TMP6]] to <4 x i32>
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP9:%.*]] = or <4 x i8> [[TMP8]], zeroinitializer
-; CHECK-NEXT:    [[TMP10:%.*]] = sext <4 x i8> [[TMP9]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = zext <4 x i1> [[TMP6]] to <4 x i32>
-; CHECK-NEXT:    [[TMP12:%.*]] = or <4 x i32> zeroinitializer, [[TMP11]]
-; CHECK-NEXT:    [[TMP13:%.*]] = icmp eq <4 x i32> [[TMP10]], [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = shufflevector <4 x i32> <i32 0, i32 0, i32 poison, i32 0>, <4 x i32> [[TMP7]], <4 x i32> <i32 0, i32 1, i32 6, i32 3>
-; CHECK-NEXT:    [[TMP15:%.*]] = select <4 x i1> [[TMP13]], <4 x i32> [[TMP14]], <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP16:%.*]] = trunc <4 x i32> [[TMP15]] to <4 x i1>
-; CHECK-NEXT:    [[TMP17:%.*]] = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> [[TMP16]])
-; CHECK-NEXT:    [[TMP18:%.*]] = zext i1 [[TMP17]] to i32
-; CHECK-NEXT:    [[OP_RDX:%.*]] = and i32 0, [[TMP18]]
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i8> [[TMP5]] to <4 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = or <4 x i8> [[TMP7]], zeroinitializer
+; CHECK-NEXT:    [[TMP9:%.*]] = sext <4 x i8> [[TMP8]] to <4 x i32>
+; CHECK-NEXT:    [[TMP10:%.*]] = or <4 x i32> zeroinitializer, [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp eq <4 x i32> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <4 x i32> [[TMP6]], <4 x i32> <i32 0, i32 0, i32 poison, i32 0>, <4 x i32> <i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP13:%.*]] = select <4 x i1> [[TMP11]], <4 x i32> [[TMP12]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.and.v4i32(<4 x i32> [[TMP13]])
+; CHECK-NEXT:    [[OP_RDX:%.*]] = and i32 0, [[TMP14]]
 ; CHECK-NEXT:    store i32 [[OP_RDX]], ptr null, align 4
 ; CHECK-NEXT:    ret void
 ;


        


More information about the llvm-commits mailing list