[llvm] fed3041 - [LV][ARM] Improve reduction costmodel for mismatching extension types.

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 10 07:41:03 PST 2021


Author: David Green
Date: 2021-12-10T15:40:58Z
New Revision: fed3041863eba8de6a5ba15a38be38ca8114faf8

URL: https://github.com/llvm/llvm-project/commit/fed3041863eba8de6a5ba15a38be38ca8114faf8
DIFF: https://github.com/llvm/llvm-project/commit/fed3041863eba8de6a5ba15a38be38ca8114faf8.diff

LOG: [LV][ARM] Improve reduction costmodel for mismatching extension types.

Given a MLA reduction from two different types (say i8 and i16), we were
previously failing to find the reduction pattern, often making us chose
the lower vector factor. This improves that by using the largest of the
two extension types, allowing us to use the larger VF as the type of the
reduction.

As per https://godbolt.org/z/KP549EEYM the backend handles this
valiantly, leading to better performance.

Differential Revision: https://reviews.llvm.org/D115432

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a96c844652dba..cc439a924d4bc 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7006,22 +7006,41 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
              match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) {
     if (match(Op0, m_ZExtOrSExt(m_Value())) &&
         Op0->getOpcode() == Op1->getOpcode() &&
-        Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
         !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1)) {
       bool IsUnsigned = isa<ZExtInst>(Op0);
-      auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy);
-      // Matched reduce(mul(ext, ext))
-      InstructionCost ExtCost =
-          TTI.getCastInstrCost(Op0->getOpcode(), VectorTy, ExtType,
-                               TTI::CastContextHint::None, CostKind, Op0);
+      Type *Op0Ty = Op0->getOperand(0)->getType();
+      Type *Op1Ty = Op1->getOperand(0)->getType();
+      Type *LargestOpTy =
+          Op0Ty->getIntegerBitWidth() < Op1Ty->getIntegerBitWidth() ? Op1Ty
+                                                                    : Op0Ty;
+      auto *ExtType = VectorType::get(LargestOpTy, VectorTy);
+
+      // Matched reduce(mul(ext(A), ext(B))), where the two ext may be of
+      // 
diff erent sizes. We take the largest type as the ext to reduce, and add
+      // the remaining cost as, for example reduce(mul(ext(ext(A)), ext(B))).
+      InstructionCost ExtCost0 = TTI.getCastInstrCost(
+          Op0->getOpcode(), VectorTy, VectorType::get(Op0Ty, VectorTy),
+          TTI::CastContextHint::None, CostKind, Op0);
+      InstructionCost ExtCost1 = TTI.getCastInstrCost(
+          Op1->getOpcode(), VectorTy, VectorType::get(Op1Ty, VectorTy),
+          TTI::CastContextHint::None, CostKind, Op1);
       InstructionCost MulCost =
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
       InstructionCost RedCost = TTI.getExtendedAddReductionCost(
           /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
           CostKind);
+      InstructionCost ExtraExtCost = 0;
+      if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
+        Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
+        ExtraExtCost = TTI.getCastInstrCost(
+            ExtraExtOp->getOpcode(), ExtType,
+            VectorType::get(ExtraExtOp->getOperand(0)->getType(), VectorTy),
+            TTI::CastContextHint::None, CostKind, ExtraExtOp);
+      }
 
-      if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost)
+      if (RedCost.isValid() &&
+          (RedCost + ExtraExtCost) < (ExtCost0 + ExtCost1 + MulCost + BaseCost))
         return I == RetI ? RedCost : 0;
     } else if (!match(I, m_ZExtOrSExt(m_Value()))) {
       // Matched reduce(mul())

diff  --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index e66fbede57b7b..a699304abf9df 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -1133,33 +1133,33 @@ for.cond.cleanup:                                 ; preds = %for.body, %entry
   ret i8 %r.0.lcssa
 }
 
-; 4x or 8x as 
diff erent types
+; 8x as 
diff erent types
 define i32 @red_mla_ext_s8_s16_s32(i8* noalias nocapture readonly %A, i16* noalias nocapture readonly %B, i32 %n) #0 {
 ; CHECK-LABEL: @red_mla_ext_s8_s16_s32(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 3
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -4
+; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 7
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -8
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <4 x i8>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP1]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <8 x i8>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0v8i8(<8 x i8>* [[TMP1]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, i16* [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <4 x i16>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i16> @llvm.masked.load.v4i16.p0v4i16(<4 x i16>* [[TMP4]], i32 2, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i16> poison)
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[WIDE_MASKED_LOAD1]] to <4 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]]
-; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <8 x i16>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>* [[TMP4]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP6]], <8 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
 ; CHECK-NEXT:    [[TMP9]] = add i32 [[TMP8]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 4
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[FOR_COND_CLEANUP]], label [[VECTOR_BODY]], !llvm.loop [[LOOP26:![0-9]+]]
 ; CHECK:       for.cond.cleanup:


        


More information about the llvm-commits mailing list