[llvm] [SLP]Fix PR87011: missing sign extension of demoted type before zero (PR #87054)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 29 05:56:53 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

extension.

Corner case, where sext/zext node cannot be directly promoted because of
the signedness switching. In this case, at first need to cast operand
value to the original type with the its signedness and only after this
cast the result to the new type with the new signedness. Also, need to
adjust cost model to handle this kind of transformation.


---
Full diff: https://github.com/llvm/llvm-project/pull/87054.diff


4 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+87-34) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll (+1-1) 
- (modified) llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll (+1-1) 
- (modified) llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2875e71081d928..579db52921e676 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -8788,23 +8788,52 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto *SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
     unsigned Opcode = ShuffleOrOp;
     unsigned VecOpcode = Opcode;
+    TTI::CastContextHint VecCCH = GetCastContextHint(VL0->getOperand(0));
+    Instruction *VI = VL0;
     if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
         (SrcIt != MinBWs.end() || It != MinBWs.end())) {
+      VI = nullptr;
       // Check if the values are candidates to demote.
       unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
-      if (SrcIt != MinBWs.end()) {
-        SrcBWSz = SrcIt->second.first;
-        SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
-        SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
-      }
-      unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
-      if (BWSz == SrcBWSz) {
-        VecOpcode = Instruction::BitCast;
-      } else if (BWSz < SrcBWSz) {
-        VecOpcode = Instruction::Trunc;
-      } else if (It != MinBWs.end()) {
-        assert(BWSz > SrcBWSz && "Invalid cast!");
-        VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+      if (It == MinBWs.end() && SrcIt != MinBWs.end() &&
+          SrcBWSz != SrcIt->second.first &&
+          all_of(VL, [&](Value *V) {
+            return !isKnownNonNegative(V, SimplifyQuery(*DL));
+          }) != SrcIt->second.second) {
+        // Neeed to perform first cast src to original src type.
+        if (SrcBWSz != SrcIt->second.first) {
+          CommonCost += TTI->getCastInstrCost(
+              SrcBWSz < SrcIt->second.first
+                  ? Instruction::Trunc
+                  : (SrcIt->second.second ? Instruction::SExt
+                                          : Instruction::ZExt),
+              SrcVecTy,
+              FixedVectorType::get(
+                  IntegerType::get(F->getContext(), SrcIt->second.first),
+                  VL.size()),
+              VecCCH, CostKind);
+          VecCCH = TTI::CastContextHint::None;
+        }
+      } else {
+        bool Signedness = false;
+        if (SrcIt != MinBWs.end()) {
+          SrcBWSz = SrcIt->second.first;
+          SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
+          SrcVecTy = FixedVectorType::get(SrcScalarTy, VL.size());
+          Signedness = SrcIt->second.second;
+        } else {
+          assert(It != MinBWs.end() && "Expected node in MinBWs.");
+          Signedness = It->second.second;
+        }
+        unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+        if (BWSz == SrcBWSz) {
+          VecOpcode = Instruction::BitCast;
+        } else if (BWSz < SrcBWSz) {
+          VecOpcode = Instruction::Trunc;
+        } else {
+          assert(BWSz > SrcBWSz && "Invalid cast!");
+          VecOpcode = Signedness ? Instruction::SExt : Instruction::ZExt;
+        }
       }
     }
     auto GetScalarCost = [&](unsigned Idx) -> InstructionCost {
@@ -8814,15 +8843,12 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                                    TTI::getCastContextHint(VI), CostKind, VI);
     };
     auto GetVectorCost = [=](InstructionCost CommonCost) {
-      // Do not count cost here if minimum bitwidth is in effect and it is just
-      // a bitcast (here it is just a noop).
+      // Do not count cost here if minimum bitwidth is in effect and it is
+      // just a bitcast (here it is just a noop).
       if (VecOpcode != Opcode && VecOpcode == Instruction::BitCast)
         return CommonCost;
-      auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
-      TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
-      return CommonCost +
-             TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
-                                   VecOpcode == Opcode ? VI : nullptr);
+      return CommonCost + TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy,
+                                                VecCCH, CostKind, VI);
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
   }
@@ -12145,18 +12171,37 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           (SrcIt != MinBWs.end() || It != MinBWs.end() ||
            SrcScalarTy != CI->getOperand(0)->getType())) {
         // Check if the values are candidates to demote.
-        unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
-        if (SrcIt != MinBWs.end())
-          SrcBWSz = SrcIt->second.first;
+        unsigned OrigSrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+        unsigned SrcBWSz = OrigSrcBWSz;
         unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
-        if (BWSz == SrcBWSz) {
-          VecOpcode = Instruction::BitCast;
-        } else if (BWSz < SrcBWSz) {
-          VecOpcode = Instruction::Trunc;
-        } else if (SrcIt != MinBWs.end()) {
-          assert(BWSz > SrcBWSz && "Invalid cast!");
-          VecOpcode =
-              SrcIt->second.second ? Instruction::SExt : Instruction::ZExt;
+        if (It == MinBWs.end() && SrcIt != MinBWs.end() &&
+            all_of(E->Scalars, [&](Value *V) {
+              return !isKnownNonNegative(V, SimplifyQuery(*DL));
+            }) != SrcIt->second.second) {
+          // Neeed to perform first cast.
+          InVec = Builder.CreateIntCast(
+              InVec,
+              VectorType::get(
+                  CI->getOperand(0)->getType(),
+                  cast<VectorType>(InVec->getType())->getElementCount()),
+              SrcIt->second.second);
+        } else {
+          bool Signedness = false;
+          if (SrcIt != MinBWs.end()) {
+            SrcBWSz = SrcIt->second.first;
+            Signedness = SrcIt->second.second;
+          } else {
+            assert(It != MinBWs.end() && "Expected node in MinBWs.");
+            Signedness = It->second.second;
+          }
+          if (BWSz == SrcBWSz) {
+            VecOpcode = Instruction::BitCast;
+          } else if (BWSz < SrcBWSz) {
+            VecOpcode = Instruction::Trunc;
+          } else {
+            assert(BWSz > SrcBWSz && "Invalid cast!");
+            VecOpcode = Signedness ? Instruction::SExt : Instruction::ZExt;
+          }
         }
       }
       Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
@@ -14454,10 +14499,18 @@ void BoUpSLP::computeMinimumValueSizes() {
       Value *V = VectorizableTree[Idx]->Scalars.front();
       uint32_t OrigBitWidth = DL->getTypeSizeInBits(V->getType());
       if (OrigBitWidth > MaxBitWidth) {
-      APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
-      if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
-        ToDemote.push_back(V);
+        APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, MaxBitWidth);
+        if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL))) {
+          ToDemote.push_back(V);
+          continue;
+        }
       }
+      auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
+      unsigned BitWidth = OrigBitWidth - NumSignBits;
+      if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
+        ++BitWidth;
+      if (BitWidth <= MaxBitWidth)
+        ToDemote.push_back(V);
     }
     RootDemotes.clear();
     IsTopRoot = false;
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
index 1cce52060c479f..866afeea50108c 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/getelementptr2.ll
@@ -14,7 +14,7 @@
 ; YAML-NEXT:  Function:        test_i16_extend
 ; YAML-NEXT:  Args:
 ; YAML-NEXT:    - String:          'SLP vectorized with cost '
-; YAML-NEXT:    - Cost:            '-20'
+; YAML-NEXT:    - Cost:            '-16'
 ; YAML-NEXT:    - String:          ' and with tree size '
 ; YAML-NEXT:    - TreeSize:        '5'
 ; YAML-NEXT:  ...
diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
index 7c5f9847db1f41..21d4383b3e3563 100644
--- a/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
+++ b/llvm/test/Transforms/SLPVectorizer/AArch64/horizontal.ll
@@ -228,7 +228,7 @@ for.end:                                          ; preds = %for.end.loopexit, %
 ; YAML-NEXT: Function:        test_unrolled_select
 ; YAML-NEXT: Args:
 ; YAML-NEXT:   - String:          'Vectorized horizontal reduction with cost '
-; YAML-NEXT:   - Cost:            '-41'
+; YAML-NEXT:   - Cost:            '-39'
 ; YAML-NEXT:   - String:          ' and with tree size '
 ; YAML-NEXT:   - TreeSize:        '10'
 
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
index 436fba3261d602..1166b1fca826b6 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/init-ext-node-not-truncable.ll
@@ -7,7 +7,7 @@ define void @test() {
 ; CHECK-LABEL: define void @test(
 ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    store <2 x i64> <i64 -1, i64 0>, ptr @h, align 8
+; CHECK-NEXT:    store <2 x i64> <i64 4294967295, i64 0>, ptr @h, align 8
 ; CHECK-NEXT:    ret void
 ;
 entry:

``````````

</details>


https://github.com/llvm/llvm-project/pull/87054


More information about the llvm-commits mailing list