[llvm] af1e59a - [SLP]Fix PR107037: correctly track origonal/modified after vectorizations reduced values

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 4 04:54:42 PDT 2024


Author: Alexey Bataev
Date: 2024-09-04T04:54:32-07:00
New Revision: af1e59aea2ea7d07ece1f34621dda38c995826a3

URL: https://github.com/llvm/llvm-project/commit/af1e59aea2ea7d07ece1f34621dda38c995826a3
DIFF: https://github.com/llvm/llvm-project/commit/af1e59aea2ea7d07ece1f34621dda38c995826a3.diff

LOG: [SLP]Fix PR107037: correctly track origonal/modified after vectorizations reduced values

Need to correctly track reduced values with multiple uses in the same
reduction emission attempt. Otherwise, the number of the reuses might be
calculated incorrectly, and may cause compiler crash.

Fixes https://github.com/llvm/llvm-project/issues/107037

Added: 
    llvm/test/Transforms/SLPVectorizer/X86/multi-tracked-reduced-value.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a7272deb3c34fb..19b95cf473e914 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -17721,6 +17721,12 @@ class HorizontalReduction {
       for (Value *V : Candidates)
         TrackedVals.try_emplace(V, V);
 
+    auto At = [](MapVector<Value *, unsigned> &MV, Value *V) -> unsigned & {
+      auto *It = MV.find(V);
+      assert(It != MV.end() && "Unable to find given key.");
+      return It->second;
+    };
+
     DenseMap<Value *, unsigned> VectorizedVals(ReducedVals.size());
     // List of the values that were reduced in other trees as part of gather
     // nodes and thus requiring extract if fully vectorized in other trees.
@@ -17738,7 +17744,7 @@ class HorizontalReduction {
       Candidates.reserve(2 * OrigReducedVals.size());
       DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size());
       for (unsigned Cnt = 0, Sz = OrigReducedVals.size(); Cnt < Sz; ++Cnt) {
-        Value *RdxVal = TrackedVals.find(OrigReducedVals[Cnt])->second;
+        Value *RdxVal = TrackedVals.at(OrigReducedVals[Cnt]);
         // Check if the reduction value was not overriden by the extractelement
         // instruction because of the vectorization and exclude it, if it is not
         // compatible with other values.
@@ -17757,7 +17763,7 @@ class HorizontalReduction {
           I + 1 < E) {
         SmallVector<Value *> CommonCandidates(Candidates);
         for (Value *RV : ReducedVals[I + 1]) {
-          Value *RdxVal = TrackedVals.find(RV)->second;
+          Value *RdxVal = TrackedVals.at(RV);
           // Check if the reduction value was not overriden by the
           // extractelement instruction because of the vectorization and
           // exclude it, if it is not compatible with other values.
@@ -17778,10 +17784,12 @@ class HorizontalReduction {
       // Emit code for constant values.
       if (Candidates.size() > 1 && allConstant(Candidates)) {
         Value *Res = Candidates.front();
-        ++VectorizedVals.try_emplace(Candidates.front(), 0).first->getSecond();
+        Value *OrigV = TrackedToOrig.at(Candidates.front());
+        ++VectorizedVals.try_emplace(OrigV).first->getSecond();
         for (Value *VC : ArrayRef(Candidates).drop_front()) {
           Res = createOp(Builder, RdxKind, Res, VC, "const.rdx", ReductionOps);
-          ++VectorizedVals.try_emplace(VC, 0).first->getSecond();
+          Value *OrigV = TrackedToOrig.at(VC);
+          ++VectorizedVals.try_emplace(OrigV).first->getSecond();
           if (auto *ResI = dyn_cast<Instruction>(Res))
             V.analyzedReductionRoot(ResI);
         }
@@ -17802,8 +17810,10 @@ class HorizontalReduction {
       // Gather same values.
       MapVector<Value *, unsigned> SameValuesCounter;
       if (IsSupportedHorRdxIdentityOp)
-        for (Value *V : Candidates)
-          ++SameValuesCounter.insert(std::make_pair(V, 0)).first->second;
+        for (Value *V : Candidates) {
+          Value *OrigV = TrackedToOrig.at(V);
+          ++SameValuesCounter.try_emplace(OrigV).first->second;
+        }
       // Used to check if the reduced values used same number of times. In this
       // case the compiler may produce better code. E.g. if reduced values are
       // aabbccdd (8 x values), then the first node of the tree will have a node
@@ -17827,12 +17837,12 @@ class HorizontalReduction {
                    });
         Candidates.resize(SameValuesCounter.size());
         transform(SameValuesCounter, Candidates.begin(),
-                  [](const auto &P) { return P.first; });
+                  [&](const auto &P) { return TrackedVals.at(P.first); });
         NumReducedVals = Candidates.size();
         // Have a reduction of the same element.
         if (NumReducedVals == 1) {
-          Value *OrigV = TrackedToOrig.find(Candidates.front())->second;
-          unsigned Cnt = SameValuesCounter.lookup(OrigV);
+          Value *OrigV = TrackedToOrig.at(Candidates.front());
+          unsigned Cnt = At(SameValuesCounter, OrigV);
           Value *RedVal =
               emitScaleForReusedOps(Candidates.front(), Builder, Cnt);
           VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal);
@@ -17936,8 +17946,8 @@ class HorizontalReduction {
             if (Cnt >= Pos && Cnt < Pos + ReduxWidth)
               continue;
             Value *V = Candidates[Cnt];
-            Value *OrigV = TrackedToOrig.find(V)->second;
-            ++SameValuesCounter[OrigV];
+            Value *OrigV = TrackedToOrig.at(V);
+            ++SameValuesCounter.try_emplace(OrigV).first->second;
           }
         }
         SmallPtrSet<Value *, 4> VLScalars(VL.begin(), VL.end());
@@ -17955,10 +17965,10 @@ class HorizontalReduction {
             LocalExternallyUsedValues[RdxVal];
             continue;
           }
-          Value *OrigV = TrackedToOrig.find(RdxVal)->second;
+          Value *OrigV = TrackedToOrig.at(RdxVal);
           unsigned NumOps =
-              VectorizedVals.lookup(RdxVal) + SameValuesCounter[OrigV];
-          if (NumOps != ReducedValsToOps.find(OrigV)->second.size())
+              VectorizedVals.lookup(OrigV) + At(SameValuesCounter, OrigV);
+          if (NumOps != ReducedValsToOps.at(OrigV).size())
             LocalExternallyUsedValues[RdxVal];
         }
         // Do not need the list of reused scalars in regular mode anymore.
@@ -17983,9 +17993,8 @@ class HorizontalReduction {
           break;
         if (Cost >= -SLPCostThreshold) {
           V.getORE()->emit([&]() {
-            return OptimizationRemarkMissed(
-                       SV_NAME, "HorSLPNotBeneficial",
-                       ReducedValsToOps.find(VL[0])->second.front())
+            return OptimizationRemarkMissed(SV_NAME, "HorSLPNotBeneficial",
+                                            ReducedValsToOps.at(VL[0]).front())
                    << "Vectorizing horizontal reduction is possible "
                    << "but not beneficial with cost " << ore::NV("Cost", Cost)
                    << " and threshold "
@@ -17999,9 +18008,8 @@ class HorizontalReduction {
         LLVM_DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:"
                           << Cost << ". (HorRdx)\n");
         V.getORE()->emit([&]() {
-          return OptimizationRemark(
-                     SV_NAME, "VectorizedHorizontalReduction",
-                     ReducedValsToOps.find(VL[0])->second.front())
+          return OptimizationRemark(SV_NAME, "VectorizedHorizontalReduction",
+                                    ReducedValsToOps.at(VL[0]).front())
                  << "Vectorized horizontal reduction with cost "
                  << ore::NV("Cost", Cost) << " and with tree size "
                  << ore::NV("TreeSize", V.getTreeSize());
@@ -18083,12 +18091,12 @@ class HorizontalReduction {
         VectorizedTree = GetNewVectorizedTree(VectorizedTree, ReducedSubTree);
         // Count vectorized reduced values to exclude them from final reduction.
         for (Value *RdxVal : VL) {
-          Value *OrigV = TrackedToOrig.find(RdxVal)->second;
+          Value *OrigV = TrackedToOrig.at(RdxVal);
           if (IsSupportedHorRdxIdentityOp) {
-            VectorizedVals.try_emplace(OrigV, SameValuesCounter[RdxVal]);
+            VectorizedVals.try_emplace(OrigV, At(SameValuesCounter, OrigV));
             continue;
           }
-          ++VectorizedVals.try_emplace(OrigV, 0).first->getSecond();
+          ++VectorizedVals.try_emplace(OrigV).first->getSecond();
           if (!V.isVectorized(RdxVal))
             RequiredExtract.insert(RdxVal);
         }
@@ -18099,10 +18107,10 @@ class HorizontalReduction {
       }
       if (OptReusedScalars && !AnyVectorized) {
         for (const std::pair<Value *, unsigned> &P : SameValuesCounter) {
-          Value *RedVal = emitScaleForReusedOps(P.first, Builder, P.second);
+          Value *RdxVal = TrackedVals.at(P.first);
+          Value *RedVal = emitScaleForReusedOps(RdxVal, Builder, P.second);
           VectorizedTree = GetNewVectorizedTree(VectorizedTree, RedVal);
-          Value *OrigV = TrackedToOrig.find(P.first)->second;
-          VectorizedVals.try_emplace(OrigV, P.second);
+          VectorizedVals.try_emplace(P.first, P.second);
         }
         continue;
       }
@@ -18190,8 +18198,7 @@ class HorizontalReduction {
             continue;
           unsigned NumOps = VectorizedVals.lookup(RdxVal);
           for (Instruction *RedOp :
-               ArrayRef(ReducedValsToOps.find(RdxVal)->second)
-                   .drop_back(NumOps))
+               ArrayRef(ReducedValsToOps.at(RdxVal)).drop_back(NumOps))
             ExtraReductions.emplace_back(RedOp, RdxVal);
         }
       }
@@ -18430,7 +18437,7 @@ class HorizontalReduction {
       // root = mul prev_root, <1, 1, n, 1>
       SmallVector<Constant *> Vals;
       for (Value *V : VL) {
-        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.at(V));
         Vals.push_back(ConstantInt::get(V->getType(), Cnt, /*IsSigned=*/false));
       }
       auto *Scale = ConstantVector::get(Vals);
@@ -18468,7 +18475,7 @@ class HorizontalReduction {
       bool NeedShuffle = false;
       for (unsigned I = 0, VF = VL.size(); I < VF; ++I) {
         Value *V = VL[I];
-        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.at(V));
         if (Cnt % 2 == 0) {
           Mask[I] = VF;
           NeedShuffle = true;
@@ -18488,7 +18495,7 @@ class HorizontalReduction {
       // root = fmul prev_root, <1.0, 1.0, n.0, 1.0>
       SmallVector<Constant *> Vals;
       for (Value *V : VL) {
-        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.find(V)->second);
+        unsigned Cnt = SameValuesCounter.lookup(TrackedToOrig.at(V));
         Vals.push_back(ConstantFP::get(V->getType(), Cnt));
       }
       auto *Scale = ConstantVector::get(Vals);

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/multi-tracked-reduced-value.ll b/llvm/test/Transforms/SLPVectorizer/X86/multi-tracked-reduced-value.ll
new file mode 100644
index 00000000000000..e012cc60960b3c
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/multi-tracked-reduced-value.ll
@@ -0,0 +1,47 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=slp-vectorizer -mtriple=x86_64-unknown-linux-gnu < %s | FileCheck %s
+
+define i8 @test() {
+; CHECK-LABEL: define i8 @test() {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i32 0 to i8
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 0 to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 0 to i8
+; CHECK-NEXT:    [[TMP3:%.*]] = trunc i32 0 to i8
+; CHECK-NEXT:    [[TMP4:%.*]] = call i8 @llvm.vector.reduce.or.v8i8(<8 x i8> zeroinitializer)
+; CHECK-NEXT:    [[OP_RDX:%.*]] = or i8 [[TMP4]], [[TMP0]]
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = or i8 [[OP_RDX]], [[TMP2]]
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = or i8 [[OP_RDX1]], [[TMP0]]
+; CHECK-NEXT:    [[OP_RDX3:%.*]] = or i8 [[OP_RDX2]], [[TMP1]]
+; CHECK-NEXT:    [[OP_RDX4:%.*]] = or i8 [[OP_RDX3]], [[TMP3]]
+; CHECK-NEXT:    ret i8 [[OP_RDX4]]
+;
+entry:
+  %0 = trunc i32 0 to i8
+  %1 = add i8 %0, 0
+  %2 = add i8 %0, 0
+  %3 = add i8 %0, 0
+  %4 = add i8 %0, 0
+  %5 = trunc i32 0 to i8
+  %6 = or i8 %5, %0
+  %7 = or i8 %6, %2
+  %8 = or i8 %7, %3
+  %9 = or i8 %8, %0
+  %10 = or i8 %9, %4
+  %conv4 = or i8 %10, %1
+  %11 = trunc i32 0 to i8
+  %12 = add i8 %11, 0
+  %conv7 = or i8 %conv4, %12
+  %13 = add i8 %11, 0
+  %14 = add i8 %11, 0
+  %15 = add i8 %11, 0
+  %16 = trunc i32 0 to i8
+  %17 = or i8 %13, %16
+  %18 = or i8 %17, %14
+  %19 = or i8 %18, %11
+  %20 = or i8 %19, %15
+  %conv5 = or i8 %20, %conv7
+  %21 = trunc i32 0 to i8
+  %conv6 = or i8 %21, %conv5
+  ret i8 %conv6
+}


        


More information about the llvm-commits mailing list