[llvm] [LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext))) patterns more reliably (PR #115847)

David Green via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 02:14:49 PST 2024


https://github.com/davemgreen created https://github.com/llvm/llvm-project/pull/115847

We would detect ext(mul(ext, ext)) patterns when looking up through the tree, but not when looking down. This hopefully brings the cost model closer to the vplan version, avoiding some asserts and reducing the diffs needed in #113903.

>From f7bd42b79c1f6f09147231abdf8fe96352cb4127 Mon Sep 17 00:00:00 2001
From: David Green <david.green at arm.com>
Date: Tue, 12 Nov 2024 10:11:13 +0000
Subject: [PATCH] [LoopVectorizer][ARM] Detect reduce(ext(mul(ext, ext)))
 patterns more reliably.

We would detect ext(mul(ext, ext)) patterns when looking up through the tree,
but not when looking down. This hopefully brings the cost model closer to the
vplan version, avoiding some asserts.
---
 .../Transforms/Vectorize/LoopVectorize.cpp    | 15 +++++++-
 .../LoopVectorize/ARM/mve-reductions.ll       | 38 +++++++++----------
 2 files changed, 33 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1ebc62f9843905..568aeae2260f11 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -5818,6 +5818,15 @@ LoopVectorizationCostModel::getReductionPatternCost(
   if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
       RetI->user_back()->getOpcode() == Instruction::Add) {
     RetI = RetI->user_back();
+  } else if (match(RetI, m_OneUse(m_Mul(m_Value(), m_Value()))) &&
+             ((match(I, m_ZExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_ZExt(m_Value())))) ||
+              (match(I, m_SExt(m_Value())) &&
+               match(RetI->user_back(), m_OneUse(m_SExt(m_Value()))))) &&
+             RetI->user_back()->user_back()->getOpcode() == Instruction::Add) {
+    // This looks through ext(mul(ext, ext)), making sure that the extensions
+    // are the same sign.
+    RetI = RetI->user_back()->user_back();
   }
 
   // Test if the found instruction is a reduction, and if not return an invalid
@@ -7316,7 +7325,7 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
     // Also include the operands of instructions in the chain, as the cost-model
     // may mark extends as free.
     //
-    // For ARM, some of the instruction can folded into the reducion
+    // For ARM, some of the instructions can be folded into the reduction
     // instruction. So we need to mark all folded instructions free.
     // For example: We can fold reduce(mul(ext(A), ext(B))) into one
     // instruction.
@@ -7324,6 +7333,10 @@ LoopVectorizationPlanner::precomputeCosts(VPlan &Plan, ElementCount VF,
       for (Value *Op : ChainOp->operands()) {
         if (auto *I = dyn_cast<Instruction>(Op)) {
           ChainOpsAndOperands.insert(I);
+          if (IsZExtOrSExt(I->getOpcode())) {
+            ChainOpsAndOperands.insert(I);
+            I = dyn_cast<Instruction>(I->getOperand(0));
+          }
           if (I->getOpcode() == Instruction::Mul) {
             auto *Ext0 = dyn_cast<Instruction>(I->getOperand(0));
             auto *Ext1 = dyn_cast<Instruction>(I->getOperand(1));
diff --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index c115c91cff896c..a4f96adccb64b5 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1722,10 +1722,10 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[TMP0:%.*]] = add nsw i32 [[N]], -1
 ; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = add nuw i32 [[TMP1]], 1
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 7
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 15
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -4
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP2]], -8
 ; CHECK-NEXT:    [[IND_END:%.*]] = shl i32 [[N_VEC]], 1
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
@@ -1733,26 +1733,26 @@ define i64 @test_fir_q15(ptr %x, ptr %y, i32 %n) #0 {
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP16:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[OFFSET_IDX:%.*]] = shl i32 [[INDEX]], 1
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, ptr [[X:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <8 x i16>, ptr [[TMP3]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <8 x i16> [[WIDE_VEC]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[STRIDED_VEC]] to <4 x i32>
+; CHECK-NEXT:    [[WIDE_VEC:%.*]] = load <16 x i16>, ptr [[TMP3]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC1:%.*]] = shufflevector <16 x i16> [[WIDE_VEC]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[STRIDED_VEC]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[Y:%.*]], i32 [[OFFSET_IDX]]
-; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <8 x i16>, ptr [[TMP4]], align 2
-; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
-; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <8 x i16> [[WIDE_VEC2]], <8 x i16> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
-; CHECK-NEXT:    [[TMP6:%.*]] = sext <4 x i16> [[STRIDED_VEC3]] to <4 x i32>
-; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <4 x i32> [[TMP6]], [[TMP5]]
-; CHECK-NEXT:    [[TMP8:%.*]] = sext <4 x i32> [[TMP7]] to <4 x i64>
-; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i16> [[STRIDED_VEC1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i16> [[STRIDED_VEC4]] to <4 x i32>
-; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <4 x i32> [[TMP14]], [[TMP13]]
-; CHECK-NEXT:    [[TMP12:%.*]] = sext <4 x i32> [[TMP11]] to <4 x i64>
-; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP8]])
+; CHECK-NEXT:    [[WIDE_VEC2:%.*]] = load <16 x i16>, ptr [[TMP4]], align 2
+; CHECK-NEXT:    [[STRIDED_VEC3:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+; CHECK-NEXT:    [[STRIDED_VEC4:%.*]] = shufflevector <16 x i16> [[WIDE_VEC2]], <16 x i16> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+; CHECK-NEXT:    [[TMP6:%.*]] = sext <8 x i16> [[STRIDED_VEC3]] to <8 x i32>
+; CHECK-NEXT:    [[TMP7:%.*]] = mul nsw <8 x i32> [[TMP6]], [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = sext <8 x i32> [[TMP7]] to <8 x i64>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <8 x i16> [[STRIDED_VEC1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <8 x i16> [[STRIDED_VEC4]] to <8 x i32>
+; CHECK-NEXT:    [[TMP11:%.*]] = mul nsw <8 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP12:%.*]] = sext <8 x i32> [[TMP11]] to <8 x i64>
+; CHECK-NEXT:    [[TMP9:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP8]])
 ; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[TMP9]], [[VEC_PHI]]
-; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP12]])
+; CHECK-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP12]])
 ; CHECK-NEXT:    [[TMP16]] = add i64 [[TMP15]], [[TMP10]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP17:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP37:![0-9]+]]
 ; CHECK:       middle.block:



More information about the llvm-commits mailing list