[llvm] 41b6057 - [InstructionCost] Add saturation support.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 10 03:56:52 PDT 2021


Author: Sander de Smalen
Date: 2021-07-10T11:28:42+01:00
New Revision: 41b6057641720e6ba7d4b6c7c2905f2870a885d3

URL: https://github.com/llvm/llvm-project/commit/41b6057641720e6ba7d4b6c7c2905f2870a885d3
DIFF: https://github.com/llvm/llvm-project/commit/41b6057641720e6ba7d4b6c7c2905f2870a885d3.diff

LOG: [InstructionCost] Add saturation support.

This patch makes the operations on InstructionCost saturate, so that when
costs are accumulated they saturate to <max value>.

One of the compelling reasons for wanting to have saturation support
is because in various places, arbitrary values are used to represent
a 'high' cost, but when accumulating the cost of some set of operations
or a loop, overflow is not taken into account, which may lead to unexpected
results. By defining the operations to saturate, we can express the cost
of something 'very expensive' as InstructionCost::getMax().

Reviewed By: kparzysz, dmgreen

Differential Revision: https://reviews.llvm.org/D105108

Added: 
    

Modified: 
    llvm/include/llvm/Support/InstructionCost.h
    llvm/unittests/Support/InstructionCostTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h
index 030c1fb8e4a1e..30f68597e6f22 100644
--- a/llvm/include/llvm/Support/InstructionCost.h
+++ b/llvm/include/llvm/Support/InstructionCost.h
@@ -9,8 +9,9 @@
 /// This file defines an InstructionCost class that is used when calculating
 /// the cost of an instruction, or a group of instructions. In addition to a
 /// numeric value representing the cost the class also contains a state that
-/// can be used to encode particular properties, i.e. a cost being invalid or
-/// unknown.
+/// can be used to encode particular properties, such as a cost being invalid.
+/// Operations on InstructionCost implement saturation arithmetic, so that
+/// accumulating costs on large cost-values don't overflow.
 ///
 //===----------------------------------------------------------------------===//
 
@@ -18,6 +19,8 @@
 #define LLVM_SUPPORT_INSTRUCTIONCOST_H
 
 #include "llvm/ADT/Optional.h"
+#include "llvm/Support/MathExtras.h"
+#include <limits>
 
 namespace llvm {
 
@@ -27,13 +30,24 @@ class InstructionCost {
 public:
   using CostType = int;
 
-  /// These states can currently be used to indicate whether a cost is valid or
-  /// invalid. Examples of an invalid cost might be where the cost is
-  /// prohibitively expensive and the user wants to prevent certain
-  /// optimizations being performed. Or perhaps the cost is simply unknown
-  /// because the operation makes no sense in certain circumstances. These
-  /// states can be expanded in future to support other cases if necessary.
-  enum CostState { Valid, Invalid };
+  /// CostState describes the state of a cost.
+  enum CostState {
+    Valid,  /// < The cost value represents a valid cost, even when the
+            /// cost-value is large.
+    Invalid /// < Invalid indicates there is no way to represent the cost as a
+            /// numeric value. This state exists to represent a possible issue,
+            /// e.g. if the cost-model knows the operation cannot be expanded
+            /// into a valid code-sequence by the code-generator.  While some
+            /// passes may assert that the calculated cost must be valid, it is
+            /// up to individual passes how to interpret an Invalid cost. For
+            /// example, a transformation pass could choose not to perform a
+            /// transformation if the resulting cost would end up Invalid.
+            /// Because some passes may assert a cost is Valid, it is not
+            /// recommended to use Invalid costs to model 'Unknown'.
+            /// Note that Invalid is semantically 
diff erent from a (very) high,
+            /// but valid cost, which intentionally indicates no issue, but
+            /// rather a strong preference not to select a certain operation.
+  };
 
 private:
   CostType Value = 0;
@@ -44,6 +58,9 @@ class InstructionCost {
       State = Invalid;
   }
 
+  static CostType getMaxValue() { return std::numeric_limits<CostType>::max(); }
+  static CostType getMinValue() { return std::numeric_limits<CostType>::min(); }
+
 public:
   // A default constructed InstructionCost is a valid zero cost
   InstructionCost() = default;
@@ -51,6 +68,8 @@ class InstructionCost {
   InstructionCost(CostState) = delete;
   InstructionCost(CostType Val) : Value(Val), State(Valid) {}
 
+  static InstructionCost getMax() { return getMaxValue(); }
+  static InstructionCost getMin() { return getMinValue(); }
   static InstructionCost getInvalid(CostType Val = 0) {
     InstructionCost Tmp(Val);
     Tmp.setInvalid();
@@ -73,13 +92,19 @@ class InstructionCost {
 
   /// For all of the arithmetic operators provided here any invalid state is
   /// perpetuated and cannot be removed. Once a cost becomes invalid it stays
-  /// invalid, and it also inherits any invalid state from the RHS. Regardless
-  /// of the state, arithmetic work on the actual values in the same way as they
-  /// would on a basic type, such as integer.
+  /// invalid, and it also inherits any invalid state from the RHS.
+  /// Arithmetic work on the actual values is implemented with saturation,
+  /// to avoid overflow when using more extreme cost values.
 
   InstructionCost &operator+=(const InstructionCost &RHS) {
     propagateState(RHS);
-    Value += RHS.Value;
+
+    // Saturating addition.
+    InstructionCost::CostType Result;
+    if (AddOverflow(Value, RHS.Value, Result))
+      Result = RHS.Value > 0 ? getMaxValue() : getMinValue();
+
+    Value = Result;
     return *this;
   }
 
@@ -91,7 +116,12 @@ class InstructionCost {
 
   InstructionCost &operator-=(const InstructionCost &RHS) {
     propagateState(RHS);
-    Value -= RHS.Value;
+
+    // Saturating subtract.
+    InstructionCost::CostType Result;
+    if (SubOverflow(Value, RHS.Value, Result))
+      Result = RHS.Value > 0 ? getMinValue() : getMaxValue();
+    Value = Result;
     return *this;
   }
 
@@ -103,7 +133,17 @@ class InstructionCost {
 
   InstructionCost &operator*=(const InstructionCost &RHS) {
     propagateState(RHS);
-    Value *= RHS.Value;
+
+    // Saturating multiply.
+    InstructionCost::CostType Result;
+    if (MulOverflow(Value, RHS.Value, Result)) {
+      if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0))
+        Result = getMaxValue();
+      else
+        Result = getMinValue();
+    }
+
+    Value = Result;
     return *this;
   }
 

diff  --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp
index 6c8a9151e18d7..e31bf34233a23 100644
--- a/llvm/unittests/Support/InstructionCostTest.cpp
+++ b/llvm/unittests/Support/InstructionCostTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Support/InstructionCost.h"
 #include "gtest/gtest.h"
+#include <limits>
 
 using namespace llvm;
 
@@ -75,4 +76,20 @@ TEST_F(CostTest, Operators) {
 
   EXPECT_EQ(std::min(VThree, VNegTwo), -2);
   EXPECT_EQ(std::max(VThree, VSix), 6);
+
+  // Test saturation
+  auto Max = InstructionCost::getMax();
+  auto Min = InstructionCost::getMin();
+  auto MinusOne = InstructionCost(-1);
+  auto MinusTwo = InstructionCost(-2);
+  auto One = InstructionCost(1);
+  auto Two = InstructionCost(2);
+  EXPECT_EQ(Max + One, Max);
+  EXPECT_EQ(Min + MinusOne, Min);
+  EXPECT_EQ(Min - One, Min);
+  EXPECT_EQ(Max - MinusOne, Max);
+  EXPECT_EQ(Max * Two, Max);
+  EXPECT_EQ(Min * Two, Min);
+  EXPECT_EQ(Max * MinusTwo, Min);
+  EXPECT_EQ(Min * MinusTwo, Max);
 }


        


More information about the llvm-commits mailing list