[llvm] 41cedb1 - [LV][ARM] Tighten up MLA reduction costing

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 28 04:51:19 PDT 2021


Author: David Green
Date: 2021-07-28T12:50:58+01:00
New Revision: 41cedb1c9a380628ac162bf76148cbd143f41450

URL: https://github.com/llvm/llvm-project/commit/41cedb1c9a380628ac162bf76148cbd143f41450
DIFF: https://github.com/llvm/llvm-project/commit/41cedb1c9a380628ac162bf76148cbd143f41450.diff

LOG: [LV][ARM] Tighten up MLA reduction costing

This makes a couple of changes to the costing of MLA reduction patterns,
to more accurately cost various patterns that can come up from
vectorization.

 - The Arm implementation of getExtendedAddReductionCost is altered to
   only provide costs for legal or smaller types. Larger than legal types
   need to be split, which currently does not work very well, especially
   for predicated reductions where the predicate may be legal but needs to
   be split. Currently we limit it to legal or smaller input types.
 - The getReductionPatternCost has learnt that reduce(ext(mul(ext, ext))
   is a pattern that can come up, and can be treated the same as
   reduce(mul(ext, ext)) providing the extension types match.
 - And it has been adjusted to not count the ext in reduce(mul(ext, ext))
   as part of a reduce(mul) pattern.

Together these changes help to more accurately cost the mla reductions
in cases such as where the extend types don't match or the extend
opcodes are different, picking better vector factors that don't result
in expanded reductions.

Differential Revision: https://reviews.llvm.org/D106166

Added: 
    

Modified: 
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index cf7456e9e4f5..92892d3abdc6 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1623,13 +1623,24 @@ ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
                                         TTI::TargetCostKind CostKind) {
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
+
   if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
     std::pair<InstructionCost, MVT> LT =
         TLI->getTypeLegalizationCost(DL, ValTy);
-    if ((LT.second == MVT::v16i8 && ResVT.getSizeInBits() <= 32) ||
-        (LT.second == MVT::v8i16 &&
-         ResVT.getSizeInBits() <= (IsMLA ? 64 : 32)) ||
-        (LT.second == MVT::v4i32 && ResVT.getSizeInBits() <= 64))
+
+    // The legal cases are:
+    //   VADDV u/s 8/16/32
+    //   VMLAV u/s 8/16/32
+    //   VADDLV u/s 32
+    //   VMLALV u/s 16/32
+    // Codegen currently cannot always handle larger than legal vectors very
+    // well, especially for predicated reductions where the mask needs to be
+    // split, so restrict to 128bit or smaller input types.
+    unsigned RevVTSize = ResVT.getSizeInBits();
+    if (ValVT.getSizeInBits() <= 128 &&
+        ((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
+         (LT.second == MVT::v8i16 && RevVTSize <= (IsMLA ? 64 : 32)) ||
+         (LT.second == MVT::v4i32 && RevVTSize <= 64)))
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index f24ae6b100d5..a3bd2e3054ca 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7210,8 +7210,41 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
   VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy);
 
   Instruction *Op0, *Op1;
-  if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) &&
-      !TheLoop->isLoopInvariant(RedOp)) {
+  if (RedOp &&
+      match(RedOp,
+            m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+      match(Op0, m_ZExtOrSExt(m_Value())) &&
+      Op0->getOpcode() == Op1->getOpcode() &&
+      Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+      !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1) &&
+      (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+
+    // Matched reduce(ext(mul(ext(A), ext(B)))
+    // Note that the extend opcodes need to all match, or if A==B they will have
+    // been converted to zext(mul(sext(A), sext(A))) as it is known positive,
+    // which is equally fine.
+    bool IsUnsigned = isa<ZExtInst>(Op0);
+    auto *ExtType = VectorType::get(Op0->getOperand(0)->getType(), VectorTy);
+    auto *MulType = VectorType::get(Op0->getType(), VectorTy);
+
+    InstructionCost ExtCost =
+        TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
+                             TTI::CastContextHint::None, CostKind, Op0);
+    InstructionCost MulCost =
+        TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
+    InstructionCost Ext2Cost =
+        TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, MulType,
+                             TTI::CastContextHint::None, CostKind, RedOp);
+
+    InstructionCost RedCost = TTI.getExtendedAddReductionCost(
+        /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
+        CostKind);
+
+    if (RedCost.isValid() &&
+        RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
+      return I == RetI ? RedCost : 0;
+  } else if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) &&
+             !TheLoop->isLoopInvariant(RedOp)) {
     // Matched reduce(ext(A))
     bool IsUnsigned = isa<ZExtInst>(RedOp);
     auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy);
@@ -7245,7 +7278,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
 
       if (RedCost.isValid() && RedCost < ExtCost * 2 + MulCost + BaseCost)
         return I == RetI ? RedCost : 0;
-    } else {
+    } else if (!match(I, m_ZExtOrSExt(m_Value()))) {
       // Matched reduce(mul())
       InstructionCost MulCost =
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);

diff  --git a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
index abd59b42048e..3e8ac1bad93c 100644
--- a/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/ARM/mve-reductions.ll
@@ -722,34 +722,34 @@ for.cond.cleanup:                                 ; preds = %for.body, %entry
 }
 
 ; 8x to use VMLAL.u16
-; FIXME: 8x, TailPredicate, double-extended
+; FIXME: TailPredicate
 define i64 @mla_i8_i64(i8* nocapture readonly %x, i8* nocapture readonly %y, i32 %n) #0 {
 ; CHECK-LABEL: @mla_i8_i64(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP10:%.*]] = icmp sgt i32 [[N:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP10]], label [[FOR_BODY_PREHEADER:%.*]], label [[FOR_COND_CLEANUP:%.*]]
 ; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 16
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 8
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], -16
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], -8
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[X:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <16 x i8>*
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, <16 x i8>* [[TMP1]], align 1
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <8 x i8>*
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i8>, <8 x i8>* [[TMP1]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <8 x i8> [[WIDE_LOAD]] to <8 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[Y:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <16 x i8>*
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, <16 x i8>* [[TMP4]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nuw nsw <16 x i32> [[TMP5]], [[TMP2]]
-; CHECK-NEXT:    [[TMP7:%.*]] = zext <16 x i32> [[TMP6]] to <16 x i64>
-; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v16i64(<16 x i64> [[TMP7]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <8 x i8>*
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i8>, <8 x i8>* [[TMP4]], align 1
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <8 x i8> [[WIDE_LOAD1]] to <8 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nuw nsw <8 x i32> [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <8 x i32> [[TMP6]] to <8 x i64>
+; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
 ; CHECK-NEXT:    [[TMP9]] = add i64 [[TMP8]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 16
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP18:![0-9]+]]
 ; CHECK:       middle.block:
@@ -1140,26 +1140,26 @@ define i32 @red_mla_ext_s8_s16_s32(i8* noalias nocapture readonly %A, i16* noali
 ; CHECK-NEXT:    [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 7
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -8
+; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 3
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -4
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 [[INDEX]], i32 [[N]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <8 x i8>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <8 x i8> @llvm.masked.load.v8i8.p0v8i8(<8 x i8>* [[TMP1]], i32 1, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = sext <8 x i8> [[WIDE_MASKED_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <4 x i8>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP1]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, i16* [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <8 x i16>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <8 x i16> @llvm.masked.load.v8i16.p0v8i16(<8 x i16>* [[TMP4]], i32 2, <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i16> poison)
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <8 x i16> [[WIDE_MASKED_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP2]]
-; CHECK-NEXT:    [[TMP7:%.*]] = select <8 x i1> [[ACTIVE_LANE_MASK]], <8 x i32> [[TMP6]], <8 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> [[TMP7]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <4 x i16>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i16> @llvm.masked.load.v4i16.p0v4i16(<4 x i16>* [[TMP4]], i32 2, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i16> poison)
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i16> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]])
 ; CHECK-NEXT:    [[TMP9]] = add i32 [[TMP8]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 8
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[FOR_COND_CLEANUP]], label [[VECTOR_BODY]], !llvm.loop [[LOOP26:![0-9]+]]
 ; CHECK:       for.cond.cleanup:
@@ -1197,34 +1197,34 @@ for.cond.cleanup:                                 ; preds = %for.cond.cleanup.lo
   ret i32 %s.0.lcssa
 }
 
-; FIXME: 4x as 
diff erent sext vs zext
+; 4x as 
diff erent sext vs zext
 define i64 @red_mla_ext_s16_u16_s64(i16* noalias nocapture readonly %A, i16* noalias nocapture readonly %B, i32 %n) #0 {
 ; CHECK-LABEL: @red_mla_ext_s16_u16_s64(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY_PREHEADER:%.*]]
 ; CHECK:       for.body.preheader:
-; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 8
+; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i32 [[N]], 4
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], -8
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N]], -4
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i16, i16* [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16* [[TMP0]] to <8 x i16>*
-; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <8 x i16>, <8 x i16>* [[TMP1]], align 1
-; CHECK-NEXT:    [[TMP2:%.*]] = sext <8 x i16> [[WIDE_LOAD]] to <8 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i16* [[TMP0]] to <4 x i16>*
+; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[TMP1]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = sext <4 x i16> [[WIDE_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i16, i16* [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <8 x i16>*
-; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <8 x i16>, <8 x i16>* [[TMP4]], align 2
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <8 x i16> [[WIDE_LOAD1]] to <8 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <8 x i32> [[TMP5]], [[TMP2]]
-; CHECK-NEXT:    [[TMP7:%.*]] = zext <8 x i32> [[TMP6]] to <8 x i64>
-; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP7]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i16* [[TMP3]] to <4 x i16>*
+; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <4 x i16>, <4 x i16>* [[TMP4]], align 2
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i16> [[WIDE_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = zext <4 x i32> [[TMP6]] to <4 x i64>
+; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> [[TMP7]])
 ; CHECK-NEXT:    [[TMP9]] = add i64 [[TMP8]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 8
+; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP27:![0-9]+]]
 ; CHECK:       middle.block:
@@ -1285,33 +1285,33 @@ for.cond.cleanup:                                 ; preds = %for.cond.cleanup.lo
   ret i64 %s.0.lcssa
 }
 
-; FIXME: 4x as 
diff erent sext vs zext
+; 4x as 
diff erent sext vs zext
 define i32 @red_mla_u8_s8_u32(i8* noalias nocapture readonly %A, i8* noalias nocapture readonly %B, i32 %n) #0 {
 ; CHECK-LABEL: @red_mla_u8_s8_u32(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CMP9_NOT:%.*]] = icmp eq i32 [[N:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP9_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 15
-; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -16
+; CHECK-NEXT:    [[N_RND_UP:%.*]] = add i32 [[N]], 3
+; CHECK-NEXT:    [[N_VEC:%.*]] = and i32 [[N_RND_UP]], -4
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi i32 [ 0, [[VECTOR_PH]] ], [ [[TMP9:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i32(i32 [[INDEX]], i32 [[N]])
+; CHECK-NEXT:    [[ACTIVE_LANE_MASK:%.*]] = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 [[INDEX]], i32 [[N]])
 ; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <16 x i8>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0v16i8(<16 x i8>* [[TMP1]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP2:%.*]] = zext <16 x i8> [[WIDE_MASKED_LOAD]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[TMP0]] to <4 x i8>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP1]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP2:%.*]] = zext <4 x i8> [[WIDE_MASKED_LOAD]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr inbounds i8, i8* [[B:%.*]], i32 [[INDEX]]
-; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <16 x i8>*
-; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <16 x i8> @llvm.masked.load.v16i8.p0v16i8(<16 x i8>* [[TMP4]], i32 1, <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i8> poison)
-; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_MASKED_LOAD1]] to <16 x i32>
-; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <16 x i32> [[TMP5]], [[TMP2]]
-; CHECK-NEXT:    [[TMP7:%.*]] = select <16 x i1> [[ACTIVE_LANE_MASK]], <16 x i32> [[TMP6]], <16 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP7]])
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast i8* [[TMP3]] to <4 x i8>*
+; CHECK-NEXT:    [[WIDE_MASKED_LOAD1:%.*]] = call <4 x i8> @llvm.masked.load.v4i8.p0v4i8(<4 x i8>* [[TMP4]], i32 1, <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i8> poison)
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <4 x i8> [[WIDE_MASKED_LOAD1]] to <4 x i32>
+; CHECK-NEXT:    [[TMP6:%.*]] = mul nsw <4 x i32> [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = select <4 x i1> [[ACTIVE_LANE_MASK]], <4 x i32> [[TMP6]], <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP7]])
 ; CHECK-NEXT:    [[TMP9]] = add i32 [[TMP8]], [[VEC_PHI]]
-; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 16
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i32 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP10:%.*]] = icmp eq i32 [[INDEX_NEXT]], [[N_VEC]]
 ; CHECK-NEXT:    br i1 [[TMP10]], label [[FOR_COND_CLEANUP]], label [[VECTOR_BODY]], !llvm.loop [[LOOP29:![0-9]+]]
 ; CHECK:       for.cond.cleanup:


        


More information about the llvm-commits mailing list