[llvm] e38af7b - [LV] Refactor getExtendedAddReductionCost to support other extended reduction more than Add.

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 2 01:02:49 PDT 2022


Author: jacquesguan
Date: 2022-08-02T16:02:38+08:00
New Revision: e38af7ba958aeb26a60b8a188596eb5762c3fde6

URL: https://github.com/llvm/llvm-project/commit/e38af7ba958aeb26a60b8a188596eb5762c3fde6
DIFF: https://github.com/llvm/llvm-project/commit/e38af7ba958aeb26a60b8a188596eb5762c3fde6.diff

LOG: [LV] Refactor getExtendedAddReductionCost to support other extended reduction more than Add.

Now the API getExtendedAddReductionCost is used to determine the cost of extended Add reduction with optional Mul. For Arm, it could cover the cases. But for other target, for example: RISCV, they support other kinds of extended recution, such as FAdd.

This patch does the following changes:
1, Split getExtendedAddReductionCost into 2 new API: getExtendedReductionCost which handles the extended reduction with addtional input of Opcode; getMulAccReductionCost which handle the MLA cases the getExtendedAddReductionCost.
2, Refactor getReductionPatternCost, add some contraint condition to make sure the getMulAccReductionCost should only handle the reuction of Add + Mul.

Differential Revision: https://reviews.llvm.org/D130868

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
    llvm/include/llvm/CodeGen/BasicTTIImpl.h
    llvm/lib/Analysis/TargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c6637709ff7b7..b26f9f2927f7c 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1265,13 +1265,21 @@ class TargetTransformInfo {
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// Calculate the cost of an extended reduction pattern, similar to
-  /// getArithmeticReductionCost of an Add reduction with an extension and
-  /// optional multiply. This is the cost of as:
-  /// ResTy vecreduce.add(ext(Ty A)), or if IsMLA flag is set then:
-  /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)). The reduction happens
-  /// on a VectorType with ResTy elements and Ty lanes.
-  InstructionCost getExtendedAddReductionCost(
-      bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+  /// getArithmeticReductionCost of an Add reduction with multiply and optional
+  /// extensions. This is the cost of as:
+  /// ResTy vecreduce.add(mul (A, B)).
+  /// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
+  InstructionCost getMulAccReductionCost(
+      bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
+
+  /// Calculate the cost of an extended reduction pattern, similar to
+  /// getArithmeticReductionCost of a reduction with an extension.
+  /// This is the cost of as:
+  /// ResTy vecreduce(ext(Ty A)).
+  InstructionCost getExtendedReductionCost(
+      unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      Optional<FastMathFlags> FMF,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
 
   /// \returns The cost of Intrinsic instructions. Analyses the real arguments.
@@ -1775,8 +1783,12 @@ class TargetTransformInfo::Concept {
   virtual InstructionCost
   getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy, bool IsUnsigned,
                          TTI::TargetCostKind CostKind) = 0;
-  virtual InstructionCost getExtendedAddReductionCost(
-      bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+  virtual InstructionCost getExtendedReductionCost(
+      unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      Optional<FastMathFlags> FMF,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
+  virtual InstructionCost getMulAccReductionCost(
+      bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) = 0;
   virtual InstructionCost
   getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
@@ -2345,11 +2357,17 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                          TTI::TargetCostKind CostKind) override {
     return Impl.getMinMaxReductionCost(Ty, CondTy, IsUnsigned, CostKind);
   }
-  InstructionCost getExtendedAddReductionCost(
-      bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+  InstructionCost getExtendedReductionCost(
+      unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+      Optional<FastMathFlags> FMF,
+      TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) override {
+    return Impl.getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
+                                         CostKind);
+  }
+  InstructionCost getMulAccReductionCost(
+      bool IsUnsigned, Type *ResTy, VectorType *Ty,
       TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) override {
-    return Impl.getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, Ty,
-                                            CostKind);
+    return Impl.getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
   }
   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                         TTI::TargetCostKind CostKind) override {

diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 289721dc5bc5f..74582a1cfdbe1 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -682,10 +682,16 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
-  InstructionCost
-  getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned, Type *ResTy,
-                              VectorType *Ty,
-                              TTI::TargetCostKind CostKind) const {
+  InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           Type *ResTy, VectorType *Ty,
+                                           Optional<FastMathFlags> FMF,
+                                           TTI::TargetCostKind CostKind) const {
+    return 1;
+  }
+
+  InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
+                                         VectorType *Ty,
+                                         TTI::TargetCostKind CostKind) const {
     return 1;
   }
 

diff  --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 2dc63917ea5ef..44c028ccbd2e6 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -2318,25 +2318,38 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
            thisT()->getVectorInstrCost(Instruction::ExtractElement, Ty, 0);
   }
 
-  InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
-                                              Type *ResTy, VectorType *Ty,
-                                              TTI::TargetCostKind CostKind) {
+  InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           Type *ResTy, VectorType *Ty,
+                                           Optional<FastMathFlags> FMF,
+                                           TTI::TargetCostKind CostKind) {
     // Without any native support, this is equivalent to the cost of
-    // vecreduce.add(ext) or if IsMLA vecreduce.add(mul(ext, ext))
+    // vecreduce.op(ext).
+    VectorType *ExtTy = VectorType::get(ResTy, Ty);
+    InstructionCost RedCost =
+        thisT()->getArithmeticReductionCost(Opcode, ExtTy, FMF, CostKind);
+    InstructionCost ExtCost = thisT()->getCastInstrCost(
+        IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty,
+        TTI::CastContextHint::None, CostKind);
+
+    return RedCost + ExtCost;
+  }
+
+  InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
+                                         VectorType *Ty,
+                                         TTI::TargetCostKind CostKind) {
+    // Without any native support, this is equivalent to the cost of
+    // vecreduce.add(mul(ext, ext)).
     VectorType *ExtTy = VectorType::get(ResTy, Ty);
     InstructionCost RedCost = thisT()->getArithmeticReductionCost(
         Instruction::Add, ExtTy, None, CostKind);
-    InstructionCost MulCost = 0;
     InstructionCost ExtCost = thisT()->getCastInstrCost(
         IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty,
         TTI::CastContextHint::None, CostKind);
-    if (IsMLA) {
-      MulCost =
-          thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
-      ExtCost *= 2;
-    }
 
-    return RedCost + MulCost + ExtCost;
+    InstructionCost MulCost =
+        thisT()->getArithmeticInstrCost(Instruction::Mul, ExtTy, CostKind);
+
+    return RedCost + MulCost + 2 * ExtCost;
   }
 
   InstructionCost getVectorSplitCost() { return 1; }

diff  --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index afd24950667d5..4684dbe424d64 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -976,11 +976,17 @@ InstructionCost TargetTransformInfo::getMinMaxReductionCost(
   return Cost;
 }
 
-InstructionCost TargetTransformInfo::getExtendedAddReductionCost(
-    bool IsMLA, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+InstructionCost TargetTransformInfo::getExtendedReductionCost(
+    unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *Ty,
+    Optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const {
+  return TTIImpl->getExtendedReductionCost(Opcode, IsUnsigned, ResTy, Ty, FMF,
+                                           CostKind);
+}
+
+InstructionCost TargetTransformInfo::getMulAccReductionCost(
+    bool IsUnsigned, Type *ResTy, VectorType *Ty,
     TTI::TargetCostKind CostKind) const {
-  return TTIImpl->getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, Ty,
-                                              CostKind);
+  return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
 }
 
 InstructionCost

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 3c102463ba080..0114265b23476 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -1677,10 +1677,45 @@ ARMTTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
 }
 
+InstructionCost ARMTTIImpl::getExtendedReductionCost(
+    unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *ValTy,
+    Optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) {
+  EVT ValVT = TLI->getValueType(DL, ValTy);
+  EVT ResVT = TLI->getValueType(DL, ResTy);
+
+  int ISD = TLI->InstructionOpcodeToISD(Opcode);
+
+  switch (ISD) {
+  case ISD::ADD:
+    if (ST->hasMVEIntegerOps() && ValVT.isSimple() && ResVT.isSimple()) {
+      std::pair<InstructionCost, MVT> LT =
+          TLI->getTypeLegalizationCost(DL, ValTy);
+
+      // The legal cases are:
+      //   VADDV u/s 8/16/32
+      //   VADDLV u/s 32
+      // Codegen currently cannot always handle larger than legal vectors very
+      // well, especially for predicated reductions where the mask needs to be
+      // split, so restrict to 128bit or smaller input types.
+      unsigned RevVTSize = ResVT.getSizeInBits();
+      if (ValVT.getSizeInBits() <= 128 &&
+          ((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
+           (LT.second == MVT::v8i16 && RevVTSize <= 32) ||
+           (LT.second == MVT::v4i32 && RevVTSize <= 64)))
+        return ST->getMVEVectorCostFactor(CostKind) * LT.first;
+    }
+    break;
+  default:
+    break;
+  }
+  return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, ValTy, FMF,
+                                         CostKind);
+}
+
 InstructionCost
-ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
-                                        Type *ResTy, VectorType *ValTy,
-                                        TTI::TargetCostKind CostKind) {
+ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
+                                   VectorType *ValTy,
+                                   TTI::TargetCostKind CostKind) {
   EVT ValVT = TLI->getValueType(DL, ValTy);
   EVT ResVT = TLI->getValueType(DL, ResTy);
 
@@ -1689,9 +1724,7 @@ ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
         TLI->getTypeLegalizationCost(DL, ValTy);
 
     // The legal cases are:
-    //   VADDV u/s 8/16/32
     //   VMLAV u/s 8/16/32
-    //   VADDLV u/s 32
     //   VMLALV u/s 16/32
     // Codegen currently cannot always handle larger than legal vectors very
     // well, especially for predicated reductions where the mask needs to be
@@ -1699,13 +1732,12 @@ ARMTTIImpl::getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
     unsigned RevVTSize = ResVT.getSizeInBits();
     if (ValVT.getSizeInBits() <= 128 &&
         ((LT.second == MVT::v16i8 && RevVTSize <= 32) ||
-         (LT.second == MVT::v8i16 && RevVTSize <= (IsMLA ? 64u : 32u)) ||
+         (LT.second == MVT::v8i16 && RevVTSize <= 64) ||
          (LT.second == MVT::v4i32 && RevVTSize <= 64)))
       return ST->getMVEVectorCostFactor(CostKind) * LT.first;
   }
 
-  return BaseT::getExtendedAddReductionCost(IsMLA, IsUnsigned, ResTy, ValTy,
-                                            CostKind);
+  return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
 }
 
 InstructionCost

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 9c3980d79e606..ee3023a69d861 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -275,9 +275,13 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
                                              Optional<FastMathFlags> FMF,
                                              TTI::TargetCostKind CostKind);
-  InstructionCost getExtendedAddReductionCost(bool IsMLA, bool IsUnsigned,
-                                              Type *ResTy, VectorType *ValTy,
-                                              TTI::TargetCostKind CostKind);
+  InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
+                                           Type *ResTy, VectorType *ValTy,
+                                           Optional<FastMathFlags> FMF,
+                                           TTI::TargetCostKind CostKind);
+  InstructionCost getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
+                                         VectorType *ValTy,
+                                         TTI::TargetCostKind CostKind);
 
   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
                                         TTI::TargetCostKind CostKind);

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index b74963d1104a2..d0eff1d74b00c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -6547,7 +6547,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
   VectorTy = VectorType::get(I->getOperand(0)->getType(), VectorTy);
 
   Instruction *Op0, *Op1;
-  if (RedOp &&
+  if (RedOp && RdxDesc.getOpcode() == Instruction::Add &&
       match(RedOp,
             m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
       match(Op0, m_ZExtOrSExt(m_Value())) &&
@@ -6556,7 +6556,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
       !TheLoop->isLoopInvariant(Op0) && !TheLoop->isLoopInvariant(Op1) &&
       (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
 
-    // Matched reduce(ext(mul(ext(A), ext(B)))
+    // Matched reduce.add(ext(mul(ext(A), ext(B)))
     // Note that the extend opcodes need to all match, or if A==B they will have
     // been converted to zext(mul(sext(A), sext(A))) as it is known positive,
     // which is equally fine.
@@ -6573,9 +6573,8 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
         TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, MulType,
                              TTI::CastContextHint::None, CostKind, RedOp);
 
-    InstructionCost RedCost = TTI.getExtendedAddReductionCost(
-        /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
-        CostKind);
+    InstructionCost RedCost = TTI.getMulAccReductionCost(
+        IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
 
     if (RedCost.isValid() &&
         RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -6585,16 +6584,16 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
     // Matched reduce(ext(A))
     bool IsUnsigned = isa<ZExtInst>(RedOp);
     auto *ExtType = VectorType::get(RedOp->getOperand(0)->getType(), VectorTy);
-    InstructionCost RedCost = TTI.getExtendedAddReductionCost(
-        /*IsMLA=*/false, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
-        CostKind);
+    InstructionCost RedCost = TTI.getExtendedReductionCost(
+        RdxDesc.getOpcode(), IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
+        RdxDesc.getFastMathFlags(), CostKind);
 
     InstructionCost ExtCost =
         TTI.getCastInstrCost(RedOp->getOpcode(), VectorTy, ExtType,
                              TTI::CastContextHint::None, CostKind, RedOp);
     if (RedCost.isValid() && RedCost < BaseCost + ExtCost)
       return I == RetI ? RedCost : 0;
-  } else if (RedOp &&
+  } else if (RedOp && RdxDesc.getOpcode() == Instruction::Add &&
              match(RedOp, m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) {
     if (match(Op0, m_ZExtOrSExt(m_Value())) &&
         Op0->getOpcode() == Op1->getOpcode() &&
@@ -6607,7 +6606,7 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
                                                                     : Op0Ty;
       auto *ExtType = VectorType::get(LargestOpTy, VectorTy);
 
-      // Matched reduce(mul(ext(A), ext(B))), where the two ext may be of
+      // Matched reduce.add(mul(ext(A), ext(B))), where the two ext may be of
       // 
diff erent sizes. We take the largest type as the ext to reduce, and add
       // the remaining cost as, for example reduce(mul(ext(ext(A)), ext(B))).
       InstructionCost ExtCost0 = TTI.getCastInstrCost(
@@ -6619,9 +6618,8 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
       InstructionCost MulCost =
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
-      InstructionCost RedCost = TTI.getExtendedAddReductionCost(
-          /*IsMLA=*/true, IsUnsigned, RdxDesc.getRecurrenceType(), ExtType,
-          CostKind);
+      InstructionCost RedCost = TTI.getMulAccReductionCost(
+          IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
       InstructionCost ExtraExtCost = 0;
       if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
         Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -6635,13 +6633,12 @@ Optional<InstructionCost> LoopVectorizationCostModel::getReductionPatternCost(
           (RedCost + ExtraExtCost) < (ExtCost0 + ExtCost1 + MulCost + BaseCost))
         return I == RetI ? RedCost : 0;
     } else if (!match(I, m_ZExtOrSExt(m_Value()))) {
-      // Matched reduce(mul())
+      // Matched reduce.add(mul())
       InstructionCost MulCost =
           TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
 
-      InstructionCost RedCost = TTI.getExtendedAddReductionCost(
-          /*IsMLA=*/true, true, RdxDesc.getRecurrenceType(), VectorTy,
-          CostKind);
+      InstructionCost RedCost = TTI.getMulAccReductionCost(
+          true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
 
       if (RedCost.isValid() && RedCost < MulCost + BaseCost)
         return I == RetI ? RedCost : 0;


        


More information about the llvm-commits mailing list