[llvm] 0a68cd2 - [SLP]Fix PR64252: Requesting cost of invalid extending instruction.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 31 13:41:22 PDT 2023


Author: Alexey Bataev
Date: 2023-07-31T13:37:52-07:00
New Revision: 0a68cd23049de778a4440168c855aff38f202940

URL: https://github.com/llvm/llvm-project/commit/0a68cd23049de778a4440168c855aff38f202940
DIFF: https://github.com/llvm/llvm-project/commit/0a68cd23049de778a4440168c855aff38f202940.diff

LOG: [SLP]Fix PR64252: Requesting cost of invalid extending instruction.

If the actual instruction bitwidth does not match its original size,
need to reestimate the casting opcode, the compiler cannot rely on the
one, provided in the instruction.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2c22789cacfb56..564cdaf6d9210a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7291,9 +7291,10 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   // If we have computed a smaller type for the expression, update VecTy so
   // that the costs will be accurate.
   auto It = MinBWs.find(VL.front());
-  if (It != MinBWs.end())
-    VecTy = FixedVectorType::get(
-        IntegerType::get(F->getContext(), It->second.first), VL.size());
+  if (It != MinBWs.end()) {
+    ScalarTy = IntegerType::get(F->getContext(), It->second.first);
+    VecTy = FixedVectorType::get(ScalarTy, VL.size());
+  }
   unsigned EntryVF = E->getVectorFactor();
   auto *FinalVecTy = FixedVectorType::get(VecTy->getElementType(), EntryVF);
 
@@ -7303,6 +7304,16 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
       return 0;
     if (isa<InsertElementInst>(VL[0]))
       return InstructionCost::getInvalid();
+    // The gather nodes use small bitwidth only if all operands use the same
+    // bitwidth. Otherwise - use the original one.
+    if (It != MinBWs.end() && any_of(VL.drop_front(), [&](Value *V) {
+          auto VIt = MinBWs.find(V);
+          return VIt == MinBWs.end() || VIt->second.first != It->second.first ||
+                 VIt->second.second != It->second.second;
+        })) {
+      ScalarTy = VL.front()->getType();
+      VecTy = FixedVectorType::get(ScalarTy, VL.size());
+    }
     ShuffleCostEstimator Estimator(*TTI, VectorizedVals, *this,
                                    CheckedExtracts);
     unsigned VF = E->getVectorFactor();
@@ -7725,22 +7736,60 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   case Instruction::Trunc:
   case Instruction::FPTrunc:
   case Instruction::BitCast: {
+    auto SrcIt = MinBWs.find(VL0->getOperand(0));
+    Type *SrcScalarTy = VL0->getOperand(0)->getType();
+    auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+    unsigned Opcode = ShuffleOrOp;
+    if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
+        (SrcIt != MinBWs.end() || It != MinBWs.end())) {
+      // Check if the values are candidates to demote.
+      unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+      if (SrcIt != MinBWs.end()) {
+        SrcBWSz = SrcIt->second.first;
+        SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
+        SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+      }
+      unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+      if (BWSz == SrcBWSz) {
+        Opcode = Instruction::BitCast;
+      } else if (BWSz < SrcBWSz) {
+        Opcode = Instruction::Trunc;
+      } else if (It != MinBWs.end()) {
+        assert(BWSz > SrcBWSz && "Invalid cast!");
+        Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      }
+    }
     auto GetScalarCost = [=](unsigned Idx) {
-      auto *VI = cast<Instruction>(VL[Idx]);
-      return TTI->getCastInstrCost(E->getOpcode(), ScalarTy,
-                                   VI->getOperand(0)->getType(),
+      auto *VI =
+          VL0->getOpcode() == Opcode ? cast<Instruction>(VL[Idx]) : nullptr;
+      return TTI->getCastInstrCost(Opcode, ScalarTy, SrcScalarTy,
                                    TTI::getCastContextHint(VI), CostKind, VI);
     };
+    TTI::CastContextHint CCH = TTI::CastContextHint::None;
+    if (const TreeEntry *OpTE = getTreeEntry(VL0->getOperand(0))) {
+      if (OpTE->State == TreeEntry::ScatterVectorize) {
+        CCH = TTI::CastContextHint::GatherScatter;
+      } else if (OpTE->State == TreeEntry::Vectorize &&
+                 OpTE->getOpcode() == Instruction::Load &&
+                 !OpTE->isAltShuffle()) {
+        if (OpTE->ReorderIndices.empty()) {
+          CCH = TTI::CastContextHint::Normal;
+        } else {
+          SmallVector<int> Mask;
+          inversePermutation(OpTE->ReorderIndices, Mask);
+          if (ShuffleVectorInst::isReverseMask(Mask))
+            CCH = TTI::CastContextHint::Reversed;
+        }
+      }
+    } else {
+      InstructionsState SrcState = getSameOpcode(E->getOperand(0), *TLI);
+      if (SrcState.getOpcode() == Instruction::Load && !SrcState.isAltShuffle())
+        CCH = TTI::CastContextHint::GatherScatter;
+    }
     auto GetVectorCost = [=](InstructionCost CommonCost) {
-      Type *SrcTy = VL0->getOperand(0)->getType();
-      auto *SrcVecTy = FixedVectorType::get(SrcTy, VL.size());
-      InstructionCost VecCost = CommonCost;
-      // Check if the values are candidates to demote.
-      if (!MinBWs.contains(VL0) || VecTy != SrcVecTy)
-        VecCost +=
-            TTI->getCastInstrCost(E->getOpcode(), VecTy, SrcVecTy,
-                                  TTI::getCastContextHint(VL0), CostKind, VL0);
-      return VecCost;
+      auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
+      return CommonCost +
+             TTI->getCastInstrCost(Opcode, VecTy, SrcVecTy, CCH, CostKind, VI);
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
index 68ed6062e3c40b..903adc8893f346 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
@@ -1,9 +1,9 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+sse2     -S | FileCheck %s --check-prefixes=CHECK,SSE
-; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx      -S | FileCheck %s --check-prefixes=CHECK,AVX
-; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx2     -S | FileCheck %s --check-prefixes=CHECK,AVX
-; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx512f  -S | FileCheck %s --check-prefixes=CHECK,AVX
-; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx512vl -S | FileCheck %s --check-prefixes=CHECK,AVX
+; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+sse2     -S | FileCheck %s --check-prefix=SSE
+; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx      -S | FileCheck %s --check-prefix=AVX
+; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx2     -S | FileCheck %s --check-prefix=AVX
+; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx512f  -S | FileCheck %s --check-prefix=AVX
+; RUN: opt < %s -slp-threshold=-6 -passes=slp-vectorizer,instcombine -mattr=+avx512vl -S | FileCheck %s --check-prefix=AVX
 
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -15,21 +15,34 @@ target triple = "x86_64-unknown-linux-gnu"
 ; zero-extend the roots back to their original sizes.
 ;
 define i8 @PR31243_zext(i8 %v0, i8 %v1, i8 %v2, i8 %v3, ptr %ptr) {
-; CHECK-LABEL: @PR31243_zext(
-; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = insertelement <2 x i8> poison, i8 [[V0:%.*]], i64 0
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[TMP0]], i8 [[V1:%.*]], i64 1
-; CHECK-NEXT:    [[TMP2:%.*]] = or <2 x i8> [[TMP1]], <i8 1, i8 1>
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <2 x i8> [[TMP2]], i64 0
-; CHECK-NEXT:    [[TMP4:%.*]] = zext i8 [[TMP3]] to i64
-; CHECK-NEXT:    [[TMP_4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP4]]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <2 x i8> [[TMP2]], i64 1
-; CHECK-NEXT:    [[TMP6:%.*]] = zext i8 [[TMP5]] to i64
-; CHECK-NEXT:    [[TMP_5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP6]]
-; CHECK-NEXT:    [[TMP_6:%.*]] = load i8, ptr [[TMP_4]], align 1
-; CHECK-NEXT:    [[TMP_7:%.*]] = load i8, ptr [[TMP_5]], align 1
-; CHECK-NEXT:    [[TMP_8:%.*]] = add i8 [[TMP_6]], [[TMP_7]]
-; CHECK-NEXT:    ret i8 [[TMP_8]]
+; SSE-LABEL: @PR31243_zext(
+; SSE-NEXT:  entry:
+; SSE-NEXT:    [[TMP0:%.*]] = or i8 [[V0:%.*]], 1
+; SSE-NEXT:    [[TMP1:%.*]] = or i8 [[V1:%.*]], 1
+; SSE-NEXT:    [[TMP2:%.*]] = zext i8 [[TMP0]] to i64
+; SSE-NEXT:    [[TMP_4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP2]]
+; SSE-NEXT:    [[TMP3:%.*]] = zext i8 [[TMP1]] to i64
+; SSE-NEXT:    [[TMP_5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP3]]
+; SSE-NEXT:    [[TMP_6:%.*]] = load i8, ptr [[TMP_4]], align 1
+; SSE-NEXT:    [[TMP_7:%.*]] = load i8, ptr [[TMP_5]], align 1
+; SSE-NEXT:    [[TMP_8:%.*]] = add i8 [[TMP_6]], [[TMP_7]]
+; SSE-NEXT:    ret i8 [[TMP_8]]
+;
+; AVX-LABEL: @PR31243_zext(
+; AVX-NEXT:  entry:
+; AVX-NEXT:    [[TMP0:%.*]] = insertelement <2 x i8> poison, i8 [[V0:%.*]], i64 0
+; AVX-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[TMP0]], i8 [[V1:%.*]], i64 1
+; AVX-NEXT:    [[TMP2:%.*]] = or <2 x i8> [[TMP1]], <i8 1, i8 1>
+; AVX-NEXT:    [[TMP3:%.*]] = extractelement <2 x i8> [[TMP2]], i64 0
+; AVX-NEXT:    [[TMP4:%.*]] = zext i8 [[TMP3]] to i64
+; AVX-NEXT:    [[TMP_4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP4]]
+; AVX-NEXT:    [[TMP5:%.*]] = extractelement <2 x i8> [[TMP2]], i64 1
+; AVX-NEXT:    [[TMP6:%.*]] = zext i8 [[TMP5]] to i64
+; AVX-NEXT:    [[TMP_5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP6]]
+; AVX-NEXT:    [[TMP_6:%.*]] = load i8, ptr [[TMP_4]], align 1
+; AVX-NEXT:    [[TMP_7:%.*]] = load i8, ptr [[TMP_5]], align 1
+; AVX-NEXT:    [[TMP_8:%.*]] = add i8 [[TMP_6]], [[TMP_7]]
+; AVX-NEXT:    ret i8 [[TMP_8]]
 ;
 entry:
   %tmp_0 = zext i8 %v0 to i32


        


More information about the llvm-commits mailing list