[Mlir-commits] [mlir] [mlir][affine][Analysis] Add conservative bounds for semi-affine mods (PR #93576)

Benjamin Maxwell llvmlistbot at llvm.org
Wed May 29 10:09:06 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/93576

>From 209a78c90b784217283da29883566a1f763608d2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 23 May 2024 11:09:55 +0000
Subject: [PATCH 1/3] [mlir][affine][Analysis] Add conservative bounds for
 semi-affine mods

This path adds support for computing bounds for semi-affine mod
expression to FlatLinearConstraints. This is then enabled within the
ScalableValueBoundsConstraintSet to allow computing the bounds of
scalable remainder loops.

E.g. computing the bound of something like:
```

%0 = affine.apply #remainder_start_index()[%c8_vscale]
scf.for %i = %0 to %c1000 step %c8_vscale {
  %remaining_iterations = affine.apply #remaining_iterations(%i)
  // The upper bound for the remainder loop iterations should be:
  // %c8_vscale - 1  (expressed as an affine map,
  // affine_map<()[s0] -> (s0 * 8 - 1)>, where s0 is vscale)
  %bound = "test.reify_bound"(%remaining_iterations) <{scalable, ...}>
}
```

There are caveats to this implementation. To be able to add a bound for
a `mod` we need to assume the rhs is positive (> 0). This may not be
known when adding the bounds for the `mod` expression. So to handle this
a constraint is added for `rhs > 0`, this may later be found not to hold
(in which case the constraints set becomes empty/invalid).

This is not a problem for computing scalable bounds where it's safe to
assume `s0` is vscale (or some positive multiple of it). But this may need
 to be considered when enabling this feature elsewhere (to ensure
correctness).
---
 .../Analysis/FlatLinearValueConstraints.h     | 49 +++++++---
 .../Analysis/Presburger/IntegerRelation.h     | 14 +++
 .../IR/ScalableValueBoundsConstraintSet.h     |  5 +-
 mlir/include/mlir/IR/AffineExprVisitor.h      |  6 +-
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  6 +-
 .../Analysis/FlatLinearValueConstraints.cpp   | 98 +++++++++++++------
 .../Analysis/Presburger/IntegerRelation.cpp   | 31 ++++++
 .../IR/ScalableValueBoundsConstraintSet.cpp   | 10 ++
 mlir/lib/IR/AffineExpr.cpp                    | 20 ++--
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 18 +++-
 .../Dialect/Vector/test-scalable-bounds.mlir  | 56 +++++++++++
 11 files changed, 255 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
index 29c19442a7c7c..a85e2790373bd 100644
--- a/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
+++ b/mlir/include/mlir/Analysis/FlatLinearValueConstraints.h
@@ -66,6 +66,10 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
   /// Return the kind of this object.
   Kind getKind() const override { return Kind::FlatLinearConstraints; }
 
+  /// Flag to control if conservative semi-affine bounds should be added in
+  /// `addBound()`.
+  enum class AddConservativeSemiAffineBounds { No = 0, Yes };
+
   /// Adds a bound for the variable at the specified position with constraints
   /// being drawn from the specified bound map. In case of an EQ bound, the
   /// bound map is expected to have exactly one result. In case of a LB/UB, the
@@ -77,21 +81,39 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
   /// as a closed bound by +1/-1 respectively. In case of an EQ bound, it can
   /// only be added as a closed bound.
   ///
+  /// Conservative bounds for semi-affine expressions will be added if
+  /// `AddConservativeSemiAffineBounds` is set to `Yes`. This currently does not
+  /// cover all semi-affine expressions, so `addBound()` still may fail with
+  /// this set. Note: If enabled it is possible for the resulting constraint set
+  /// to become empty if a precondition of a conservative bound is found not to
+  /// hold.
+  ///
   /// Note: The dimensions/symbols of this FlatLinearConstraints must match the
   /// dimensions/symbols of the affine map.
-  LogicalResult addBound(presburger::BoundType type, unsigned pos,
-                         AffineMap boundMap, bool isClosedBound);
+  LogicalResult addBound(
+      presburger::BoundType type, unsigned pos, AffineMap boundMap,
+      bool isClosedBound,
+      AddConservativeSemiAffineBounds = AddConservativeSemiAffineBounds::No);
 
   /// Adds a bound for the variable at the specified position with constraints
   /// being drawn from the specified bound map. In case of an EQ bound, the
   /// bound map is expected to have exactly one result. In case of a LB/UB, the
   /// bound map may have more than one result, for each of which an inequality
   /// is added.
+  ///
+  /// Conservative bounds for semi-affine expressions will be added if
+  /// `AddConservativeSemiAffineBounds` is set to `Yes`. This currently does not
+  /// cover all semi-affine expressions, so `addBound()` still may fail with
+  /// this set. If enabled it is possible for the resulting constraint set
+  /// to become empty if a precondition of a conservative bound is found not to
+  /// hold.
+  ///
   /// Note: The dimensions/symbols of this FlatLinearConstraints must match the
   /// dimensions/symbols of the affine map. By default the lower bound is closed
   /// and the upper bound is open.
-  LogicalResult addBound(presburger::BoundType type, unsigned pos,
-                         AffineMap boundMap);
+  LogicalResult addBound(
+      presburger::BoundType type, unsigned pos, AffineMap boundMap,
+      AddConservativeSemiAffineBounds = AddConservativeSemiAffineBounds::No);
 
   /// The `addBound` overload above hides the inherited overloads by default, so
   /// we explicitly introduce them here.
@@ -193,7 +215,8 @@ class FlatLinearConstraints : public presburger::IntegerPolyhedron {
   /// Note: This is a shared helper function of `addLowerOrUpperBound` and
   ///       `composeMatchingMap`.
   LogicalResult flattenAlignedMapAndMergeLocals(
-      AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs);
+      AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
+      bool addConservativeSemiAffineBounds = false);
 
   /// Prints the number of constraints, dimensions, symbols and locals in the
   /// FlatLinearConstraints. Also, prints for each variable whether there is
@@ -468,18 +491,19 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
 /// Flattens 'expr' into 'flattenedExpr', which contains the coefficients of the
 /// dimensions, symbols, and additional variables that represent floor divisions
 /// of dimensions, symbols, and in turn other floor divisions.  Returns failure
-/// if 'expr' could not be flattened (i.e., semi-affine is not yet handled).
+/// if 'expr' could not be flattened (i.e., an unhandled semi-affine was found).
 /// 'cst' contains constraints that connect newly introduced local variables
 /// to existing dimensional and symbolic variables. See documentation for
 /// AffineExprFlattener on how mod's and div's are flattened.
-LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
-                                     unsigned numSymbols,
-                                     SmallVectorImpl<int64_t> *flattenedExpr,
-                                     FlatLinearConstraints *cst = nullptr);
+LogicalResult
+getFlattenedAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
+                       SmallVectorImpl<int64_t> *flattenedExpr,
+                       FlatLinearConstraints *cst = nullptr,
+                       bool addConservativeSemiAffineBounds = false);
 
 /// Flattens the result expressions of the map to their corresponding flattened
 /// forms and set in 'flattenedExprs'. Returns failure if any expression in the
-/// map could not be flattened (i.e., semi-affine is not yet handled). 'cst'
+/// map could not be flattened (i.e., an unhandled semi-affine was found). 'cst'
 /// contains constraints that connect newly introduced local variables to
 /// existing dimensional and / symbolic variables. See documentation for
 /// AffineExprFlattener on how mod's and div's are flattened. For all affine
@@ -490,7 +514,8 @@ LogicalResult getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
 LogicalResult
 getFlattenedAffineExprs(AffineMap map,
                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
-                        FlatLinearConstraints *cst = nullptr);
+                        FlatLinearConstraints *cst = nullptr,
+                        bool addConservativeSemiAffineBounds = false);
 LogicalResult
 getFlattenedAffineExprs(IntegerSet set,
                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 163f365c623d7..c7e2e55372324 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -454,6 +454,20 @@ class IntegerRelation {
     addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor));
   }
 
+  /// Adds a new local variable as the mod of an affine function of other
+  /// variables. The coefficients of the operands of the mod are provided in
+  /// `lhs` and `rhs` respectively. Three constraints are added to provide a
+  /// conservative bound for the mod:
+  ///  1. rhs > 0 (assumption/precondition)
+  ///  2. lhs % rhs < rhs
+  ///  3. lhs % rhs >= 0
+  /// We ensure the rhs is positive so we can assume the result is positive.
+  void addLocalModConservativeBounds(ArrayRef<MPInt> lhs, ArrayRef<MPInt> rhs);
+  void addLocalModConservativeBounds(ArrayRef<int64_t> lhs,
+                                     ArrayRef<int64_t> rhs) {
+    addLocalModConservativeBounds(getMPIntVec(lhs), getMPIntVec(rhs));
+  }
+
   /// Projects out (aka eliminates) `num` variables starting at position
   /// `pos`. The resulting constraint system is the shadow along the dimensions
   /// that still exist. This method may not always be integer exact.
diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
index 67a6581eb2fb4..93b3c92533c54 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -33,8 +33,9 @@ struct ScalableValueBoundsConstraintSet
       MLIRContext *context,
       ValueBoundsConstraintSet::StopConditionFn stopCondition,
       unsigned vscaleMin, unsigned vscaleMax)
-      : RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
-        vscaleMax(vscaleMax) {};
+      : RTTIExtends(context, stopCondition,
+                    /*addConservativeSemiAffineBounds=*/true),
+        vscaleMin(vscaleMin), vscaleMax(vscaleMax) {};
 
   using RTTIExtends::bound;
   using RTTIExtends::StopConditionFn;
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index 27c49cd80018e..bff9c9d4a029c 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -413,7 +413,8 @@ class SimpleAffineExprFlattener
   /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
   /// symbolic rhs expression. `localExpr` is the simplified tree expression
   /// (AffineExpr) corresponding to the quantifier.
-  virtual void addLocalIdSemiAffine(AffineExpr localExpr);
+  virtual void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
+                                    ArrayRef<int64_t> rhs);
 
 private:
   /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
@@ -422,7 +423,8 @@ class SimpleAffineExprFlattener
   /// quantifier is already present, we put the coefficient in the proper index
   /// of `result`, otherwise we add a new local variable and put the coefficient
   /// there.
-  void addLocalVariableSemiAffine(AffineExpr expr,
+  void addLocalVariableSemiAffine(AffineExpr expr, ArrayRef<int64_t> lhs,
+                                  ArrayRef<int64_t> rhs,
                                   SmallVectorImpl<int64_t> &result,
                                   unsigned long resultSize);
 
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index ac17ace5a976d..337314143c80c 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -313,7 +313,8 @@ class ValueBoundsConstraintSet
   /// An index-typed value or the dimension of a shaped-type value.
   using ValueDim = std::pair<Value, int64_t>;
 
-  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
+  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition,
+                           bool addConservativeSemiAffineBounds = false);
 
   /// Return "true" if, based on the current state of the constraint system,
   /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
@@ -404,6 +405,9 @@ class ValueBoundsConstraintSet
 
   /// The current stop condition function.
   StopConditionFn stopCondition = nullptr;
+
+  /// Should conservative bounds be added for semi-affine expressions.
+  bool addConservativeSemiAffineBounds = false;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 8b38016d61498..35ee989272899 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -46,9 +46,15 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
   // inequalities.
   IntegerPolyhedron localVarCst;
 
-  AffineExprFlattener(unsigned nDims, unsigned nSymbols)
+  AffineExprFlattener(unsigned nDims, unsigned nSymbols,
+                      bool addConservativeSemiAffineBounds = false)
       : SimpleAffineExprFlattener(nDims, nSymbols),
-        localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
+        localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)),
+        addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {}
+
+  bool hasUnhandledSemiAffineExpressions() const {
+    return unhandledSemiAffineExpressions;
+  }
 
 private:
   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
@@ -63,35 +69,61 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
     // Update localVarCst.
     localVarCst.addLocalFloorDiv(dividend, divisor);
   }
+
+  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
+  // expr) when the rhs is a symbolic expression. The local identifier added
+  // may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
+  // function of other identifiers, coefficients of which are specified in the
+  // lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
+  // symbolic rhs expression. `localExpr` is the simplified tree expression
+  // (AffineExpr) corresponding to the quantifier.
+  void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
+                            ArrayRef<int64_t> rhs) override {
+    SimpleAffineExprFlattener::addLocalIdSemiAffine(localExpr, lhs, rhs);
+    if (!addConservativeSemiAffineBounds) {
+      unhandledSemiAffineExpressions = true;
+      return;
+    }
+    if (localExpr.getKind() == AffineExprKind::Mod) {
+      localVarCst.addLocalModConservativeBounds(lhs, rhs);
+      return;
+    }
+    // TODO: Support other semi-affine expressions.
+    unhandledSemiAffineExpressions = true;
+  }
+
+  bool addConservativeSemiAffineBounds = false;
+  bool unhandledSemiAffineExpressions = false;
 };
 
 } // namespace
 
 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
 // flattened. For example two specific cases:
-// 1. semi-affine expressions not handled yet.
+// 1. an unhandled semi-affine expressions is found.
 // 2. has poison expression (i.e., division by zero).
 static LogicalResult
 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
                         unsigned numSymbols,
                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
-                        FlatLinearConstraints *localVarCst) {
+                        FlatLinearConstraints *localVarCst,
+                        bool addConservativeSemiAffineBounds = false) {
   if (exprs.empty()) {
     if (localVarCst)
       *localVarCst = FlatLinearConstraints(numDims, numSymbols);
     return success();
   }
 
-  AffineExprFlattener flattener(numDims, numSymbols);
+  AffineExprFlattener flattener(numDims, numSymbols,
+                                addConservativeSemiAffineBounds);
   // Use the same flattener to simplify each expression successively. This way
   // local variables / expressions are shared.
   for (auto expr : exprs) {
-    if (!expr.isPureAffine())
-      return failure();
-    // has poison expression
     auto flattenResult = flattener.walkPostOrder(expr);
     if (failed(flattenResult))
       return failure();
+    if (flattener.hasUnhandledSemiAffineExpressions())
+      return failure();
   }
 
   assert(flattener.operandExprStack.size() == exprs.size());
@@ -106,33 +138,33 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
 }
 
 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
-// be flattened (semi-affine expressions not handled yet).
-LogicalResult
-mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
-                             unsigned numSymbols,
-                             SmallVectorImpl<int64_t> *flattenedExpr,
-                             FlatLinearConstraints *localVarCst) {
+// be flattened (an unhandled semi-affine was found).
+LogicalResult mlir::getFlattenedAffineExpr(
+    AffineExpr expr, unsigned numDims, unsigned numSymbols,
+    SmallVectorImpl<int64_t> *flattenedExpr, FlatLinearConstraints *localVarCst,
+    bool addConservativeSemiAffineBounds) {
   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
-  LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
-                                                &flattenedExprs, localVarCst);
+  LogicalResult ret =
+      ::getFlattenedAffineExprs({expr}, numDims, numSymbols, &flattenedExprs,
+                                localVarCst, addConservativeSemiAffineBounds);
   *flattenedExpr = flattenedExprs[0];
   return ret;
 }
 
 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
-/// flattened (i.e., semi-affine expressions not handled yet).
+/// flattened (i.e., an unhandled semi-affine was found).
 LogicalResult mlir::getFlattenedAffineExprs(
     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
-    FlatLinearConstraints *localVarCst) {
+    FlatLinearConstraints *localVarCst, bool addConservativeSemiAffineBounds) {
   if (map.getNumResults() == 0) {
     if (localVarCst)
       *localVarCst =
           FlatLinearConstraints(map.getNumDims(), map.getNumSymbols());
     return success();
   }
-  return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
-                                   map.getNumSymbols(), flattenedExprs,
-                                   localVarCst);
+  return ::getFlattenedAffineExprs(
+      map.getResults(), map.getNumDims(), map.getNumSymbols(), flattenedExprs,
+      localVarCst, addConservativeSemiAffineBounds);
 }
 
 LogicalResult mlir::getFlattenedAffineExprs(
@@ -641,9 +673,11 @@ void FlatLinearConstraints::getSliceBounds(unsigned offset, unsigned num,
 }
 
 LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
-    AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
+    AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
+    bool addConservativeSemiAffineBounds) {
   FlatLinearConstraints localCst;
-  if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
+  if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst,
+                                     addConservativeSemiAffineBounds))) {
     LLVM_DEBUG(llvm::dbgs()
                << "composition unimplemented for semi-affine maps\n");
     return failure();
@@ -664,9 +698,9 @@ LogicalResult FlatLinearConstraints::flattenAlignedMapAndMergeLocals(
   return success();
 }
 
-LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
-                                              AffineMap boundMap,
-                                              bool isClosedBound) {
+LogicalResult FlatLinearConstraints::addBound(
+    BoundType type, unsigned pos, AffineMap boundMap, bool isClosedBound,
+    AddConservativeSemiAffineBounds addSemiAffineBounds) {
   assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
   assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
   assert(pos < getNumDimAndSymbolVars() && "invalid position");
@@ -680,7 +714,9 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
   bool lower = type == BoundType::LB || type == BoundType::EQ;
 
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
+  if (failed(flattenAlignedMapAndMergeLocals(
+          boundMap, &flatExprs,
+          addSemiAffineBounds == AddConservativeSemiAffineBounds::Yes)))
     return failure();
   assert(flatExprs.size() == boundMap.getNumResults());
 
@@ -716,9 +752,11 @@ LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
   return success();
 }
 
-LogicalResult FlatLinearConstraints::addBound(BoundType type, unsigned pos,
-                                              AffineMap boundMap) {
-  return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB);
+LogicalResult FlatLinearConstraints::addBound(
+    BoundType type, unsigned pos, AffineMap boundMap,
+    AddConservativeSemiAffineBounds addSemiAffineBounds) {
+  return addBound(type, pos, boundMap,
+                  /*isClosedBound=*/type != BoundType::UB, addSemiAffineBounds);
 }
 
 /// Compute an explicit representation for local vars. For all systems coming
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index b5a2ed6ccc369..798f9deaa4028 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1521,6 +1521,37 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<MPInt> dividend,
       getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
 }
 
+/// Adds a new local variable as the mod of an affine function of other
+/// variables. The coefficients of the operands of the mod are provided in `lhs`
+/// and `rhs` respectively. Three constraints are added to provide a
+/// conservative bound for the mod:
+///  1. rhs > 0 (assumption/precondition)
+///  2. lhs % rhs < rhs
+///  3. lhs % rhs >= 0
+/// We ensure the rhs is positive so we can assume the result is positive.
+void IntegerRelation::addLocalModConservativeBounds(ArrayRef<MPInt> lhs,
+                                                    ArrayRef<MPInt> rhs) {
+  appendVar(VarKind::Local);
+
+  // Ensure the rhs is > 0 (most likely case).
+  // If this constraint does not hold the following bounds are incorrect.
+  SmallVector<MPInt, 8> rhsCopy(rhs);
+  rhsCopy.insert(rhsCopy.end() - 1, MPInt(0));
+  rhsCopy.back() -= MPInt(1);
+  addInequality(rhsCopy);
+
+  // rhs - (lhs % rhs) - 1 >= 0 i.e. lhs % rhs < rhs
+  SmallVector<MPInt, 8> resultUpperBound(rhs);
+  resultUpperBound.insert(resultUpperBound.end() - 1, MPInt(-1));
+  resultUpperBound.back() -= MPInt(1);
+  addInequality(resultUpperBound);
+
+  // lhs % rhs >= 0
+  SmallVector<MPInt, 8> resultLowerBound(rhs.size());
+  resultLowerBound.insert(resultLowerBound.end() - 1, MPInt(1));
+  addInequality(resultLowerBound);
+}
+
 /// Finds an equality that equates the specified variable to a constant.
 /// Returns the position of the equality row. If 'symbolic' is set to true,
 /// symbols are also treated like a constant, i.e., an affine function of the
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index f8df34843a363..f7669335a241c 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
 
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "llvm/Support/Debug.h"
 
 namespace mlir::vector {
 
@@ -62,6 +63,10 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
   int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
   scalableCstr.processWorklist();
 
+  // Check the resulting constraints set is valid.
+  if (scalableCstr.cstr.isEmpty())
+    return failure();
+
   // Project out all columns apart from vscale and the starting point
   // (value/dim). This should result in constraints in terms of vscale only.
   auto projectOutFn = [&](ValueDim p) {
@@ -71,6 +76,11 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
     return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
   };
   scalableCstr.projectOut(projectOutFn);
+  // Also project out local variables (these are not tracked by the
+  // ValueBoundsConstraintSet).
+  for (unsigned i = 0; i < scalableCstr.cstr.getNumLocalVars(); ++i) {
+    scalableCstr.cstr.projectOut(scalableCstr.cstr.getNumDimAndSymbolVars());
+  }
 
   assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
              scalableCstr.positionToValueDim.size() &&
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 94562d0f15a24..540931521182e 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1242,12 +1242,13 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
   // variable in place of the product; the affine expression
   // corresponding to the quantifier is added to `localExprs`.
   if (!isa<AffineConstantExpr>(expr.getRHS())) {
+    SmallVector<int64_t, 8> mulLhs(lhs);
     MLIRContext *context = expr.getContext();
     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
-    addLocalVariableSemiAffine(a * b, lhs, lhs.size());
+    addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
     return success();
   }
 
@@ -1295,12 +1296,13 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
   // variable in place of the modulo value, and the affine expression
   // corresponding to the quantifier is added to `localExprs`.
   if (!isa<AffineConstantExpr>(expr.getRHS())) {
+    SmallVector<int64_t, 8> modLhs(lhs);
     AffineExpr dividendExpr = getAffineExprFromFlatForm(
         lhs, numDims, numSymbols, localExprs, context);
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
-    addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
+    addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
     return success();
   }
 
@@ -1386,13 +1388,13 @@ SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
 }
 
 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
-    AffineExpr expr, SmallVectorImpl<int64_t> &result,
-    unsigned long resultSize) {
+    AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+    SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
   assert(result.size() == resultSize &&
          "`result` vector passed is not of correct size");
   int loc;
   if ((loc = findLocalId(expr)) == -1)
-    addLocalIdSemiAffine(expr);
+    addLocalIdSemiAffine(expr, lhs, rhs);
   std::fill(result.begin(), result.end(), 0);
   if (loc == -1)
     result[getLocalVarStartIndex() + numLocals - 1] = 1;
@@ -1426,12 +1428,13 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
   // variable in place of the quotient, and the affine expression corresponding
   // to the quantifier is added to `localExprs`.
   if (!isa<AffineConstantExpr>(expr.getRHS())) {
+    SmallVector<int64_t, 8> divLhs(lhs);
     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
-    addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
+    addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
     return success();
   }
 
@@ -1503,11 +1506,14 @@ void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
   // dividend and divisor are not used here; an override of this method uses it.
 }
 
-void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
+void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr,
+                                                     ArrayRef<int64_t> lhs,
+                                                     ArrayRef<int64_t> rhs) {
   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
   localExprs.push_back(localExpr);
   ++numLocals;
+  // lhs and rhs are not used here; an override of this method uses them.
 }
 
 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 87937591e60ad..6420c192b257d 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -151,8 +151,10 @@ ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
                                         [](Value v) { return Variable(v); })) {}
 
 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
-    MLIRContext *ctx, StopConditionFn stopCondition)
-    : builder(ctx), stopCondition(stopCondition) {
+    MLIRContext *ctx, StopConditionFn stopCondition,
+    bool addConservativeSemiAffineBounds)
+    : builder(ctx), stopCondition(stopCondition),
+      addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {
   assert(stopCondition && "expected non-null stop condition");
 }
 
@@ -174,11 +176,19 @@ static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
 
 void ValueBoundsConstraintSet::addBound(BoundType type, int64_t pos,
                                         AffineExpr expr) {
+  // Note: If `addConservativeSemiAffineBounds` is true then the bound
+  // computation function needs to handle the case that the constraints set
+  // could become empty. This is because the conservative bounds add assumptions
+  // (e.g. for `mod` it assumes `rhs > 0`). If these constraints are later found
+  // not to hold, then the bound is invalid.
   LogicalResult status = cstr.addBound(
       type, pos,
-      AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr));
+      AffineMap::get(cstr.getNumDimVars(), cstr.getNumSymbolVars(), expr),
+      addConservativeSemiAffineBounds
+          ? FlatLinearConstraints::AddConservativeSemiAffineBounds::Yes
+          : FlatLinearConstraints::AddConservativeSemiAffineBounds::No);
   if (failed(status)) {
-    // Non-pure (e.g., semi-affine) expressions are not yet supported by
+    // Not all semi-affine expressions are not yet supported by
     // FlatLinearConstraints. However, we can just ignore such failures here.
     // Even without this bound, there may be enough information in the
     // constraint system to compute the requested bound. In case this bound is
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index d549c5bd1c378..e24dfa6affebb 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -159,3 +159,59 @@ func.func @non_scalable_code() {
   }
   return
 }
+
+// -----
+
+#map_remainder_start = affine_map<()[s0] -> (-(1000 mod s0) + 1000)>
+#map_remainder_size = affine_map<(d0) -> (-d0 + 1000)>
+
+// CHECK: #[[$REMAINDER_START_MAP:.*]] = affine_map<()[s0] -> (-(1000 mod s0) + 1000)>
+// CHECK: #[[$SCALABLE_BOUND_MAP_4:.*]] = affine_map<()[s0] -> (s0 * 8 - 1)>
+
+// CHECK-LABEL: @test_scalable_remainder_loop
+//       CHECK:   %[[VSCALE:.*]] = vector.vscale
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_4]]()[%[[VSCALE]]]
+//       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
+func.func @test_scalable_remainder_loop() {
+  %c8 = arith.constant 8 : index
+  %c1000 = arith.constant 1000 : index
+  %vscale = vector.vscale
+  %c8_vscale = arith.muli %vscale, %c8 : index
+  %0 = affine.apply #map_remainder_start()[%c8_vscale]
+  scf.for %arg1 = %0 to %c1000 step %c8_vscale {
+    %remainder_trip_count = affine.apply #map_remainder_size(%arg1)
+    // The upper bound for the remainder loop iterations should be: %c8_vscale - 1
+    // (expressed as an affine map, affine_map<()[s0] -> (s0 * 8 - 1)>, where s0 is vscale)
+    %bound = "test.reify_bound"(%remainder_trip_count) <{scalable, type = "UB", vscale_min = 1 : i64, vscale_max = 16 : i64}> : (index) -> index
+    "test.some_use"(%bound) : (index) -> ()
+  }
+  return
+}
+
+// -----
+
+#unsupported_semi_affine = affine_map<()[s0] -> (s0 * s0)>
+
+func.func @unsupported_semi_affine() {
+  %vscale = vector.vscale
+  %0 = affine.apply #unsupported_semi_affine()[%vscale]
+  // expected-error @below{{could not reify bound}}
+  %bound = "test.reify_bound"(%0) <{scalable, type = "UB", vscale_min = 1 : i64, vscale_max = 16 : i64}> : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}
+
+// -----
+
+#map_mod = affine_map<()[s0] -> (1000 mod s0)>
+
+func.func @unsupported_negative_mod() {
+  %c_minus_1 = arith.constant -1 : index
+  %vscale = vector.vscale
+  %negative_vscale = arith.muli %vscale, %c_minus_1 : index
+  %0 = affine.apply #map_mod()[%negative_vscale]
+  // expected-error @below{{could not reify bound}}
+  %bound = "test.reify_bound"(%0) <{scalable, type = "UB", vscale_min = 1 : i64, vscale_max = 16 : i64}> : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}

>From 179b27bcef7f5157aafd79537a8703abc0c262a4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 29 May 2024 11:03:33 +0000
Subject: [PATCH 2/3] Keep changes within FlatLinearValueConstraints

---
 .../Analysis/Presburger/IntegerRelation.h     |  14 --
 mlir/include/mlir/IR/AffineExprVisitor.h      |  14 +-
 .../Analysis/FlatLinearValueConstraints.cpp   | 145 ++++++++++++------
 .../Analysis/Presburger/IntegerRelation.cpp   |  31 ----
 .../IR/ScalableValueBoundsConstraintSet.cpp   |   3 -
 mlir/lib/IR/AffineExpr.cpp                    |  24 +--
 6 files changed, 114 insertions(+), 117 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index c7e2e55372324..163f365c623d7 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -454,20 +454,6 @@ class IntegerRelation {
     addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor));
   }
 
-  /// Adds a new local variable as the mod of an affine function of other
-  /// variables. The coefficients of the operands of the mod are provided in
-  /// `lhs` and `rhs` respectively. Three constraints are added to provide a
-  /// conservative bound for the mod:
-  ///  1. rhs > 0 (assumption/precondition)
-  ///  2. lhs % rhs < rhs
-  ///  3. lhs % rhs >= 0
-  /// We ensure the rhs is positive so we can assume the result is positive.
-  void addLocalModConservativeBounds(ArrayRef<MPInt> lhs, ArrayRef<MPInt> rhs);
-  void addLocalModConservativeBounds(ArrayRef<int64_t> lhs,
-                                     ArrayRef<int64_t> rhs) {
-    addLocalModConservativeBounds(getMPIntVec(lhs), getMPIntVec(rhs));
-  }
-
   /// Projects out (aka eliminates) `num` variables starting at position
   /// `pos`. The resulting constraint system is the shadow along the dimensions
   /// that still exist. This method may not always be integer exact.
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index bff9c9d4a029c..be4bc84e02d57 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -413,8 +413,9 @@ class SimpleAffineExprFlattener
   /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
   /// symbolic rhs expression. `localExpr` is the simplified tree expression
   /// (AffineExpr) corresponding to the quantifier.
-  virtual void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
-                                    ArrayRef<int64_t> rhs);
+  virtual LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
+                                             ArrayRef<int64_t> rhs,
+                                             AffineExpr localExpr);
 
 private:
   /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
@@ -423,10 +424,11 @@ class SimpleAffineExprFlattener
   /// quantifier is already present, we put the coefficient in the proper index
   /// of `result`, otherwise we add a new local variable and put the coefficient
   /// there.
-  void addLocalVariableSemiAffine(AffineExpr expr, ArrayRef<int64_t> lhs,
-                                  ArrayRef<int64_t> rhs,
-                                  SmallVectorImpl<int64_t> &result,
-                                  unsigned long resultSize);
+  LogicalResult addLocalVariableSemiAffine(AffineExpr expr,
+                                           ArrayRef<int64_t> lhs,
+                                           ArrayRef<int64_t> rhs,
+                                           SmallVectorImpl<int64_t> &result,
+                                           unsigned long resultSize);
 
   // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
   // A floordiv is thus flattened by introducing a new local variable q, and
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 35ee989272899..9e3b84ea33c0e 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -36,25 +36,20 @@ using namespace presburger;
 namespace {
 
 // See comments for SimpleAffineExprFlattener.
-// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
-// constraint information associated with mod's, floordiv's, and ceildiv's
-// in FlatLinearConstraints 'localVarCst'.
-struct AffineExprFlattener : public SimpleAffineExprFlattener {
-public:
+// An AffineExprFlattenerWithLocalVars extends a SimpleAffineExprFlattener by
+// recording constraint information associated with mod's, floordiv's, and
+// ceildiv's in FlatLinearConstraints 'localVarCst'.
+struct AffineExprFlattenerWithLocalVars : public SimpleAffineExprFlattener {
+  using SimpleAffineExprFlattener::SimpleAffineExprFlattener;
+
   // Constraints connecting newly introduced local variables (for mod's and
   // div's) to existing (dimensional and symbolic) ones. These are always
   // inequalities.
   IntegerPolyhedron localVarCst;
 
-  AffineExprFlattener(unsigned nDims, unsigned nSymbols,
-                      bool addConservativeSemiAffineBounds = false)
+  AffineExprFlattenerWithLocalVars(unsigned nDims, unsigned nSymbols)
       : SimpleAffineExprFlattener(nDims, nSymbols),
-        localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)),
-        addConservativeSemiAffineBounds(addConservativeSemiAffineBounds) {}
-
-  bool hasUnhandledSemiAffineExpressions() const {
-    return unhandledSemiAffineExpressions;
-  }
+        localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {};
 
 private:
   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
@@ -70,30 +65,71 @@ struct AffineExprFlattener : public SimpleAffineExprFlattener {
     localVarCst.addLocalFloorDiv(dividend, divisor);
   }
 
-  // Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
-  // expr) when the rhs is a symbolic expression. The local identifier added
-  // may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
-  // function of other identifiers, coefficients of which are specified in the
-  // lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
-  // symbolic rhs expression. `localExpr` is the simplified tree expression
-  // (AffineExpr) corresponding to the quantifier.
-  void addLocalIdSemiAffine(AffineExpr localExpr, ArrayRef<int64_t> lhs,
-                            ArrayRef<int64_t> rhs) override {
-    SimpleAffineExprFlattener::addLocalIdSemiAffine(localExpr, lhs, rhs);
-    if (!addConservativeSemiAffineBounds) {
-      unhandledSemiAffineExpressions = true;
-      return;
-    }
+  // Semi-affine expressions are not supported by all flatteners.
+  LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
+                                     ArrayRef<int64_t> rhs,
+                                     AffineExpr localExpr) override = 0;
+};
+
+// An AffineExprFlattener is an AffineExprFlattenerWithLocalVars that explicitly
+// disallows semi-affine expressions. Flattening will fail if a semi-affine
+// expression is encountered.
+struct AffineExprFlattener : public AffineExprFlattenerWithLocalVars {
+  using AffineExprFlattenerWithLocalVars::AffineExprFlattenerWithLocalVars;
+
+  LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
+                                     ArrayRef<int64_t> rhs,
+                                     AffineExpr localExpr) override {
+    // AffineExprFlattener does not support semi-affine expressions.
+    return failure();
+  }
+};
+
+// A SemiAffineExprFlattener is an AffineExprFlattenerWithLocalVars that adds
+// conservative bounds for semi-affine expressions (given assumptions hold). If
+// the assumptions required to add the semi-affine bounds are found not to hold
+// the final constraints set will be empty/inconsistent. If the assumptions are
+// never contradicted the final bounds still only will be correct if the
+// assumptions hold.
+struct SemiAffineExprFlattener : public AffineExprFlattenerWithLocalVars {
+  using AffineExprFlattenerWithLocalVars::AffineExprFlattenerWithLocalVars;
+
+  LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
+                                     ArrayRef<int64_t> rhs,
+                                     AffineExpr localExpr) override {
+    auto result =
+        SimpleAffineExprFlattener::addLocalIdSemiAffine(lhs, rhs, localExpr);
+    assert(succeeded(result) &&
+           "unexpected failure in SimpleAffineExprFlattener");
+    (void)result;
+
     if (localExpr.getKind() == AffineExprKind::Mod) {
-      localVarCst.addLocalModConservativeBounds(lhs, rhs);
-      return;
+      localVarCst.appendVar(VarKind::Local);
+      // Add a conservative bound for `mod` assuming the rhs is > 0.
+
+      // Note: If the rhs is later found to be < 0 the following two constraints
+      // will contradict each other (and lead to the final constraints set
+      // becoming empty). If the sign of the rhs is never specified the bound
+      // will assume it is positive.
+
+      // Upper bound: rhs - (lhs % rhs) - 1 >= 0 i.e. lhs % rhs < rhs
+      // This only holds if the rhs is > 0.
+      SmallVector<int64_t, 8> resultUpperBound(rhs);
+      resultUpperBound.insert(resultUpperBound.end() - 1, -1);
+      --resultUpperBound.back();
+      localVarCst.addInequality(resultUpperBound);
+
+      // Lower bound: lhs % rhs >= 0 (always holds)
+      SmallVector<int64_t, 8> resultLowerBound(rhs.size());
+      resultLowerBound.insert(resultLowerBound.end() - 1, 1);
+      localVarCst.addInequality(resultLowerBound);
+
+      return success();
     }
+
     // TODO: Support other semi-affine expressions.
-    unhandledSemiAffineExpressions = true;
+    return failure();
   }
-
-  bool addConservativeSemiAffineBounds = false;
-  bool unhandledSemiAffineExpressions = false;
 };
 
 } // namespace
@@ -114,27 +150,34 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
     return success();
   }
 
-  AffineExprFlattener flattener(numDims, numSymbols,
-                                addConservativeSemiAffineBounds);
-  // Use the same flattener to simplify each expression successively. This way
-  // local variables / expressions are shared.
-  for (auto expr : exprs) {
-    auto flattenResult = flattener.walkPostOrder(expr);
-    if (failed(flattenResult))
-      return failure();
-    if (flattener.hasUnhandledSemiAffineExpressions())
-      return failure();
-  }
+  auto flattenExprs =
+      [&](AffineExprFlattenerWithLocalVars &flattener) -> LogicalResult {
+    // Use the same flattener to simplify each expression successively. This way
+    // local variables / expressions are shared.
+    for (auto expr : exprs) {
+      auto flattenResult = flattener.walkPostOrder(expr);
+      if (failed(flattenResult))
+        return failure();
+    }
+
+    assert(flattener.operandExprStack.size() == exprs.size());
+    flattenedExprs->clear();
+    flattenedExprs->assign(flattener.operandExprStack.begin(),
+                           flattener.operandExprStack.end());
 
-  assert(flattener.operandExprStack.size() == exprs.size());
-  flattenedExprs->clear();
-  flattenedExprs->assign(flattener.operandExprStack.begin(),
-                         flattener.operandExprStack.end());
+    if (localVarCst)
+      localVarCst->clearAndCopyFrom(flattener.localVarCst);
+
+    return success();
+  };
 
-  if (localVarCst)
-    localVarCst->clearAndCopyFrom(flattener.localVarCst);
+  if (addConservativeSemiAffineBounds) {
+    SemiAffineExprFlattener flattener(numDims, numSymbols);
+    return flattenExprs(flattener);
+  }
 
-  return success();
+  AffineExprFlattener flattener(numDims, numSymbols);
+  return flattenExprs(flattener);
 }
 
 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 798f9deaa4028..b5a2ed6ccc369 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1521,37 +1521,6 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<MPInt> dividend,
       getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
 }
 
-/// Adds a new local variable as the mod of an affine function of other
-/// variables. The coefficients of the operands of the mod are provided in `lhs`
-/// and `rhs` respectively. Three constraints are added to provide a
-/// conservative bound for the mod:
-///  1. rhs > 0 (assumption/precondition)
-///  2. lhs % rhs < rhs
-///  3. lhs % rhs >= 0
-/// We ensure the rhs is positive so we can assume the result is positive.
-void IntegerRelation::addLocalModConservativeBounds(ArrayRef<MPInt> lhs,
-                                                    ArrayRef<MPInt> rhs) {
-  appendVar(VarKind::Local);
-
-  // Ensure the rhs is > 0 (most likely case).
-  // If this constraint does not hold the following bounds are incorrect.
-  SmallVector<MPInt, 8> rhsCopy(rhs);
-  rhsCopy.insert(rhsCopy.end() - 1, MPInt(0));
-  rhsCopy.back() -= MPInt(1);
-  addInequality(rhsCopy);
-
-  // rhs - (lhs % rhs) - 1 >= 0 i.e. lhs % rhs < rhs
-  SmallVector<MPInt, 8> resultUpperBound(rhs);
-  resultUpperBound.insert(resultUpperBound.end() - 1, MPInt(-1));
-  resultUpperBound.back() -= MPInt(1);
-  addInequality(resultUpperBound);
-
-  // lhs % rhs >= 0
-  SmallVector<MPInt, 8> resultLowerBound(rhs.size());
-  resultLowerBound.insert(resultLowerBound.end() - 1, MPInt(1));
-  addInequality(resultLowerBound);
-}
-
 /// Finds an equality that equates the specified variable to a constant.
 /// Returns the position of the equality row. If 'symbolic' is set to true,
 /// symbols are also treated like a constant, i.e., an affine function of the
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index f7669335a241c..156e73c33c06f 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -7,10 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
-
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "llvm/Support/Debug.h"
-
 namespace mlir::vector {
 
 FailureOr<ConstantOrScalableBound::BoundSize>
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 540931521182e..9b0867894947c 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1248,8 +1248,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
                                              localExprs, context);
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
-    addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
-    return success();
+    return addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
   }
 
   // Get the RHS constant.
@@ -1302,8 +1301,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
-    addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
-    return success();
+    return addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
   }
 
   int64_t rhsConst = rhs[getConstantIndex()];
@@ -1387,19 +1385,22 @@ SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
   return success();
 }
 
-void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
+LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
     AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
     SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
   assert(result.size() == resultSize &&
          "`result` vector passed is not of correct size");
   int loc;
-  if ((loc = findLocalId(expr)) == -1)
-    addLocalIdSemiAffine(expr, lhs, rhs);
+  if ((loc = findLocalId(expr)) == -1) {
+    if (failed(addLocalIdSemiAffine(lhs, rhs, expr)))
+      return failure();
+  }
   std::fill(result.begin(), result.end(), 0);
   if (loc == -1)
     result[getLocalVarStartIndex() + numLocals - 1] = 1;
   else
     result[getLocalVarStartIndex() + loc] = 1;
+  return success();
 }
 
 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
@@ -1434,8 +1435,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
-    addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
-    return success();
+    return addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
   }
 
   // This is a pure affine expr; the RHS is a positive constant.
@@ -1506,14 +1506,14 @@ void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
   // dividend and divisor are not used here; an override of this method uses it.
 }
 
-void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr,
-                                                     ArrayRef<int64_t> lhs,
-                                                     ArrayRef<int64_t> rhs) {
+LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
+    ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
   localExprs.push_back(localExpr);
   ++numLocals;
   // lhs and rhs are not used here; an override of this method uses them.
+  return success();
 }
 
 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {

>From f2957893005eaae068ddd2ca3c3beb28b2c73465 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 29 May 2024 17:06:59 +0000
Subject: [PATCH 3/3] Consistent naming/order

---
 mlir/include/mlir/IR/AffineExprVisitor.h |  8 ++++----
 mlir/lib/IR/AffineExpr.cpp               | 12 ++++++------
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h
index be4bc84e02d57..fc4cd915d8453 100644
--- a/mlir/include/mlir/IR/AffineExprVisitor.h
+++ b/mlir/include/mlir/IR/AffineExprVisitor.h
@@ -418,15 +418,15 @@ class SimpleAffineExprFlattener
                                              AffineExpr localExpr);
 
 private:
-  /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
+  /// Adds `localExpr`, which may be mod, ceildiv, floordiv or mod expression
   /// representing the affine expression corresponding to the quantifier
-  /// introduced as the local variable corresponding to `expr`. If the
+  /// introduced as the local variable corresponding to `localExpr`. If the
   /// quantifier is already present, we put the coefficient in the proper index
   /// of `result`, otherwise we add a new local variable and put the coefficient
   /// there.
-  LogicalResult addLocalVariableSemiAffine(AffineExpr expr,
-                                           ArrayRef<int64_t> lhs,
+  LogicalResult addLocalVariableSemiAffine(ArrayRef<int64_t> lhs,
                                            ArrayRef<int64_t> rhs,
+                                           AffineExpr localExpr,
                                            SmallVectorImpl<int64_t> &result,
                                            unsigned long resultSize);
 
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index 9b0867894947c..5f2016470b25f 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1248,7 +1248,7 @@ LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
                                              localExprs, context);
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
-    return addLocalVariableSemiAffine(a * b, mulLhs, rhs, lhs, lhs.size());
+    return addLocalVariableSemiAffine(mulLhs, rhs, a * b, lhs, lhs.size());
   }
 
   // Get the RHS constant.
@@ -1301,7 +1301,7 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                                        localExprs, context);
     AffineExpr modExpr = dividendExpr % divisorExpr;
-    return addLocalVariableSemiAffine(modExpr, modLhs, rhs, lhs, lhs.size());
+    return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
   }
 
   int64_t rhsConst = rhs[getConstantIndex()];
@@ -1386,13 +1386,13 @@ SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
 }
 
 LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
-    AffineExpr expr, ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs,
+    ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr,
     SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
   assert(result.size() == resultSize &&
          "`result` vector passed is not of correct size");
   int loc;
-  if ((loc = findLocalId(expr)) == -1) {
-    if (failed(addLocalIdSemiAffine(lhs, rhs, expr)))
+  if ((loc = findLocalId(localExpr)) == -1) {
+    if (failed(addLocalIdSemiAffine(lhs, rhs, localExpr)))
       return failure();
   }
   std::fill(result.begin(), result.end(), 0);
@@ -1435,7 +1435,7 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
                                              localExprs, context);
     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
-    return addLocalVariableSemiAffine(divExpr, divLhs, rhs, lhs, lhs.size());
+    return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
   }
 
   // This is a pure affine expr; the RHS is a positive constant.



More information about the Mlir-commits mailing list