[llvm] [SLP][REVEC] Make tryToReduce and related functions support vector instructions. (PR #102327)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 7 09:19:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/102327.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+36-5) 
- (modified) llvm/test/Transforms/SLPVectorizer/revec.ll (+75) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 4186b17e644b0..7f96c91810910 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -17710,8 +17710,25 @@ class HorizontalReduction {
                                          SameValuesCounter, TrackedToOrig);
         }
 
-        Value *ReducedSubTree =
-            emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
+        Value *ReducedSubTree;
+        Type *ScalarTy = VL.front()->getType();
+        if (isa<FixedVectorType>(ScalarTy)) {
+          assert(SLPReVec && "FixedVectorType is not expected.");
+          unsigned ScalarTyNumElements = getNumElements(ScalarTy);
+          ReducedSubTree = PoisonValue::get(FixedVectorType::get(
+              VectorizedRoot->getType()->getScalarType(), ScalarTyNumElements));
+          for (unsigned I = 0; I != ScalarTyNumElements; ++I) {
+            // Do reduction for each lane.
+            SmallVector<int, 16> Mask =
+                createStrideMask(I, ScalarTyNumElements, VL.size());
+            Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
+            ReducedSubTree = Builder.CreateInsertElement(
+                ReducedSubTree, emitReduction(Lane, Builder, ReduxWidth, TTI),
+                I);
+          }
+        } else
+          ReducedSubTree =
+              emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
         if (ReducedSubTree->getType() != VL.front()->getType()) {
           assert(ReducedSubTree->getType() != VL.front()->getType() &&
                  "Expected different reduction type.");
@@ -17939,9 +17956,23 @@ class HorizontalReduction {
     case RecurKind::FAdd:
     case RecurKind::FMul: {
       unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
-      if (!AllConsts)
-        VectorCost =
-            TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind);
+      if (!AllConsts) {
+        if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+          assert(SLPReVec && "FixedVectorType is not expected.");
+          unsigned ScalarTyNumElements = VecTy->getNumElements();
+          for (unsigned I = 0, End = ReducedVals.size(); I != End; ++I) {
+            VectorCost += TTI->getShuffleCost(
+                TTI::SK_PermuteSingleSrc, VectorTy,
+                createStrideMask(I, ScalarTyNumElements, End));
+            VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy, FMF,
+                                                          CostKind);
+            VectorCost += TTI->getVectorInstrCost(
+                Instruction::InsertElement, VecTy, TTI::TCK_RecipThroughput, I);
+          }
+        } else
+          VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
+                                                       CostKind);
+      }
       ScalarCost = EvaluateScalarCost([&]() {
         return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
       });
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index d6dd4128de9c7..39fc906535948 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -124,3 +124,78 @@ entry:
   store <8 x i1> %6, ptr %7, align 1
   ret void
 }
+
+define <4 x i1> @test5(ptr %in1, ptr %in2) {
+; CHECK-LABEL: @test5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[IN1:%.*]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i16>, ptr [[IN2:%.*]], align 2
+; CHECK-NEXT:    [[TMP2:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> poison, <4 x i32> poison, i64 4)
+; CHECK-NEXT:    [[TMP3:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP2]], <4 x i32> poison, i64 8)
+; CHECK-NEXT:    [[TMP4:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP3]], <4 x i32> poison, i64 12)
+; CHECK-NEXT:    [[TMP5:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP4]], <4 x i32> [[TMP0]], i64 0)
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <16 x i32> [[TMP5]], <16 x i32> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP7:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> poison, <4 x i32> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP8:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP7]], <4 x i32> zeroinitializer, i64 4)
+; CHECK-NEXT:    [[TMP9:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP8]], <4 x i32> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP10:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP9]], <4 x i32> zeroinitializer, i64 12)
+; CHECK-NEXT:    [[TMP11:%.*]] = icmp ugt <16 x i32> [[TMP6]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> poison, <4 x i16> poison, i64 4)
+; CHECK-NEXT:    [[TMP13:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP12]], <4 x i16> poison, i64 8)
+; CHECK-NEXT:    [[TMP14:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP13]], <4 x i16> poison, i64 12)
+; CHECK-NEXT:    [[TMP15:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP14]], <4 x i16> [[TMP1]], i64 0)
+; CHECK-NEXT:    [[TMP16:%.*]] = shufflevector <16 x i16> [[TMP15]], <16 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP17:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> poison, <4 x i16> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP18:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP17]], <4 x i16> zeroinitializer, i64 4)
+; CHECK-NEXT:    [[TMP19:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP18]], <4 x i16> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP20:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP19]], <4 x i16> zeroinitializer, i64 12)
+; CHECK-NEXT:    [[TMP21:%.*]] = icmp eq <16 x i16> [[TMP16]], [[TMP20]]
+; CHECK-NEXT:    [[TMP22:%.*]] = and <16 x i1> [[TMP11]], [[TMP21]]
+; CHECK-NEXT:    [[TMP23:%.*]] = icmp ugt <16 x i32> [[TMP6]], [[TMP10]]
+; CHECK-NEXT:    [[TMP24:%.*]] = and <16 x i1> [[TMP22]], [[TMP23]]
+; CHECK-NEXT:    [[TMP25:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 0, i32 4, i32 8, i32 12>
+; CHECK-NEXT:    [[TMP26:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP25]])
+; CHECK-NEXT:    [[TMP27:%.*]] = insertelement <4 x i1> poison, i1 [[TMP26]], i64 0
+; CHECK-NEXT:    [[TMP28:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 1, i32 5, i32 9, i32 13>
+; CHECK-NEXT:    [[TMP29:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP28]])
+; CHECK-NEXT:    [[TMP30:%.*]] = insertelement <4 x i1> [[TMP27]], i1 [[TMP29]], i64 1
+; CHECK-NEXT:    [[TMP31:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 2, i32 6, i32 10, i32 14>
+; CHECK-NEXT:    [[TMP32:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP31]])
+; CHECK-NEXT:    [[TMP33:%.*]] = insertelement <4 x i1> [[TMP30]], i1 [[TMP32]], i64 2
+; CHECK-NEXT:    [[TMP34:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 3, i32 7, i32 11, i32 15>
+; CHECK-NEXT:    [[TMP35:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP34]])
+; CHECK-NEXT:    [[TMP36:%.*]] = insertelement <4 x i1> [[TMP33]], i1 [[TMP35]], i64 3
+; CHECK-NEXT:    [[VBSL:%.*]] = select <4 x i1> [[TMP36]], <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt <4 x i32> [[VBSL]], <i32 2, i32 3, i32 4, i32 5>
+; CHECK-NEXT:    ret <4 x i1> [[CMP]]
+;
+entry:
+  %0 = load <4 x i32>, ptr %in1, align 4
+  %1 = load <4 x i16>, ptr %in2, align 2
+  %cmp000 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp001 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp002 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp003 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp100 = icmp eq <4 x i16> %1, zeroinitializer
+  %cmp101 = icmp eq <4 x i16> %1, zeroinitializer
+  %cmp102 = icmp eq <4 x i16> %1, zeroinitializer
+  %cmp103 = icmp eq <4 x i16> %1, zeroinitializer
+  %and.cmp0 = and <4 x i1> %cmp000, %cmp100
+  %and.cmp1 = and <4 x i1> %cmp001, %cmp101
+  %and.cmp2 = and <4 x i1> %cmp002, %cmp102
+  %and.cmp3 = and <4 x i1> %cmp003, %cmp103
+  %cmp004 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp005 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp006 = icmp ugt <4 x i32> %0, zeroinitializer
+  %cmp007 = icmp ugt <4 x i32> %0, zeroinitializer
+  %and.cmp4 = and <4 x i1> %and.cmp0, %cmp004
+  %and.cmp5 = and <4 x i1> %and.cmp1, %cmp005
+  %and.cmp6 = and <4 x i1> %and.cmp2, %cmp006
+  %and.cmp7 = and <4 x i1> %and.cmp3, %cmp007
+  %or0 = or <4 x i1> %and.cmp5, %and.cmp4
+  %or1 = or <4 x i1> %or0, %and.cmp6
+  %or2 = or <4 x i1> %or1, %and.cmp7
+  %vbsl = select <4 x i1> %or2, <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>
+  %cmp = icmp ugt <4 x i32> %vbsl, <i32 2, i32 3, i32 4, i32 5>
+  ret <4 x i1> %cmp
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/102327


More information about the llvm-commits mailing list