[llvm] 58a94b1 - [SLP]Fix PR91467: Look through scalar cast, when trying to cast to another type.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu May 9 04:21:29 PDT 2024


Author: Alexey Bataev
Date: 2024-05-09T04:19:43-07:00
New Revision: 58a94b1d0ad8df85bc6b1edb22c74ffb718ca1a1

URL: https://github.com/llvm/llvm-project/commit/58a94b1d0ad8df85bc6b1edb22c74ffb718ca1a1
DIFF: https://github.com/llvm/llvm-project/commit/58a94b1d0ad8df85bc6b1edb22c74ffb718ca1a1.diff

LOG: [SLP]Fix PR91467: Look through scalar cast, when trying to cast to another type.

Need to look through the SExt/ZExt scalars to be gathered, when trying
to reduce their width after minbitwidth analysis to prevent permanent
attempts to revectorize such gathered instructions.

Added: 
    llvm/test/Transforms/SLPVectorizer/X86/extended-vectorized-gathered-inst.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll
    llvm/test/Transforms/SLPVectorizer/AArch64/user-node-not-in-bitwidths.ll
    llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-root-trunc.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 98561f9ca0442..2e0a39c4b4fdc 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11419,8 +11419,16 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
     if (Scalar->getType() != Ty) {
       assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() &&
              "Expected integer types only.");
+      Value *V = Scalar;
+      if (auto *CI = dyn_cast<CastInst>(Scalar);
+          isa_and_nonnull<SExtInst, ZExtInst>(CI)) {
+        Value *Op = CI->getOperand(0);
+        if (auto *IOp = dyn_cast<Instruction>(Op);
+            !IOp || !(isDeleted(IOp) || getTreeEntry(IOp)))
+          V = Op;
+      }
       Scalar = Builder.CreateIntCast(
-          Scalar, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
+          V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
     }
 
     Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));

diff  --git a/llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll
index 76bb882171b17..3ebe920d17343 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/gather-with-minbith-user.ll
@@ -5,14 +5,7 @@ define void @h() {
 ; CHECK-LABEL: define void @h() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr null, i64 16
-; CHECK-NEXT:    [[TMP6:%.*]] = trunc i32 0 to i1
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 poison, i1 false, i1 false, i1 false>, i1 [[TMP6]], i32 4
-; CHECK-NEXT:    [[TMP1:%.*]] = sub <8 x i1> [[TMP0]], zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = add <8 x i1> [[TMP0]], zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <8 x i1> [[TMP1]], <8 x i1> [[TMP2]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15>
-; CHECK-NEXT:    [[TMP5:%.*]] = or <8 x i1> [[TMP3]], zeroinitializer
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <8 x i1> [[TMP5]] to <8 x i16>
-; CHECK-NEXT:    store <8 x i16> [[TMP4]], ptr [[ARRAYIDX2]], align 2
+; CHECK-NEXT:    store <8 x i16> zeroinitializer, ptr [[ARRAYIDX2]], align 2
 ; CHECK-NEXT:    ret void
 ;
 entry:

diff  --git a/llvm/test/Transforms/SLPVectorizer/AArch64/user-node-not-in-bitwidths.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/user-node-not-in-bitwidths.ll
index 2ab6e919c23b2..6404cf4a2cd1d 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/user-node-not-in-bitwidths.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/user-node-not-in-bitwidths.ll
@@ -5,12 +5,7 @@ define void @h() {
 ; CHECK-LABEL: define void @h() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr i8, ptr null, i64 16
-; CHECK-NEXT:    [[TMP0:%.*]] = trunc i32 0 to i1
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 poison, i1 false, i1 false, i1 false>, i1 [[TMP0]], i32 4
-; CHECK-NEXT:    [[TMP2:%.*]] = or <8 x i1> zeroinitializer, [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = or <8 x i1> zeroinitializer, [[TMP2]]
-; CHECK-NEXT:    [[TMP4:%.*]] = zext <8 x i1> [[TMP3]] to <8 x i16>
-; CHECK-NEXT:    store <8 x i16> [[TMP4]], ptr [[ARRAYIDX2]], align 2
+; CHECK-NEXT:    store <8 x i16> zeroinitializer, ptr [[ARRAYIDX2]], align 2
 ; CHECK-NEXT:    ret void
 ;
 entry:

diff  --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-root-trunc.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-root-trunc.ll
index 1bb87bf6205f1..3c8e98485ffc1 100644
--- a/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-root-trunc.ll
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-root-trunc.ll
@@ -4,10 +4,9 @@
 define void @test(ptr %a, i8 %0, i16 %b.promoted.i) {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ptr [[A:%.*]], i8 [[TMP0:%.*]], i16 [[B_PROMOTED_I:%.*]]) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:    [[TMP2:%.*]] = zext i8 [[TMP0]] to i128
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i16> poison, i16 [[B_PROMOTED_I]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i16> [[TMP3]], <4 x i16> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP5:%.*]] = trunc i128 [[TMP2]] to i16
+; CHECK-NEXT:    [[TMP5:%.*]] = zext i8 [[TMP0]] to i16
 ; CHECK-NEXT:    [[TMP6:%.*]] = insertelement <4 x i16> poison, i16 [[TMP5]], i32 0
 ; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <4 x i16> [[TMP6]], <4 x i16> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP8:%.*]] = or <4 x i16> [[TMP4]], [[TMP7]]

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/extended-vectorized-gathered-inst.ll b/llvm/test/Transforms/SLPVectorizer/X86/extended-vectorized-gathered-inst.ll
new file mode 100644
index 0000000000000..2d028060f4914
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extended-vectorized-gathered-inst.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux < %s | FileCheck %s
+
+define void @test(ptr %top) {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ptr [[TOP:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i8>, ptr [[TOP]], align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <4 x i8> [[TMP0]], zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = extractelement <4 x i8> [[TMP0]], i32 2
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i8 [[TMP2]] to i32
+; CHECK-NEXT:    [[TMP4:%.*]] = trunc i32 [[TMP3]] to i8
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x i8> <i8 0, i8 0, i8 0, i8 poison>, i8 [[TMP4]], i32 3
+; CHECK-NEXT:    [[TMP6:%.*]] = or <4 x i8> [[TMP1]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = or <4 x i8> [[TMP6]], zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr <4 x i8> [[TMP7]], <i8 2, i8 2, i8 2, i8 2>
+; CHECK-NEXT:    br label [[FOR_COND_I:%.*]]
+; CHECK:       for.cond.i:
+; CHECK-NEXT:    store <4 x i8> [[TMP8]], ptr null, align 1
+; CHECK-NEXT:    br label [[FOR_COND_I]]
+;
+entry:
+  %0 = load i8, ptr %top, align 1
+  %conv2.i = zext i8 %0 to i32
+  %mul.i = mul i32 %conv2.i, 0
+  %add.i = or i32 %mul.i, 0
+  %arrayidx3.i = getelementptr i8, ptr %top, i64 1
+  %1 = load i8, ptr %arrayidx3.i, align 1
+  %conv4.i = zext i8 %1 to i32
+  %add5.i = or i32 %add.i, 0
+  %shr.i = lshr i32 %add5.i, 2
+  %conv7.i = trunc i32 %shr.i to i8
+  %mul12.i = mul i32 %conv4.i, 0
+  %arrayidx14.i = getelementptr i8, ptr %top, i64 2
+  %2 = load i8, ptr %arrayidx14.i, align 1
+  %conv15.i = zext i8 %2 to i32
+  %add16.i = or i32 %mul12.i, 0
+  %add17.i = or i32 %add16.i, 0
+  %shr18.i = lshr i32 %add17.i, 2
+  %conv19.i = trunc i32 %shr18.i to i8
+  %mul25.i = mul i32 %conv15.i, 0
+  %arrayidx27.i = getelementptr i8, ptr %top, i64 3
+  %3 = load i8, ptr %arrayidx27.i, align 1
+  %conv28.i = zext i8 %3 to i32
+  %add29.i = or i32 %mul25.i, 0
+  %add30.i = or i32 %add29.i, 0
+  %shr31.i = lshr i32 %add30.i, 2
+  %conv32.i = trunc i32 %shr31.i to i8
+  %mul38.i = mul i32 %conv28.i, 0
+  %add39.i = or i32 %mul38.i, %conv15.i
+  %add42.i = or i32 %add39.i, 0
+  %shr44.i = lshr i32 %add42.i, 2
+  %conv45.i = trunc i32 %shr44.i to i8
+  br label %for.cond.i
+
+for.cond.i:
+  store i8 %conv7.i, ptr null, align 1
+  %vals.sroa.5.0.add.ptr.sroa_idx.i = getelementptr i8, ptr null, i64 1
+  store i8 %conv19.i, ptr %vals.sroa.5.0.add.ptr.sroa_idx.i, align 1
+  %vals.sroa.7.0.add.ptr.sroa_idx.i = getelementptr i8, ptr null, i64 2
+  store i8 %conv32.i, ptr %vals.sroa.7.0.add.ptr.sroa_idx.i, align 1
+  %vals.sroa.9.0.add.ptr.sroa_idx.i = getelementptr i8, ptr null, i64 3
+  store i8 %conv45.i, ptr %vals.sroa.9.0.add.ptr.sroa_idx.i, align 1
+  br label %for.cond.i
+}


        


More information about the llvm-commits mailing list