[llvm] [LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext))) patterns more reliably (PR #115847)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 02:15:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #<!-- -->113903.

---
Full diff: https://github.com/llvm/llvm-project/pull/115847.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+14-1) 
- (modified) llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll (+19-19) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..568aeae2260f11 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost(
   if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
       RetI->user_back()->getOpcode() == Instruction::Add) {
     RetI = RetI->user_back();
+  } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
+             ((match(I, m_ZExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) ||
+              (match(I, m_SExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) &&
+             RetI->user_back()->user_back()->getOpcode() == Instruction::Add) {
+    // This looks through ext(mul(ext, ext)), making sure that the extensions
+    // are the same sign.
+    RetI = RetI->user_back()->user_back();
   }
 
   // Test if the found instruction is a reduction, and if not return an invalid
@@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
     //
-    // For ARM, some of the instruction can folded into the reducion
+    // For ARM, some of the instructions can be folded into the reduction
     // instruction. So we need to mark all folded instructions free.
     // For example: We can fold reduce(mul(ext(A), ext(B))) into one
     // instruction.
@@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       for (Value *Op : ChainOp->operands()) {
         if (auto *I = dyn_cast<Instruction>(Op)) {
           ChainOpsAndOperands.insert(I);
+          if (IsZExtOrSExt(I->getOpcode())) {
+            ChainOpsAndOperands.insert(I);
+            I = dyn_cast<Instruction>(I->getOperand(0));
+          }
           if (I->getOpcode() == Instruction::Mul) {
             auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
             auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index c115c91cff896c..a4f96adccb64b5 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
 ; CHECK:       middle.block:

``````````

</details>


https://github.com/llvm/llvm-project/pull/115847


More information about the llvm-commits mailing list