[llvm] a612524 - [SLP]Fix the cost of the reduction result to the final type.

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 7 06:51:51 PDT 2024

Author: Alexey Bataev
Date: 2024-04-07T09:51:47-04:00
New Revision: a61252419779a6d4a5ebf71e7e2fc4adc75cfddd

URL: https://github.com/llvm/llvm-project/commit/a61252419779a6d4a5ebf71e7e2fc4adc75cfddd
DIFF: https://github.com/llvm/llvm-project/commit/a61252419779a6d4a5ebf71e7e2fc4adc75cfddd.diff

LOG: [SLP]Fix the cost of the reduction result to the final type.

Need to fix the way the cost is calculated, otherwise wrong cast opcode
can be selected and lead to the over-optimistic vector cost. Plus, need
to take into account reduction type size.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: https://github.com/llvm/llvm-project/pull/87528




diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 332877f35081bd..6a662b2791bdc4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9824,11 +9824,13 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
     if (BWIt != MinBWs.end()) {
       Type *DstTy = Root.Scalars.front()->getType();
       unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
-      if (OriginalSz != BWIt->second.first) {
+      unsigned SrcSz =
+          ReductionBitWidth == 0 ? BWIt->second.first : ReductionBitWidth;
+      if (OriginalSz != SrcSz) {
         unsigned Opcode = Instruction::Trunc;
-        if (OriginalSz < BWIt->second.first)
+        if (OriginalSz > SrcSz)
           Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
-        Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+        Type *SrcTy = IntegerType::get(DstTy->getContext(), SrcSz);
         Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,

diff  --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
index 500f10659f04cb..1e7eb4a4167242 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reductions.ll
@@ -801,10 +801,20 @@ entry:
 define i64 @red_zext_ld_4xi64(ptr %ptr) {
 ; CHECK-LABEL: @red_zext_ld_4xi64(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i8>, ptr [[PTR:%.*]], align 1
-; CHECK-NEXT:    [[TMP1:%.*]] = zext <4 x i8> [[TMP0]] to <4 x i16>
-; CHECK-NEXT:    [[TMP2:%.*]] = call i16 @llvm.vector.reduce.add.v4i16(<4 x i16> [[TMP1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TMP2]] to i64
+; CHECK-NEXT:    [[LD0:%.*]] = load i8, ptr [[PTR:%.*]], align 1
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[LD0]] to i64
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 1
+; CHECK-NEXT:    [[LD1:%.*]] = load i8, ptr [[GEP]], align 1
+; CHECK-NEXT:    [[ZEXT_1:%.*]] = zext i8 [[LD1]] to i64
+; CHECK-NEXT:    [[ADD_1:%.*]] = add nuw nsw i64 [[ZEXT]], [[ZEXT_1]]
+; CHECK-NEXT:    [[GEP_1:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 2
+; CHECK-NEXT:    [[LD2:%.*]] = load i8, ptr [[GEP_1]], align 1
+; CHECK-NEXT:    [[ZEXT_2:%.*]] = zext i8 [[LD2]] to i64
+; CHECK-NEXT:    [[ADD_2:%.*]] = add nuw nsw i64 [[ADD_1]], [[ZEXT_2]]
+; CHECK-NEXT:    [[GEP_2:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 3
+; CHECK-NEXT:    [[LD3:%.*]] = load i8, ptr [[GEP_2]], align 1
+; CHECK-NEXT:    [[ZEXT_3:%.*]] = zext i8 [[LD3]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i64 [[ADD_2]], [[ZEXT_3]]
 ; CHECK-NEXT:    ret i64 [[TMP3]]

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
index 44738aa1a67479..a8d481a3e28a5c 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-drop-wrapping-flags.ll
@@ -5,17 +5,22 @@ define i32 @test() {
 ; CHECK-LABEL: define i32 @test() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A_PROMOTED:%.*]] = load i8, ptr null, align 1
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <4 x i8> poison, i8 [[A_PROMOTED]], i32 0
-; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i8> [[TMP0]], <4 x i8> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i8> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = or <4 x i8> [[TMP1]], zeroinitializer
-; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> [[TMP3]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i16>
-; CHECK-NEXT:    [[TMP6:%.*]] = add <4 x i16> [[TMP5]], <i16 -1, i16 0, i16 0, i16 0>
-; CHECK-NEXT:    [[TMP7:%.*]] = call i16 @llvm.vector.reduce.or.v4i16(<4 x i16> [[TMP6]])
-; CHECK-NEXT:    [[TMP8:%.*]] = zext i16 [[TMP7]] to i32
+; CHECK-NEXT:    [[DEC_4:%.*]] = add i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_4:%.*]] = zext i8 [[DEC_4]] to i32
+; CHECK-NEXT:    [[SUB_I_4:%.*]] = add nuw nsw i32 [[CONV_I_4]], 0
+; CHECK-NEXT:    [[DEC_5:%.*]] = add i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_5:%.*]] = zext i8 [[DEC_5]] to i32
+; CHECK-NEXT:    [[SUB_I_5:%.*]] = add nuw nsw i32 [[CONV_I_5]], 65535
+; CHECK-NEXT:    [[TMP0:%.*]] = or i32 [[SUB_I_4]], [[SUB_I_5]]
+; CHECK-NEXT:    [[DEC_6:%.*]] = or i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_6:%.*]] = zext i8 [[DEC_6]] to i32
+; CHECK-NEXT:    [[SUB_I_6:%.*]] = add nuw nsw i32 [[CONV_I_6]], 0
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[TMP0]], [[SUB_I_6]]
+; CHECK-NEXT:    [[TMP10:%.*]] = or i8 [[A_PROMOTED]], 0
+; CHECK-NEXT:    [[CONV_I_7:%.*]] = zext i8 [[TMP10]] to i32
+; CHECK-NEXT:    [[SUB_I_7:%.*]] = add nuw nsw i32 [[CONV_I_7]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = or i32 [[TMP1]], [[SUB_I_7]]
 ; CHECK-NEXT:    [[TMP9:%.*]] = and i32 [[TMP8]], 65535
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i8> [[TMP4]], i32 3
 ; CHECK-NEXT:    store i8 [[TMP10]], ptr null, align 1
 ; CHECK-NEXT:    [[CALL3:%.*]] = tail call i32 (ptr, ...) null(ptr null, i32 [[TMP9]])
 ; CHECK-NEXT:    ret i32 0

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
index 4acd63078b82ef..4af69dff179e26 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minbitwidth-transformed-operand.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
-; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-6 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
+; RUN: opt -passes=slp-vectorizer -S -slp-threshold=-7 -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
 define void @test(i64 %d.promoted.i) {
 ; CHECK-LABEL: define void @test(


More information about the llvm-commits mailing list