[llvm-branch-commits] [llvm] 9b76160 - [Support] Introduce a new InstructionCost class

David Sherwood via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Dec 11 00:45:01 PST 2020


Author: David Sherwood
Date: 2020-12-11T08:12:54Z
New Revision: 9b76160e53f67008ff21095098129a2949595a06

URL: https://github.com/llvm/llvm-project/commit/9b76160e53f67008ff21095098129a2949595a06
DIFF: https://github.com/llvm/llvm-project/commit/9b76160e53f67008ff21095098129a2949595a06.diff

LOG: [Support] Introduce a new InstructionCost class

This is the first in a series of patches that attempts to migrate
existing cost instructions to return a new InstructionCost class
in place of a simple integer. This new class is intended to be
as light-weight and simple as possible, with a full range of
arithmetic and comparison operators that largely mirror the same
sets of operations on basic types, such as integers. The main
advantage to using an InstructionCost is that it can encode a
particular cost state in addition to a value. The initial
implementation only has two states - Normal and Invalid - but these
could be expanded over time if necessary. An invalid state can
be used to represent an unknown cost or an instruction that is
prohibitively expensive.

This patch adds the new class and changes the getInstructionCost
interface to return the new class. Other cost functions, such as
getUserCost, etc., will be migrated in future patches as I believe
this to be less disruptive. One benefit of this new class is that
it provides a way to unify many of the magic costs in the codebase
where the cost is set to a deliberately high number to prevent
optimisations taking place, e.g. vectorization. It also provides
a route to represent the extremely high, and unknown, cost of
scalarization of scalable vectors, which is not currently supported.

Differential Revision: https://reviews.llvm.org/D91174

Added: 
    llvm/include/llvm/Support/InstructionCost.h
    llvm/lib/Support/InstructionCost.cpp
    llvm/unittests/Support/InstructionCostTest.cpp

Modified: 
    llvm/include/llvm/Analysis/TargetTransformInfo.h
    llvm/include/llvm/IR/DiagnosticInfo.h
    llvm/lib/Analysis/CostModel.cpp
    llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp
    llvm/lib/IR/DiagnosticInfo.cpp
    llvm/lib/Support/CMakeLists.txt
    llvm/lib/Transforms/IPO/HotColdSplitting.cpp
    llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/unittests/Support/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index af57176401b4..abaf07fad3d4 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -27,6 +27,7 @@
 #include "llvm/Pass.h"
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/DataTypes.h"
+#include "llvm/Support/InstructionCost.h"
 #include <functional>
 
 namespace llvm {
@@ -231,19 +232,26 @@ class TargetTransformInfo {
   ///
   /// Note, this method does not cache the cost calculation and it
   /// can be expensive in some cases.
-  int getInstructionCost(const Instruction *I, enum TargetCostKind kind) const {
+  InstructionCost getInstructionCost(const Instruction *I,
+                                     enum TargetCostKind kind) const {
+    InstructionCost Cost;
     switch (kind) {
     case TCK_RecipThroughput:
-      return getInstructionThroughput(I);
-
+      Cost = getInstructionThroughput(I);
+      break;
     case TCK_Latency:
-      return getInstructionLatency(I);
-
+      Cost = getInstructionLatency(I);
+      break;
     case TCK_CodeSize:
     case TCK_SizeAndLatency:
-      return getUserCost(I, kind);
+      Cost = getUserCost(I, kind);
+      break;
+    default:
+      llvm_unreachable("Unknown instruction cost kind");
     }
-    llvm_unreachable("Unknown instruction cost kind");
+    if (Cost == -1)
+      Cost.setInvalid();
+    return Cost;
   }
 
   /// Underlying constants for 'cost' values in this interface.

diff  --git a/llvm/include/llvm/IR/DiagnosticInfo.h b/llvm/include/llvm/IR/DiagnosticInfo.h
index 644d853b9b0d..c457072d50f1 100644
--- a/llvm/include/llvm/IR/DiagnosticInfo.h
+++ b/llvm/include/llvm/IR/DiagnosticInfo.h
@@ -35,6 +35,7 @@ namespace llvm {
 class DiagnosticPrinter;
 class Function;
 class Instruction;
+class InstructionCost;
 class LLVMContext;
 class Module;
 class SMDiagnostic;
@@ -437,6 +438,7 @@ class DiagnosticInfoOptimizationBase : public DiagnosticInfoWithLocationBase {
     Argument(StringRef Key, ElementCount EC);
     Argument(StringRef Key, bool B) : Key(Key), Val(B ? "true" : "false") {}
     Argument(StringRef Key, DebugLoc dl);
+    Argument(StringRef Key, InstructionCost C);
   };
 
   /// \p PassName is the name of the pass emitting this diagnostic. \p

diff  --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h
new file mode 100644
index 000000000000..fe56d49b4174
--- /dev/null
+++ b/llvm/include/llvm/Support/InstructionCost.h
@@ -0,0 +1,245 @@
+//===- InstructionCost.h ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file defines an InstructionCost class that is used when calculating
+/// the cost of an instruction, or a group of instructions. In addition to a
+/// numeric value representing the cost the class also contains a state that
+/// can be used to encode particular properties, i.e. a cost being invalid or
+/// unknown.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H
+#define LLVM_SUPPORT_INSTRUCTIONCOST_H
+
+#include "llvm/ADT/Optional.h"
+
+namespace llvm {
+
+class raw_ostream;
+
+class InstructionCost {
+public:
+  using CostType = int;
+
+  /// These states can currently be used to indicate whether a cost is valid or
+  /// invalid. Examples of an invalid cost might be where the cost is
+  /// prohibitively expensive and the user wants to prevent certain
+  /// optimizations being performed. Or perhaps the cost is simply unknown
+  /// because the operation makes no sense in certain circumstances. These
+  /// states can be expanded in future to support other cases if necessary.
+  enum CostState { Valid, Invalid };
+
+private:
+  CostType Value;
+  CostState State;
+
+  void propagateState(const InstructionCost &RHS) {
+    if (RHS.State == Invalid)
+      State = Invalid;
+  }
+
+public:
+  InstructionCost() = default;
+
+  InstructionCost(CostType Val) : Value(Val), State(Valid) {}
+
+  static InstructionCost getInvalid(CostType Val = 0) {
+    InstructionCost Tmp(Val);
+    Tmp.setInvalid();
+    return Tmp;
+  }
+
+  bool isValid() const { return State == Valid; }
+  void setValid() { State = Valid; }
+  void setInvalid() { State = Invalid; }
+  CostState getState() const { return State; }
+
+  /// This function is intended to be used as sparingly as possible, since the
+  /// class provides the full range of operator support required for arithmetic
+  /// and comparisons.
+  Optional<CostType> getValue() const {
+    if (isValid())
+      return Value;
+    return None;
+  }
+
+  /// For all of the arithmetic operators provided here any invalid state is
+  /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
+  /// invalid, and it also inherits any invalid state from the RHS. Regardless
+  /// of the state, arithmetic and comparisons work on the actual values in the
+  /// same way as they would on a basic type, such as integer.
+
+  InstructionCost &operator+=(const InstructionCost &RHS) {
+    propagateState(RHS);
+    Value += RHS.Value;
+    return *this;
+  }
+
+  InstructionCost &operator+=(const CostType RHS) {
+    InstructionCost RHS2(RHS);
+    *this += RHS2;
+    return *this;
+  }
+
+  InstructionCost &operator-=(const InstructionCost &RHS) {
+    propagateState(RHS);
+    Value -= RHS.Value;
+    return *this;
+  }
+
+  InstructionCost &operator-=(const CostType RHS) {
+    InstructionCost RHS2(RHS);
+    *this -= RHS2;
+    return *this;
+  }
+
+  InstructionCost &operator*=(const InstructionCost &RHS) {
+    propagateState(RHS);
+    Value *= RHS.Value;
+    return *this;
+  }
+
+  InstructionCost &operator*=(const CostType RHS) {
+    InstructionCost RHS2(RHS);
+    *this *= RHS2;
+    return *this;
+  }
+
+  InstructionCost &operator/=(const InstructionCost &RHS) {
+    propagateState(RHS);
+    Value /= RHS.Value;
+    return *this;
+  }
+
+  InstructionCost &operator/=(const CostType RHS) {
+    InstructionCost RHS2(RHS);
+    *this /= RHS2;
+    return *this;
+  }
+
+  InstructionCost &operator++() {
+    *this += 1;
+    return *this;
+  }
+
+  InstructionCost operator++(int) {
+    InstructionCost Copy = *this;
+    ++*this;
+    return Copy;
+  }
+
+  InstructionCost &operator--() {
+    *this -= 1;
+    return *this;
+  }
+
+  InstructionCost operator--(int) {
+    InstructionCost Copy = *this;
+    --*this;
+    return Copy;
+  }
+
+  bool operator==(const InstructionCost &RHS) const {
+    return State == RHS.State && Value == RHS.Value;
+  }
+
+  bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); }
+
+  bool operator==(const CostType RHS) const {
+    return State == Valid && Value == RHS;
+  }
+
+  bool operator!=(const CostType RHS) const { return !(*this == RHS); }
+
+  /// For the comparison operators we have chosen to use total ordering with
+  /// the following rules:
+  ///  1. If either of the states != Valid then a lexicographical order is
+  ///     applied based upon the state.
+  ///  2. If both states are valid then order based upon value.
+  /// This avoids having to add asserts the comparison operators that the states
+  /// are valid and users can test for validity of the cost explicitly.
+  bool operator<(const InstructionCost &RHS) const {
+    if (State != Valid || RHS.State != Valid)
+      return State < RHS.State;
+    return Value < RHS.Value;
+  }
+
+  bool operator>(const InstructionCost &RHS) const { return RHS < *this; }
+
+  bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); }
+
+  bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); }
+
+  bool operator<(const CostType RHS) const {
+    InstructionCost RHS2(RHS);
+    return *this < RHS2;
+  }
+
+  bool operator>(const CostType RHS) const {
+    InstructionCost RHS2(RHS);
+    return *this > RHS2;
+  }
+
+  bool operator<=(const CostType RHS) const {
+    InstructionCost RHS2(RHS);
+    return *this <= RHS2;
+  }
+
+  bool operator>=(const CostType RHS) const {
+    InstructionCost RHS2(RHS);
+    return *this >= RHS2;
+  }
+
+  static InstructionCost min(InstructionCost LHS, InstructionCost RHS) {
+    return LHS < RHS ? LHS : RHS;
+  }
+
+  static InstructionCost max(InstructionCost LHS, InstructionCost RHS) {
+    return LHS > RHS ? LHS : RHS;
+  }
+
+  void print(raw_ostream &OS) const;
+};
+
+inline InstructionCost operator+(const InstructionCost &LHS,
+                                 const InstructionCost &RHS) {
+  InstructionCost LHS2(LHS);
+  LHS2 += RHS;
+  return LHS2;
+}
+
+inline InstructionCost operator-(const InstructionCost &LHS,
+                                 const InstructionCost &RHS) {
+  InstructionCost LHS2(LHS);
+  LHS2 -= RHS;
+  return LHS2;
+}
+
+inline InstructionCost operator*(const InstructionCost &LHS,
+                                 const InstructionCost &RHS) {
+  InstructionCost LHS2(LHS);
+  LHS2 *= RHS;
+  return LHS2;
+}
+
+inline InstructionCost operator/(const InstructionCost &LHS,
+                                 const InstructionCost &RHS) {
+  InstructionCost LHS2(LHS);
+  LHS2 /= RHS;
+  return LHS2;
+}
+
+inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) {
+  V.print(OS);
+  return OS;
+}
+
+} // namespace llvm
+
+#endif

diff  --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp
index 0fcd69fe7dc4..19c307b4ef8e 100644
--- a/llvm/lib/Analysis/CostModel.cpp
+++ b/llvm/lib/Analysis/CostModel.cpp
@@ -57,7 +57,7 @@ namespace {
     /// Returns -1 if the cost is unknown.
     /// Note, this method does not cache the cost calculation and it
     /// can be expensive in some cases.
-    unsigned getInstructionCost(const Instruction *I) const {
+    InstructionCost getInstructionCost(const Instruction *I) const {
       return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput);
     }
 
@@ -103,9 +103,9 @@ void CostModelAnalysis::print(raw_ostream &OS, const Module*) const {
 
   for (BasicBlock &B : *F) {
     for (Instruction &Inst : B) {
-      unsigned Cost = TTI->getInstructionCost(&Inst, CostKind);
-      if (Cost != (unsigned)-1)
-        OS << "Cost Model: Found an estimated cost of " << Cost;
+      InstructionCost Cost = TTI->getInstructionCost(&Inst, CostKind);
+      if (auto CostVal = Cost.getValue())
+        OS << "Cost Model: Found an estimated cost of " << *CostVal;
       else
         OS << "Cost Model: Unknown cost";
 

diff  --git a/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp b/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp
index f7131926ee65..80c5cc7506d4 100644
--- a/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp
+++ b/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp
@@ -1130,8 +1130,8 @@ bool InterleavedLoadCombineImpl::combine(std::list<VectorInfo> &InterleavedLoad,
   std::set<Instruction *> Is;
   std::set<Instruction *> SVIs;
 
-  unsigned InterleavedCost;
-  unsigned InstructionCost = 0;
+  InstructionCost InterleavedCost;
+  InstructionCost InstructionCost = 0;
 
   // Get the interleave factor
   unsigned Factor = InterleavedLoad.size();
@@ -1174,6 +1174,10 @@ bool InterleavedLoadCombineImpl::combine(std::list<VectorInfo> &InterleavedLoad,
     }
   }
 
+  // We need to have a valid cost in order to proceed.
+  if (!InstructionCost.isValid())
+    return false;
+
   // We know that all LoadInst are within the same BB. This guarantees that
   // either everything or nothing is loaded.
   LoadInst *First = findFirstLoad(LIs);

diff  --git a/llvm/lib/IR/DiagnosticInfo.cpp b/llvm/lib/IR/DiagnosticInfo.cpp
index 46acc9040433..8820a79421c3 100644
--- a/llvm/lib/IR/DiagnosticInfo.cpp
+++ b/llvm/lib/IR/DiagnosticInfo.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/InstructionCost.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/ScopedPrinter.h"
@@ -220,6 +221,13 @@ DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key,
   EC.print(OS);
 }
 
+DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key,
+                                                   InstructionCost C)
+    : Key(std::string(Key)) {
+  raw_string_ostream OS(Val);
+  C.print(OS);
+}
+
 DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key, DebugLoc Loc)
     : Key(std::string(Key)), Loc(Loc) {
   if (Loc) {

diff  --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 337d07d0a9dc..cdee11412eb5 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -128,6 +128,7 @@ add_llvm_component_library(LLVMSupport
   GraphWriter.cpp
   Hashing.cpp
   InitLLVM.cpp
+  InstructionCost.cpp
   IntEqClasses.cpp
   IntervalMap.cpp
   ItaniumManglingCanonicalizer.cpp

diff  --git a/llvm/lib/Support/InstructionCost.cpp b/llvm/lib/Support/InstructionCost.cpp
new file mode 100644
index 000000000000..c485ce9107af
--- /dev/null
+++ b/llvm/lib/Support/InstructionCost.cpp
@@ -0,0 +1,24 @@
+//===- InstructionCost.cpp --------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+/// This file includes the function definitions for the InstructionCost class
+/// that is used when calculating the cost of an instruction, or a group of
+/// instructions.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/InstructionCost.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace llvm;
+
+void InstructionCost::print(raw_ostream &OS) const {
+  if (isValid())
+    OS << Value;
+  else
+    OS << "Invalid";
+}

diff  --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 2460099fba43..042a8dbad6bd 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -233,11 +233,11 @@ bool HotColdSplitting::shouldOutlineFrom(const Function &F) const {
 }
 
 /// Get the benefit score of outlining \p Region.
-static int getOutliningBenefit(ArrayRef<BasicBlock *> Region,
-                               TargetTransformInfo &TTI) {
+static InstructionCost getOutliningBenefit(ArrayRef<BasicBlock *> Region,
+                                           TargetTransformInfo &TTI) {
   // Sum up the code size costs of non-terminator instructions. Tight coupling
   // with \ref getOutliningPenalty is needed to model the costs of terminators.
-  int Benefit = 0;
+  InstructionCost Benefit = 0;
   for (BasicBlock *BB : Region)
     for (Instruction &I : BB->instructionsWithoutDebug())
       if (&I != BB->getTerminator())
@@ -324,12 +324,12 @@ Function *HotColdSplitting::extractColdRegion(
   // splitting.
   SetVector<Value *> Inputs, Outputs, Sinks;
   CE.findInputsOutputs(Inputs, Outputs, Sinks);
-  int OutliningBenefit = getOutliningBenefit(Region, TTI);
+  InstructionCost OutliningBenefit = getOutliningBenefit(Region, TTI);
   int OutliningPenalty =
       getOutliningPenalty(Region, Inputs.size(), Outputs.size());
   LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit
                     << ", penalty = " << OutliningPenalty << "\n");
-  if (OutliningBenefit <= OutliningPenalty)
+  if (!OutliningBenefit.isValid() || OutliningBenefit <= OutliningPenalty)
     return nullptr;
 
   Function *OrigF = Region[0]->getParent();

diff  --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index b26bd1114bd4..2eb94b721d96 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -208,7 +208,7 @@ static bool canSplitCallSite(CallBase &CB, TargetTransformInfo &TTI) {
   // instructions before the call is less then DuplicationThreshold. The
   // instructions before the call will be duplicated in the split blocks and
   // corresponding uses will be updated.
-  unsigned Cost = 0;
+  InstructionCost Cost = 0;
   for (auto &InstBeforeCall :
        llvm::make_range(CallSiteBB->begin(), CB.getIterator())) {
     Cost += TTI.getInstructionCost(&InstBeforeCall,

diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index a91fb988badf..c381377b67c9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7192,8 +7192,12 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I,
       return std::min(CallCost, getVectorIntrinsicCost(CI, VF));
     return CallCost;
   }
-  case Instruction::ExtractValue:
-    return TTI.getInstructionCost(I, TTI::TCK_RecipThroughput);
+  case Instruction::ExtractValue: {
+    InstructionCost ExtractCost =
+        TTI.getInstructionCost(I, TTI::TCK_RecipThroughput);
+    assert(ExtractCost.isValid() && "Invalid cost for ExtractValue");
+    return *(ExtractCost.getValue());
+  }
   default:
     // The cost of executing VF copies of the scalar instruction. This opcode
     // is unknown. Assume that it is the same as 'mul'.

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 456485f45809..dc35f5c3df3d 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3793,17 +3793,23 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
       if (NeedToShuffleReuses) {
         for (unsigned Idx : E->ReuseShuffleIndices) {
           Instruction *I = cast<Instruction>(VL[Idx]);
-          ReuseShuffleCost -= TTI->getInstructionCost(I, CostKind);
+          InstructionCost Cost = TTI->getInstructionCost(I, CostKind);
+          assert(Cost.isValid() && "Invalid instruction cost");
+          ReuseShuffleCost -= *(Cost.getValue());
         }
         for (Value *V : VL) {
           Instruction *I = cast<Instruction>(V);
-          ReuseShuffleCost += TTI->getInstructionCost(I, CostKind);
+          InstructionCost Cost = TTI->getInstructionCost(I, CostKind);
+          assert(Cost.isValid() && "Invalid instruction cost");
+          ReuseShuffleCost += *(Cost.getValue());
         }
       }
       for (Value *V : VL) {
         Instruction *I = cast<Instruction>(V);
         assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode");
-        ScalarCost += TTI->getInstructionCost(I, CostKind);
+        InstructionCost Cost = TTI->getInstructionCost(I, CostKind);
+        assert(Cost.isValid() && "Invalid instruction cost");
+        ScalarCost += *(Cost.getValue());
       }
       // VecCost is equal to sum of the cost of creating 2 vectors
       // and the cost of creating shuffle.

diff  --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt
index ed00d38f248c..44fff481d1b6 100644
--- a/llvm/unittests/Support/CMakeLists.txt
+++ b/llvm/unittests/Support/CMakeLists.txt
@@ -40,6 +40,7 @@ add_llvm_unittest(SupportTests
   GlobPatternTest.cpp
   Host.cpp
   IndexedAccessorTest.cpp
+  InstructionCostTest.cpp
   ItaniumManglingCanonicalizerTest.cpp
   JSONTest.cpp
   KnownBitsTest.cpp

diff  --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp
new file mode 100644
index 000000000000..da3d3f47a212
--- /dev/null
+++ b/llvm/unittests/Support/InstructionCostTest.cpp
@@ -0,0 +1,64 @@
+//===- InstructionCostTest.cpp - InstructionCost tests --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/InstructionCost.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+struct CostTest : public testing::Test {
+  CostTest() {}
+};
+
+} // namespace
+
+TEST_F(CostTest, Operators) {
+  InstructionCost VThree = 3;
+  InstructionCost VNegTwo = -2;
+  InstructionCost VSix = 6;
+  InstructionCost IThreeA = InstructionCost::getInvalid(3);
+  InstructionCost IThreeB = InstructionCost::getInvalid(3);
+  InstructionCost TmpCost;
+
+  EXPECT_NE(VThree, VNegTwo);
+  EXPECT_GT(VThree, VNegTwo);
+  EXPECT_NE(VThree, IThreeA);
+  EXPECT_EQ(IThreeA, IThreeB);
+  EXPECT_GE(IThreeA, VNegTwo);
+  EXPECT_LT(VSix, IThreeA);
+  EXPECT_EQ(VSix - IThreeA, IThreeB);
+  EXPECT_EQ(VThree - VNegTwo, 5);
+  EXPECT_EQ(VThree * VNegTwo, -6);
+  EXPECT_EQ(VSix / VThree, 2);
+
+  EXPECT_FALSE(IThreeA.isValid());
+  EXPECT_EQ(IThreeA.getState(), InstructionCost::Invalid);
+
+  TmpCost = VThree + IThreeA;
+  EXPECT_FALSE(TmpCost.isValid());
+
+  // Test increments, decrements
+  EXPECT_EQ(++VThree, 4);
+  EXPECT_EQ(VThree++, 4);
+  EXPECT_EQ(VThree, 5);
+  EXPECT_EQ(--VThree, 4);
+  EXPECT_EQ(VThree--, 4);
+  EXPECT_EQ(VThree, 3);
+
+  TmpCost = VThree * IThreeA;
+  EXPECT_FALSE(TmpCost.isValid());
+
+  // Test value extraction
+  EXPECT_EQ(*(VThree.getValue()), 3);
+  EXPECT_EQ(IThreeA.getValue(), None);
+
+  EXPECT_EQ(InstructionCost::min(VThree, VNegTwo), -2);
+  EXPECT_EQ(InstructionCost::max(VThree, VSix), 6);
+}


        


More information about the llvm-branch-commits mailing list