[Mlir-commits] [mlir] c777e51 - [mlir][Analysis][NFC] FlatAffineConstraints: Use BoundType enum in functions

Matthias Springer llvmlistbot at llvm.org
Wed Aug 18 18:33:57 PDT 2021


Author: Matthias Springer
Date: 2021-08-19T10:33:42+09:00
New Revision: c777e51468f5d44ad4600344683ecf9b46aa2b0f

URL: https://github.com/llvm/llvm-project/commit/c777e51468f5d44ad4600344683ecf9b46aa2b0f
DIFF: https://github.com/llvm/llvm-project/commit/c777e51468f5d44ad4600344683ecf9b46aa2b0f.diff

LOG: [mlir][Analysis][NFC] FlatAffineConstraints: Use BoundType enum in functions

Differential Revision: https://reviews.llvm.org/D108185

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/lib/Analysis/AffineAnalysis.cpp
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/Utils.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp
    mlir/lib/Transforms/Utils/Utils.cpp
    mlir/unittests/Analysis/AffineStructuresTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 76f62355d9af8..e33f7bc124753 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -194,14 +194,23 @@ class FlatAffineConstraints {
     return inequalities.getRow(idx);
   }
 
-  /// Adds a lower or an upper bound for the identifier at the specified
-  /// position with constraints being drawn from the specified bound map. If
-  /// `eq` is true, add a single equality equal to the bound map's first result
-  /// expr.
+  /// The type of bound: equal, lower bound or upper bound.
+  enum BoundType { EQ, LB, UB };
+
+  /// Adds a bound for the identifier at the specified position with constraints
+  /// being drawn from the specified bound map. In case of an EQ bound, the
+  /// bound map is expected to have exactly one result. In case of a LB/UB, the
+  /// bound map may have more than one result, for each of which an inequality
+  /// is added.
   /// Note: The dimensions/symbols of this FlatAffineConstraints must match the
   /// dimensions/symbols of the affine map.
-  LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap, bool eq,
-                                     bool lower = true);
+  LogicalResult addBound(BoundType type, unsigned pos, AffineMap boundMap);
+
+  /// Adds a constant bound for the specified identifier.
+  void addBound(BoundType type, unsigned pos, int64_t value);
+
+  /// Adds a constant bound for the specified expression.
+  void addBound(BoundType type, ArrayRef<int64_t> expr, int64_t value);
 
   /// Returns the constraint system as an integer set. Returns a null integer
   /// set if the system has no constraints, or if an integer set couldn't be
@@ -224,11 +233,6 @@ class FlatAffineConstraints {
   /// Adds an equality from the coefficients specified in `eq`.
   void addEquality(ArrayRef<int64_t> eq);
 
-  /// Adds a constant lower bound constraint for the specified identifier.
-  void addConstantLowerBound(unsigned pos, int64_t lb);
-  /// Adds a constant upper bound constraint for the specified identifier.
-  void addConstantUpperBound(unsigned pos, int64_t ub);
-
   /// Adds a new local identifier as the floordiv of an affine function of other
   /// identifiers, the coefficients of which are provided in `dividend` and with
   /// respect to a positive constant `divisor`. Two constraints are added to the
@@ -236,14 +240,6 @@ class FlatAffineConstraints {
   /// q = dividend floordiv c    <=>   c*q <= dividend <= c*q + c - 1.
   void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor);
 
-  /// Adds a constant lower bound constraint for the specified expression.
-  void addConstantLowerBound(ArrayRef<int64_t> expr, int64_t lb);
-  /// Adds a constant upper bound constraint for the specified expression.
-  void addConstantUpperBound(ArrayRef<int64_t> expr, int64_t ub);
-
-  /// Sets the identifier at the specified position to a constant.
-  void setIdToConstant(unsigned pos, int64_t val);
-
   /// Swap the posA^th identifier with the posB^th identifier.
   virtual void swapId(unsigned posA, unsigned posB);
 
@@ -349,13 +345,10 @@ class FlatAffineConstraints {
       SmallVectorImpl<int64_t> *ub = nullptr, unsigned *minLbPos = nullptr,
       unsigned *minUbPos = nullptr) const;
 
-  /// Returns the constant lower bound for the pos^th identifier if there is
-  /// one; None otherwise.
-  Optional<int64_t> getConstantLowerBound(unsigned pos) const;
-
-  /// Returns the constant upper bound for the pos^th identifier if there is
-  /// one; None otherwise.
-  Optional<int64_t> getConstantUpperBound(unsigned pos) const;
+  /// Returns the constant bound for the pos^th identifier if there is one;
+  /// None otherwise.
+  // TODO: Support EQ bounds.
+  Optional<int64_t> getConstantBound(BoundType type, unsigned pos) const;
 
   /// Gets the lower and upper bound of the `offset` + `pos`th identifier
   /// treating [0, offset) U [offset + num, symStartPos) as dimensions and
@@ -611,14 +604,18 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   /// the columns in the current one regarding numbers and values.
   void addAffineIfOpDomain(AffineIfOp ifOp);
 
-  /// Adds a lower or an upper bound for the identifier at the specified
-  /// position with constraints being drawn from the specified bound map and
-  /// operands. If `eq` is true, add a single equality equal to the bound map's
-  /// first result expr.
-  LogicalResult addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
-                                     ValueRange operands, bool eq,
-                                     bool lower = true);
-  using FlatAffineConstraints::addLowerOrUpperBound;
+  /// Adds a bound for the identifier at the specified position with constraints
+  /// being drawn from the specified bound map and operands. In case of an
+  /// EQ bound, the  bound map is expected to have exactly one result. In case
+  /// of a LB/UB, the bound map may have more than one result, for each of which
+  /// an inequality is added.
+  LogicalResult addBound(BoundType type, unsigned pos, AffineMap boundMap,
+                         ValueRange operands);
+
+  /// Adds a constant bound for the identifier associated with the given Value.
+  void addBound(BoundType type, Value val, int64_t value);
+
+  using FlatAffineConstraints::addBound;
 
   /// Returns the bound for the identifier at `pos` from the inequality at
   /// `ineqPos` as a 1-d affine value map (affine map + operands). The returned
@@ -640,11 +637,6 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
                                ArrayRef<AffineMap> ubMaps,
                                ArrayRef<Value> operands);
 
-  /// Sets the identifier corresponding to the specified Value `value` to a
-  /// constant. Asserts if the `value` is not found.
-  void setIdToConstant(Value value, int64_t val);
-  using FlatAffineConstraints::setIdToConstant;
-
   /// Looks up the position of the identifier with the specified Value. Returns
   /// true if found (false otherwise). `pos` is set to the (column) position of
   /// the identifier.

diff  --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index e3feb780bdf46..b83f02283ba08 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -664,8 +664,9 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
       assert(isValidSymbol(symbol));
       // Check if the symbol is a constant.
       if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
-        dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
-                                          cOp.getValue());
+        dependenceDomain->addBound(FlatAffineConstraints::EQ,
+                                   valuePosMap.getSymPos(symbol),
+                                   cOp.getValue());
     }
   };
 
@@ -885,10 +886,12 @@ static void computeDirectionVector(
   dependenceComponents->resize(numCommonLoops);
   for (unsigned j = 0; j < numCommonLoops; ++j) {
     (*dependenceComponents)[j].op = commonLoops[j].getOperation();
-    auto lbConst = dependenceDomain->getConstantLowerBound(j);
+    auto lbConst =
+        dependenceDomain->getConstantBound(FlatAffineConstraints::LB, j);
     (*dependenceComponents)[j].lb =
         lbConst.getValueOr(std::numeric_limits<int64_t>::min());
-    auto ubConst = dependenceDomain->getConstantUpperBound(j);
+    auto ubConst =
+        dependenceDomain->getConstantBound(FlatAffineConstraints::UB, j);
     (*dependenceComponents)[j].ub =
         ubConst.getValueOr(std::numeric_limits<int64_t>::max());
   }

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index bb636299358e3..58eb5c230b742 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -550,7 +550,7 @@ void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
   addSymbolId(getNumSymbolIds(), val);
   // Check if the symbol is a constant.
   if (auto constOp = val.getDefiningOp<ConstantIndexOp>())
-    setIdToConstant(val, constOp.getValue());
+    addBound(BoundType::EQ, val, constOp.getValue());
 }
 
 LogicalResult
@@ -588,23 +588,21 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
   }
 
   if (forOp.hasConstantLowerBound()) {
-    addConstantLowerBound(pos, forOp.getConstantLowerBound());
+    addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
   } else {
     // Non-constant lower bound case.
-    if (failed(addLowerOrUpperBound(pos, forOp.getLowerBoundMap(),
-                                    forOp.getLowerBoundOperands(),
-                                    /*eq=*/false, /*lower=*/true)))
+    if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
+                        forOp.getLowerBoundOperands())))
       return failure();
   }
 
   if (forOp.hasConstantUpperBound()) {
-    addConstantUpperBound(pos, forOp.getConstantUpperBound() - 1);
+    addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
     return success();
   }
   // Non-constant upper bound case.
-  return addLowerOrUpperBound(pos, forOp.getUpperBoundMap(),
-                              forOp.getUpperBoundOperands(),
-                              /*eq=*/false, /*lower=*/false);
+  return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
+                  forOp.getUpperBoundOperands());
 }
 
 LogicalResult
@@ -649,12 +647,9 @@ FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
     // This slice refers to a loop that doesn't exist in the IR yet. Add its
     // bounds to the system assuming its dimension identifier position is the
     // same as the position of the loop in the loop nest.
-    if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false,
-                                             /*lower=*/true)))
+    if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
       return failure();
-
-    if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false,
-                                             /*lower=*/false)))
+    if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
       return failure();
   }
   return success();
@@ -1393,7 +1388,8 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
 
     // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
     if (quotientCount >= 1) {
-      auto ub = cst.getConstantUpperBound(dimExpr.getPosition());
+      auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB,
+                                     dimExpr.getPosition());
       // If `id_n` has an upperbound that is less than the divisor, mod can be
       // eliminated altogether.
       if (ub.hasValue() && ub.getValue() < divisor)
@@ -1768,8 +1764,8 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
       if (memo[pos])
         continue;
 
-      auto lbConst = getConstantLowerBound(pos);
-      auto ubConst = getConstantUpperBound(pos);
+      auto lbConst = getConstantBound(BoundType::LB, pos);
+      auto ubConst = getConstantBound(BoundType::UB, pos);
       if (lbConst.hasValue() && ubConst.hasValue()) {
         // Detect equality to a constant.
         if (lbConst.getValue() == ubConst.getValue()) {
@@ -1878,7 +1874,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
       if (!lbMap || lbMap.getNumResults() > 1) {
         LLVM_DEBUG(llvm::dbgs()
                    << "WARNING: Potentially over-approximating slice lb\n");
-        auto lbConst = getConstantLowerBound(pos + offset);
+        auto lbConst = getConstantBound(BoundType::LB, pos + offset);
         if (lbConst.hasValue()) {
           lbMap = AffineMap::get(
               numMapDims, numMapSymbols,
@@ -1888,7 +1884,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
       if (!ubMap || ubMap.getNumResults() > 1) {
         LLVM_DEBUG(llvm::dbgs()
                    << "WARNING: Potentially over-approximating slice ub\n");
-        auto ubConst = getConstantUpperBound(pos + offset);
+        auto ubConst = getConstantBound(BoundType::UB, pos + offset);
         if (ubConst.hasValue()) {
           (ubMap) = AffineMap::get(
               numMapDims, numMapSymbols,
@@ -1931,18 +1927,17 @@ LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
   return success();
 }
 
-LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
-                                                          AffineMap boundMap,
-                                                          bool eq, bool lower) {
+LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos,
+                                              AffineMap boundMap) {
   assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
   assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
   assert(pos < getNumDimAndSymbolIds() && "invalid position");
 
   // Equality follows the logic of lower bound except that we add an equality
   // instead of an inequality.
-  assert((!eq || boundMap.getNumResults() == 1) && "single result expected");
-  if (eq)
-    lower = true;
+  assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
+         "single result expected");
+  bool lower = type == BoundType::LB || type == BoundType::EQ;
 
   std::vector<SmallVector<int64_t, 8>> flatExprs;
   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
@@ -1973,7 +1968,7 @@ LogicalResult FlatAffineConstraints::addLowerOrUpperBound(unsigned pos,
         lower ? -flatExpr[flatExpr.size() - 1]
               // Upper bound in flattenedExpr is an exclusive one.
               : flatExpr[flatExpr.size() - 1] - 1;
-    eq ? addEquality(ineq) : addInequality(ineq);
+    type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
   }
 
   return success();
@@ -2008,9 +2003,9 @@ FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
   return alignedMap;
 }
 
-LogicalResult FlatAffineValueConstraints::addLowerOrUpperBound(
-    unsigned pos, AffineMap boundMap, ValueRange boundOperands, bool eq,
-    bool lower) {
+LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
+                                                   AffineMap boundMap,
+                                                   ValueRange boundOperands) {
   // Fully compose map and operands; canonicalize and simplify so that we
   // transitively get to terminal symbols or loop IVs.
   auto map = boundMap;
@@ -2020,7 +2015,7 @@ LogicalResult FlatAffineValueConstraints::addLowerOrUpperBound(
   canonicalizeMapAndOperands(&map, &operands);
   for (auto operand : operands)
     addInductionVarOrTerminalSymbol(operand);
-  return addLowerOrUpperBound(pos, computeAlignedMap(map, operands), eq, lower);
+  return addBound(type, pos, computeAlignedMap(map, operands));
 }
 
 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
@@ -2052,8 +2047,7 @@ LogicalResult FlatAffineValueConstraints::addSliceBounds(
     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
         ubMap.getNumResults() == 1 &&
         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
-      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/true,
-                                      /*lower=*/true)))
+      if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
         return failure();
       continue;
     }
@@ -2063,11 +2057,9 @@ LogicalResult FlatAffineValueConstraints::addSliceBounds(
     // part of the slice.
     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
         ubMap.getNumResults() != 0) {
-      if (failed(addLowerOrUpperBound(pos, lbMap, operands, /*eq=*/false,
-                                      /*lower=*/true)))
+      if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
         return failure();
-      if (failed(addLowerOrUpperBound(pos, ubMap, operands, /*eq=*/false,
-                                      /*lower=*/false)))
+      if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
         return failure();
     } else {
       auto loop = getForInductionVarOwner(values[i]);
@@ -2092,33 +2084,30 @@ void FlatAffineConstraints::addInequality(ArrayRef<int64_t> inEq) {
     inequalities(row, i) = inEq[i];
 }
 
-void FlatAffineConstraints::addConstantLowerBound(unsigned pos, int64_t lb) {
-  assert(pos < getNumCols());
-  unsigned row = inequalities.appendExtraRow();
-  inequalities(row, pos) = 1;
-  inequalities(row, getNumCols() - 1) = -lb;
-}
-
-void FlatAffineConstraints::addConstantUpperBound(unsigned pos, int64_t ub) {
+void FlatAffineConstraints::addBound(BoundType type, unsigned pos,
+                                     int64_t value) {
   assert(pos < getNumCols());
-  unsigned row = inequalities.appendExtraRow();
-  inequalities(row, pos) = -1;
-  inequalities(row, getNumCols() - 1) = ub;
-}
-
-void FlatAffineConstraints::addConstantLowerBound(ArrayRef<int64_t> expr,
-                                                  int64_t lb) {
-  addInequality(expr);
-  inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += -lb;
+  if (type == BoundType::EQ) {
+    unsigned row = equalities.appendExtraRow();
+    equalities(row, pos) = 1;
+    equalities(row, getNumCols() - 1) = -value;
+  } else {
+    unsigned row = inequalities.appendExtraRow();
+    inequalities(row, pos) = type == BoundType::LB ? 1 : -1;
+    inequalities(row, getNumCols() - 1) =
+        type == BoundType::LB ? -value : value;
+  }
 }
 
-void FlatAffineConstraints::addConstantUpperBound(ArrayRef<int64_t> expr,
-                                                  int64_t ub) {
+void FlatAffineConstraints::addBound(BoundType type, ArrayRef<int64_t> expr,
+                                     int64_t value) {
+  assert(type != BoundType::EQ && "EQ not implemented");
   assert(expr.size() == getNumCols());
   unsigned row = inequalities.appendExtraRow();
   for (unsigned i = 0, e = expr.size(); i < e; ++i)
-    inequalities(row, i) = -expr[i];
-  inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) += ub;
+    inequalities(row, i) = type == BoundType::LB ? expr[i] : -expr[i];
+  inequalities(inequalities.getNumRows() - 1, getNumCols() - 1) +=
+      type == BoundType::LB ? -value : value;
 }
 
 /// Adds a new local identifier as the floordiv of an affine function of other
@@ -2193,22 +2182,13 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
   numSymbols = newSymbolCount;
 }
 
-/// Sets the specified identifier to a constant value.
-void FlatAffineConstraints::setIdToConstant(unsigned pos, int64_t val) {
-  equalities.resizeVertically(equalities.getNumRows() + 1);
-  unsigned row = equalities.getNumRows() - 1;
-  equalities(row, pos) = 1;
-  equalities(row, getNumCols() - 1) = -val;
-}
-
-/// Sets the specified identifier to a constant value; asserts if the id is not
-/// found.
-void FlatAffineValueConstraints::setIdToConstant(Value value, int64_t val) {
+void FlatAffineValueConstraints::addBound(BoundType type, Value val,
+                                          int64_t value) {
   unsigned pos;
-  if (!findId(value, &pos))
+  if (!findId(val, &pos))
     // This is a pre-condition for this method.
     assert(0 && "id not found");
-  setIdToConstant(pos, val);
+  addBound(type, pos, value);
 }
 
 void FlatAffineConstraints::removeEquality(unsigned pos) {
@@ -2485,15 +2465,12 @@ FlatAffineConstraints::computeConstantLowerOrUpperBound(unsigned pos) {
   return minOrMaxConst;
 }
 
-Optional<int64_t>
-FlatAffineConstraints::getConstantLowerBound(unsigned pos) const {
-  FlatAffineConstraints tmpCst(*this);
-  return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
-}
-
-Optional<int64_t>
-FlatAffineConstraints::getConstantUpperBound(unsigned pos) const {
+Optional<int64_t> FlatAffineConstraints::getConstantBound(BoundType type,
+                                                          unsigned pos) const {
+  assert(type != BoundType::EQ && "EQ not implemented");
   FlatAffineConstraints tmpCst(*this);
+  if (type == BoundType::LB)
+    return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
   return tmpCst.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
 }
 
@@ -3042,8 +3019,8 @@ FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
       minLb.back() -= otherLbFloorDivisor - 1;
     } else {
       // Uncomparable - check for constant lower/upper bounds.
-      auto constLb = getConstantLowerBound(d);
-      auto constOtherLb = otherCst.getConstantLowerBound(d);
+      auto constLb = getConstantBound(BoundType::LB, d);
+      auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
       if (!constLb.hasValue() || !constOtherLb.hasValue())
         return failure();
       std::fill(minLb.begin(), minLb.end(), 0);
@@ -3058,8 +3035,8 @@ FlatAffineConstraints::unionBoundingBox(const FlatAffineConstraints &otherCst) {
       maxUb = otherUb;
     } else {
       // Uncomparable - check for constant lower/upper bounds.
-      auto constUb = getConstantUpperBound(d);
-      auto constOtherUb = otherCst.getConstantUpperBound(d);
+      auto constUb = getConstantBound(BoundType::UB, d);
+      auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
       if (!constUb.hasValue() || !constOtherUb.hasValue())
         return failure();
       std::fill(maxUb.begin(), maxUb.end(), 0);

diff  --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 93b36929999d9..558808bcbe00e 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -99,7 +99,7 @@ ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) {
     if (isValidSymbol(value)) {
       // Check if the symbol is a constant.
       if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
-        cst->setIdToConstant(value, cOp.getValue());
+        cst->addBound(FlatAffineConstraints::EQ, value, cOp.getValue());
     } else if (auto loop = getForInductionVarOwner(value)) {
       if (failed(cst->addAffineForOpDomain(loop)))
         return failure();
@@ -357,11 +357,11 @@ Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
   // that will need non-trivials means to eliminate.
   FlatAffineConstraints cstWithShapeBounds(cst);
   for (unsigned r = 0; r < rank; r++) {
-    cstWithShapeBounds.addConstantLowerBound(r, 0);
+    cstWithShapeBounds.addBound(FlatAffineConstraints::LB, r, 0);
     int64_t dimSize = memRefType.getDimSize(r);
     if (ShapedType::isDynamic(dimSize))
       continue;
-    cstWithShapeBounds.addConstantUpperBound(r, dimSize - 1);
+    cstWithShapeBounds.addBound(FlatAffineConstraints::UB, r, dimSize - 1);
   }
 
   // Find a constant upper bound on the extent of this memref region along each
@@ -518,7 +518,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
       // Check if the symbol is a constant.
       if (auto *op = symbol.getDefiningOp()) {
         if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
-          cst.setIdToConstant(symbol, constOp.getValue());
+          cst.addBound(FlatAffineConstraints::EQ, symbol, constOp.getValue());
         }
       }
     }
@@ -583,10 +583,11 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
   if (addMemRefDimBounds) {
     auto memRefType = memref.getType().cast<MemRefType>();
     for (unsigned r = 0; r < rank; r++) {
-      cst.addConstantLowerBound(/*pos=*/r, /*lb=*/0);
+      cst.addBound(FlatAffineConstraints::LB, /*pos=*/r, /*value=*/0);
       if (memRefType.isDynamicDim(r))
         continue;
-      cst.addConstantUpperBound(/*pos=*/r, memRefType.getDimSize(r) - 1);
+      cst.addBound(FlatAffineConstraints::UB, /*pos=*/r,
+                   memRefType.getDimSize(r) - 1);
     }
   }
   cst.removeTrivialRedundancy();
@@ -688,7 +689,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
       continue;
 
     // Check for overflow: d_i >= memref dim size.
-    ucst.addConstantLowerBound(r, dimSize);
+    ucst.addBound(FlatAffineConstraints::LB, r, dimSize);
     outOfBounds = !ucst.isEmpty();
     if (outOfBounds && emitError) {
       loadOrStoreOp.emitOpError()
@@ -699,7 +700,7 @@ LogicalResult mlir::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
     FlatAffineConstraints lcst(*region.getConstraints());
     std::fill(ineq.begin(), ineq.end(), 0);
     // d_i <= -1;
-    lcst.addConstantUpperBound(r, -1);
+    lcst.addBound(FlatAffineConstraints::UB, r, -1);
     outOfBounds = !lcst.isEmpty();
     if (outOfBounds && emitError) {
       loadOrStoreOp.emitOpError()

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index 2c9cd3ce6bc44..05b292ab72ed3 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -2695,8 +2695,8 @@ static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
   for (unsigned d = 0; d < rank; d++) {
     auto dimSize = memRefType.getDimSize(d);
     assert(dimSize > 0 && "filtered dynamic shapes above");
-    regionCst->addConstantLowerBound(d, 0);
-    regionCst->addConstantUpperBound(d, dimSize - 1);
+    regionCst->addBound(FlatAffineConstraints::LB, d, 0);
+    regionCst->addBound(FlatAffineConstraints::UB, d, dimSize - 1);
   }
   return true;
 }

diff  --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index fdff3160e95f6..d8ffab3852f65 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -722,8 +722,8 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
   for (unsigned d = 0; d < rank; ++d) {
     // Use constraint system only in static dimensions.
     if (shape[d] > 0) {
-      fac.addConstantLowerBound(d, 0);
-      fac.addConstantUpperBound(d, shape[d] - 1);
+      fac.addBound(FlatAffineConstraints::LB, d, 0);
+      fac.addBound(FlatAffineConstraints::UB, d, shape[d] - 1);
     } else {
       memrefTypeDynDims.emplace_back(d);
     }
@@ -746,7 +746,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
       newShape[d] = -1;
     } else {
       // The lower bound for the shape is always zero.
-      auto ubConst = fac.getConstantUpperBound(d);
+      auto ubConst = fac.getConstantBound(FlatAffineConstraints::UB, d);
       // For a static memref and an affine map with no symbols, this is
       // always bounded.
       assert(ubConst.hasValue() && "should always have an upper bound");

diff  --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp
index 4fe46c342c29e..971ca2b2ce303 100644
--- a/mlir/unittests/Analysis/AffineStructuresTest.cpp
+++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp
@@ -551,12 +551,12 @@ TEST(FlatAffineConstraintsTest, removeRedundantConstraintsTest) {
 
 TEST(FlatAffineConstraintsTest, addConstantUpperBound) {
   FlatAffineConstraints fac = makeFACFromConstraints(2, {}, {});
-  fac.addConstantUpperBound(0, 1);
+  fac.addBound(FlatAffineConstraints::UB, 0, 1);
   EXPECT_EQ(fac.atIneq(0, 0), -1);
   EXPECT_EQ(fac.atIneq(0, 1), 0);
   EXPECT_EQ(fac.atIneq(0, 2), 1);
 
-  fac.addConstantUpperBound({1, 2, 3}, 1);
+  fac.addBound(FlatAffineConstraints::UB, {1, 2, 3}, 1);
   EXPECT_EQ(fac.atIneq(1, 0), -1);
   EXPECT_EQ(fac.atIneq(1, 1), -2);
   EXPECT_EQ(fac.atIneq(1, 2), -2);
@@ -564,12 +564,12 @@ TEST(FlatAffineConstraintsTest, addConstantUpperBound) {
 
 TEST(FlatAffineConstraintsTest, addConstantLowerBound) {
   FlatAffineConstraints fac = makeFACFromConstraints(2, {}, {});
-  fac.addConstantLowerBound(0, 1);
+  fac.addBound(FlatAffineConstraints::LB, 0, 1);
   EXPECT_EQ(fac.atIneq(0, 0), 1);
   EXPECT_EQ(fac.atIneq(0, 1), 0);
   EXPECT_EQ(fac.atIneq(0, 2), -1);
 
-  fac.addConstantLowerBound({1, 2, 3}, 1);
+  fac.addBound(FlatAffineConstraints::LB, {1, 2, 3}, 1);
   EXPECT_EQ(fac.atIneq(1, 0), 1);
   EXPECT_EQ(fac.atIneq(1, 1), 2);
   EXPECT_EQ(fac.atIneq(1, 2), 2);


        


More information about the Mlir-commits mailing list