[llvm] 928c4b4 - [SCEV] Refactor isHighCostExpansionHelper

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 7 03:58:22 PDT 2020


Author: Sam Parker
Date: 2020-09-07T11:57:46+01:00
New Revision: 928c4b4b4988b4d633a96afa4c7f4584bc0009e5

URL: https://github.com/llvm/llvm-project/commit/928c4b4b4988b4d633a96afa4c7f4584bc0009e5
DIFF: https://github.com/llvm/llvm-project/commit/928c4b4b4988b4d633a96afa4c7f4584bc0009e5.diff

LOG: [SCEV] Refactor isHighCostExpansionHelper

To enable the cost of constants, the helper function has been
reorganised:
- A struct has been introduced to hold SCEV operand information so
  that we know the user of the operand, as well as the operand index.
  The Worklist now uses instead instead of a bare SCEV.
- The costing of each SCEV, and collection of its operands, is now
  performed in a helper function.

Differential Revision: https://reviews.llvm.org/D86050

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 78ae38288c0c..77360cb2671d 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -39,6 +39,19 @@ bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE);
 bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint,
                       ScalarEvolution &SE);
 
+/// struct for holding enough information to help calculate the cost of the
+/// given SCEV when expanded into IR.
+struct SCEVOperand {
+  explicit SCEVOperand(unsigned Opc, int Idx, const SCEV *S) :
+    ParentOpcode(Opc), OperandIdx(Idx), S(S) { }
+  /// LLVM instruction opcode that uses the operand.
+  unsigned ParentOpcode;
+  /// The use index of an expanded instruction.
+  int OperandIdx;
+  /// The SCEV operand to be costed.
+  const SCEV* S;
+};
+
 /// This class uses information about analyze scalars to rewrite expressions
 /// in canonical form.
 ///
@@ -220,14 +233,14 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
     assert(At && "This function requires At instruction to be provided.");
     if (!TTI)      // In assert-less builds, avoid crashing
       return true; // by always claiming to be high-cost.
-    SmallVector<const SCEV *, 8> Worklist;
+    SmallVector<SCEVOperand, 8> Worklist;
     SmallPtrSet<const SCEV *, 8> Processed;
     int BudgetRemaining = Budget * TargetTransformInfo::TCC_Basic;
-    Worklist.emplace_back(Expr);
+    Worklist.emplace_back(-1, -1, Expr);
     while (!Worklist.empty()) {
-      const SCEV *S = Worklist.pop_back_val();
-      if (isHighCostExpansionHelper(S, L, *At, BudgetRemaining, *TTI, Processed,
-                                    Worklist))
+      const SCEVOperand WorkItem = Worklist.pop_back_val();
+      if (isHighCostExpansionHelper(WorkItem, L, *At, BudgetRemaining,
+                                    *TTI, Processed, Worklist))
         return true;
     }
     assert(BudgetRemaining >= 0 && "Should have returned from inner loop.");
@@ -394,11 +407,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root);
 
   /// Recursive helper function for isHighCostExpansion.
-  bool isHighCostExpansionHelper(const SCEV *S, Loop *L, const Instruction &At,
-                                 int &BudgetRemaining,
-                                 const TargetTransformInfo &TTI,
-                                 SmallPtrSetImpl<const SCEV *> &Processed,
-                                 SmallVectorImpl<const SCEV *> &Worklist);
+  bool isHighCostExpansionHelper(
+    const SCEVOperand &WorkItem, Loop *L, const Instruction &At,
+    int &BudgetRemaining, const TargetTransformInfo &TTI,
+    SmallPtrSetImpl<const SCEV *> &Processed,
+    SmallVectorImpl<SCEVOperand> &Worklist);
 
   /// Insert the specified binary operator, doing a small amount of work to
   /// avoid inserting an obviously redundant operation, and hoisting to an

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 1e8b11d6ac5f..1bb827cd3057 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -2177,13 +2177,133 @@ SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At,
   return None;
 }
 
+template<typename T> static int costAndCollectOperands(
+  const SCEVOperand &WorkItem, const TargetTransformInfo &TTI,
+  TargetTransformInfo::TargetCostKind CostKind,
+  SmallVectorImpl<SCEVOperand> &Worklist) {
+
+  const T *S = cast<T>(WorkItem.S);
+  int Cost = 0;
+  // Collect the opcodes of all the instructions that will be needed to expand
+  // the SCEVExpr. This is so that when we come to cost the operands, we know
+  // what the generated user(s) will be.
+  SmallVector<unsigned, 2> Opcodes;
+
+  auto CastCost = [&](unsigned Opcode) {
+    Opcodes.push_back(Opcode);
+    return TTI.getCastInstrCost(Opcode, S->getType(),
+                                S->getOperand(0)->getType(),
+                                TTI::CastContextHint::None, CostKind);
+  };
+
+  auto ArithCost = [&](unsigned Opcode, unsigned NumRequired) {
+    Opcodes.push_back(Opcode);
+    return NumRequired *
+      TTI.getArithmeticInstrCost(Opcode, S->getType(), CostKind);
+  };
+
+  auto CmpSelCost = [&](unsigned Opcode, unsigned NumRequired) {
+    Opcodes.push_back(Opcode);
+    Type *OpType = S->getOperand(0)->getType();
+    return NumRequired *
+      TTI.getCmpSelInstrCost(Opcode, OpType,
+                             CmpInst::makeCmpResultType(OpType), CostKind);
+  };
+
+  switch (S->getSCEVType()) {
+  default:
+    llvm_unreachable("No other scev expressions possible.");
+  case scUnknown:
+  case scConstant:
+    return 0;
+  case scTruncate:
+    Cost = CastCost(Instruction::Trunc);
+    break;
+  case scZeroExtend:
+    Cost = CastCost(Instruction::ZExt);
+    break;
+  case scSignExtend:
+    Cost = CastCost(Instruction::SExt);
+    break;
+  case scUDivExpr: {
+    unsigned Opcode = Instruction::UDiv;
+    if (auto *SC = dyn_cast<SCEVConstant>(S->getOperand(1)))
+      if (SC->getAPInt().isPowerOf2())
+        Opcode = Instruction::LShr;
+    Cost = ArithCost(Opcode, 1);
+    break;
+  }
+  case scAddExpr:
+    Cost = ArithCost(Instruction::Add, S->getNumOperands() - 1);
+    break;
+  case scMulExpr:
+    // TODO: this is a very pessimistic cost modelling for Mul,
+    // because of Bin Pow algorithm actually used by the expander,
+    // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN().
+    Cost = ArithCost(Instruction::Mul, S->getNumOperands() - 1);
+    break;
+  case scSMaxExpr:
+  case scUMaxExpr:
+  case scSMinExpr:
+  case scUMinExpr: {
+    Cost += CmpSelCost(Instruction::ICmp, S->getNumOperands() - 1);
+    Cost += CmpSelCost(Instruction::Select, S->getNumOperands() - 1);
+    break;
+  }
+  case scAddRecExpr: {
+    // In this polynominal, we may have some zero operands, and we shouldn't
+    // really charge for those. So how many non-zero coeffients are there?
+    int NumTerms = llvm::count_if(S->operands(), [](const SCEV *Op) {
+                                    return !Op->isZero();
+                                  });
+
+    assert(NumTerms >= 1 && "Polynominal should have at least one term.");
+    assert(!(*std::prev(S->operands().end()))->isZero() &&
+           "Last operand should not be zero");
+
+    // Ignoring constant term (operand 0), how many of the coeffients are u> 1?
+    int NumNonZeroDegreeNonOneTerms =
+      llvm::count_if(S->operands(), [](const SCEV *Op) {
+                      auto *SConst = dyn_cast<SCEVConstant>(Op);
+                      return !SConst || SConst->getAPInt().ugt(1);
+                    });
+
+    // Much like with normal add expr, the polynominal will require
+    // one less addition than the number of it's terms.
+    int AddCost = ArithCost(Instruction::Add, NumTerms - 1);
+    // Here, *each* one of those will require a multiplication.
+    int MulCost = ArithCost(Instruction::Mul, NumNonZeroDegreeNonOneTerms);
+    Cost = AddCost + MulCost;
+
+    // What is the degree of this polynominal?
+    int PolyDegree = S->getNumOperands() - 1;
+    assert(PolyDegree >= 1 && "Should be at least affine.");
+
+    // The final term will be:
+    //   Op_{PolyDegree} * x ^ {PolyDegree}
+    // Where  x ^ {PolyDegree}  will again require PolyDegree-1 mul operations.
+    // Note that  x ^ {PolyDegree} = x * x ^ {PolyDegree-1}  so charging for
+    // x ^ {PolyDegree}  will give us  x ^ {2} .. x ^ {PolyDegree-1}  for free.
+    // FIXME: this is conservatively correct, but might be overly pessimistic.
+    Cost += MulCost * (PolyDegree - 1);
+  }
+  }
+
+  for (unsigned Opc : Opcodes)
+    for (auto I : enumerate(S->operands()))
+      Worklist.emplace_back(Opc, I.index(), I.value());
+  return Cost;
+}
+
 bool SCEVExpander::isHighCostExpansionHelper(
-    const SCEV *S, Loop *L, const Instruction &At, int &BudgetRemaining,
-    const TargetTransformInfo &TTI, SmallPtrSetImpl<const SCEV *> &Processed,
-    SmallVectorImpl<const SCEV *> &Worklist) {
+    const SCEVOperand &WorkItem, Loop *L, const Instruction &At,
+    int &BudgetRemaining, const TargetTransformInfo &TTI,
+    SmallPtrSetImpl<const SCEV *> &Processed,
+    SmallVectorImpl<SCEVOperand> &Worklist) {
   if (BudgetRemaining < 0)
     return true; // Already run out of budget, give up.
 
+  const SCEV *S = WorkItem.S;
   // Was the cost of expansion of this expression already accounted for?
   if (!Processed.insert(S).second)
     return false; // We have already accounted for this expression.
@@ -2202,44 +2322,12 @@ bool SCEVExpander::isHighCostExpansionHelper(
   TargetTransformInfo::TargetCostKind CostKind =
     TargetTransformInfo::TCK_RecipThroughput;
 
-  if (auto *CastExpr = dyn_cast<SCEVCastExpr>(S)) {
-    unsigned Opcode;
-    switch (S->getSCEVType()) {
-    case scTruncate:
-      Opcode = Instruction::Trunc;
-      break;
-    case scZeroExtend:
-      Opcode = Instruction::ZExt;
-      break;
-    case scSignExtend:
-      Opcode = Instruction::SExt;
-      break;
-    default:
-      llvm_unreachable("There are no other cast types.");
-    }
-    const SCEV *Op = CastExpr->getOperand();
-    BudgetRemaining -= TTI.getCastInstrCost(
-        Opcode, /*Dst=*/S->getType(),
-        /*Src=*/Op->getType(), TTI::CastContextHint::None, CostKind);
-    Worklist.emplace_back(Op);
+  if (isa<SCEVCastExpr>(S)) {
+    int Cost =
+      costAndCollectOperands<SCEVCastExpr>(WorkItem, TTI, CostKind, Worklist);
+    BudgetRemaining -= Cost;
     return false; // Will answer upon next entry into this function.
-  }
-
-  if (auto *UDivExpr = dyn_cast<SCEVUDivExpr>(S)) {
-    // If the divisor is a power of two count this as a logical right-shift.
-    if (auto *SC = dyn_cast<SCEVConstant>(UDivExpr->getRHS())) {
-      if (SC->getAPInt().isPowerOf2()) {
-        BudgetRemaining -=
-            TTI.getArithmeticInstrCost(Instruction::LShr, S->getType(),
-                                       CostKind);
-        // Note that we don't count the cost of RHS, because it is a constant,
-        // and we consider those to be free. But if that changes, we would need
-        // to log2() it first before calling isHighCostExpansionHelper().
-        Worklist.emplace_back(UDivExpr->getLHS());
-        return false; // Will answer upon next entry into this function.
-      }
-    }
-
+  } else if (isa<SCEVUDivExpr>(S)) {
     // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or
     // HowManyLessThans produced to compute a precise expression, rather than a
     // UDiv from the user's code. If we can't find a UDiv in the code with some
@@ -2252,117 +2340,28 @@ bool SCEVExpander::isHighCostExpansionHelper(
             SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), &At, L))
       return false; // Consider it to be free.
 
+    int Cost =
+      costAndCollectOperands<SCEVUDivExpr>(WorkItem, TTI, CostKind, Worklist);
     // Need to count the cost of this UDiv.
-    BudgetRemaining -=
-        TTI.getArithmeticInstrCost(Instruction::UDiv, S->getType(),
-                                   CostKind);
-    Worklist.insert(Worklist.end(), {UDivExpr->getLHS(), UDivExpr->getRHS()});
+    BudgetRemaining -= Cost;
     return false; // Will answer upon next entry into this function.
-  }
-
-  if (const auto *NAry = dyn_cast<SCEVAddRecExpr>(S)) {
-    Type *OpType = NAry->getType();
-
-    assert(NAry->getNumOperands() >= 2 &&
-           "Polynomial should be at least linear");
-
-    int AddCost =
-      TTI.getArithmeticInstrCost(Instruction::Add, OpType, CostKind);
-    int MulCost =
-      TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind);
-
-    // In this polynominal, we may have some zero operands, and we shouldn't
-    // really charge for those. So how many non-zero coeffients are there?
-    int NumTerms = llvm::count_if(NAry->operands(),
-                                  [](const SCEV *S) { return !S->isZero(); });
-    assert(NumTerms >= 1 && "Polynominal should have at least one term.");
-    assert(!(*std::prev(NAry->operands().end()))->isZero() &&
-           "Last operand should not be zero");
-
-    // Much like with normal add expr, the polynominal will require
-    // one less addition than the number of it's terms.
-    BudgetRemaining -= AddCost * (NumTerms - 1);
-    if (BudgetRemaining < 0)
-      return true;
-
-    // Ignoring constant term (operand 0), how many of the coeffients are u> 1?
-    int NumNonZeroDegreeNonOneTerms =
-        llvm::count_if(make_range(std::next(NAry->op_begin()), NAry->op_end()),
-                       [](const SCEV *S) {
-                         auto *SConst = dyn_cast<SCEVConstant>(S);
-                         return !SConst || SConst->getAPInt().ugt(1);
-                       });
-    // Here, *each* one of those will require a multiplication.
-    BudgetRemaining -= MulCost * NumNonZeroDegreeNonOneTerms;
-    if (BudgetRemaining < 0)
-      return true;
-
-    // What is the degree of this polynominal?
-    int PolyDegree = NAry->getNumOperands() - 1;
-    assert(PolyDegree >= 1 && "Should be at least affine.");
-
-    // The final term will be:
-    //   Op_{PolyDegree} * x ^ {PolyDegree}
-    // Where  x ^ {PolyDegree}  will again require PolyDegree-1 mul operations.
-    // Note that  x ^ {PolyDegree} = x * x ^ {PolyDegree-1}  so charging for
-    // x ^ {PolyDegree}  will give us  x ^ {2} .. x ^ {PolyDegree-1}  for free.
-    // FIXME: this is conservatively correct, but might be overly pessimistic.
-    BudgetRemaining -= MulCost * (PolyDegree - 1);
-    if (BudgetRemaining < 0)
-      return true;
-
-    // And finally, the operands themselves should fit within the budget.
-    Worklist.insert(Worklist.end(), NAry->operands().begin(),
-                    NAry->operands().end());
-    return false; // So far so good, though ops may be too costly?
-  }
-
-  if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) {
-    Type *OpType = NAry->getType();
-
-    int PairCost;
-    switch (S->getSCEVType()) {
-    case scAddExpr:
-      PairCost =
-        TTI.getArithmeticInstrCost(Instruction::Add, OpType, CostKind);
-      break;
-    case scMulExpr:
-      // TODO: this is a very pessimistic cost modelling for Mul,
-      // because of Bin Pow algorithm actually used by the expander,
-      // see SCEVExpander::visitMulExpr(), ExpandOpBinPowN().
-      PairCost =
-        TTI.getArithmeticInstrCost(Instruction::Mul, OpType, CostKind);
-      break;
-    case scSMaxExpr:
-    case scUMaxExpr:
-    case scSMinExpr:
-    case scUMinExpr:
-      PairCost = TTI.getCmpSelInstrCost(Instruction::ICmp, OpType,
-                                        CmpInst::makeCmpResultType(OpType),
-                                        CostKind) +
-                 TTI.getCmpSelInstrCost(Instruction::Select, OpType,
-                                        CmpInst::makeCmpResultType(OpType),
-                                        CostKind);
-      break;
-    default:
-      llvm_unreachable("There are no other variants here.");
-    }
-
+  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) {
     assert(NAry->getNumOperands() > 1 &&
            "Nary expr should have more than 1 operand.");
     // The simple nary expr will require one less op (or pair of ops)
     // than the number of it's terms.
-    BudgetRemaining -= PairCost * (NAry->getNumOperands() - 1);
-    if (BudgetRemaining < 0)
-      return true;
-
-    // And finally, the operands themselves should fit within the budget.
-    Worklist.insert(Worklist.end(), NAry->operands().begin(),
-                    NAry->operands().end());
-    return false; // So far so good, though ops may be too costly?
-  }
-
-  llvm_unreachable("No other scev expressions possible.");
+    int Cost =
+      costAndCollectOperands<SCEVNAryExpr>(WorkItem, TTI, CostKind, Worklist);
+    BudgetRemaining -= Cost;
+    return BudgetRemaining < 0;
+  } else if (const auto *NAry = dyn_cast<SCEVAddRecExpr>(S)) {
+    assert(NAry->getNumOperands() >= 2 &&
+           "Polynomial should be at least linear");
+    BudgetRemaining -= costAndCollectOperands<SCEVAddRecExpr>(
+      WorkItem, TTI, CostKind, Worklist);
+    return BudgetRemaining < 0;
+  } else
+    llvm_unreachable("No other scev expressions possible.");
 }
 
 Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,


        


More information about the llvm-commits mailing list