[llvm-branch-commits] [mlir] [mlir][Interfaces] `ValueBoundsOpInterface`: Add API to compare values (PR #86915)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Mar 27 23:24:52 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86915

This commit adds a new public API to `ValueBoundsOpInterface` to compare values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT.

The new `ValueBoundsOpInterface::compare` API replaces and generalized `ValueBoundsOpInterface::areEqual`. Not only does it provide additional comparison operators, it also works in cases where the difference between the two values/dims is non-constant. The previous implementation of `areEqual` used to compute a constant bound of `val1 - val2`.

Note: This commit refactors, generalizes and adds a public API for value/dim comparison. The comparison functionality itself was introduced in #85895 and is already in use for analyzing `scf.if`.

In the long term, this improvement will allow for a more powerful analysis of subset ops. A future commit will update `areOverlappingSlices` to use the new comparison API. (`areEquivalentSlices` is already using the new API.) This will improve subset equivalence/disjointness checks with non-constant offsets/sizes/strides.


>From 4b38a84e39cbb26e2d59660485f26bc47394868c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 28 Mar 2024 06:12:48 +0000
Subject: [PATCH] [mlir][Interfaces][WIP] Expose public `compare` API

Also use `compare` API for `areEqual` etc.
---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  61 ++++-
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  31 +--
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 239 +++++++++++++-----
 .../value-bounds-op-interface-impl.mlir       |  43 +++-
 .../SCF/value-bounds-op-interface-impl.mlir   |  12 +
 .../value-bounds-op-interface-impl.mlir       |  16 +-
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  79 +++++-
 7 files changed, 369 insertions(+), 112 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index f35432ca0136f3..d27081fad8c6c0 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
   /// Comparison operator for `ValueBoundsConstraintSet::compare`.
   enum ComparisonOperator { LT, LE, EQ, GT, GE };
 
-  /// Try to prove that, based on the current state of this constraint set
+  /// Populate constraints for lhs/rhs (until the stop condition is met). Then,
+  /// try to prove that, based on the current state of this constraint set
   /// (i.e., without analyzing additional IR or adding new constraints), the
   /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
   ///
@@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
-               Value rhs, std::optional<int64_t> rhsDim);
+  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                          ComparisonOperator cmp, OpFoldResult rhs,
+                          std::optional<int64_t> rhsDim);
+
+  /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
+  /// specified relation could not be proven. This could be because the
+  /// specified relation does in fact not hold or because there is not enough
+  /// information in the constraint set. In other words, if we do not know for
+  /// sure, this function returns "false".
+  ///
+  /// This function keeps traversing the backward slice of lhs/rhs until could
+  /// prove the relation or until it ran out of IR.
+  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                      ComparisonOperator cmp, OpFoldResult rhs,
+                      std::optional<int64_t> rhsDim);
+  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
+                      ComparisonOperator cmp, AffineMap rhs,
+                      ValueDimList rhsOperands);
+  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
+                      ComparisonOperator cmp, AffineMap rhs,
+                      ArrayRef<Value> rhsOperands);
 
   /// Compute whether the given values/dimensions are equal. Return "failure" if
   /// equality could not be determined.
   ///
   /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
   /// index-typed.
-  static FailureOr<bool> areEqual(Value value1, Value value2,
+  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
                                   std::optional<int64_t> dim1 = std::nullopt,
                                   std::optional<int64_t> dim2 = std::nullopt);
 
-  /// Compute whether the given values/attributes are equal. Return "failure" if
-  /// equality could not be determined.
-  ///
-  /// `ofr1`/`ofr2` must be of index type.
-  static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
-
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
   /// Return "failure" if unknown.
@@ -290,6 +304,20 @@ class ValueBoundsConstraintSet
 
   ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
+  /// Return "true" if, based on the current state of the constraint system,
+  /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
+  /// could not be proven. This could be because the specified relation does in
+  /// fact not hold or because there is not enough information in the constraint
+  /// set. In other words, if we do not know for sure, this function returns
+  /// "false".
+  ///
+  /// This function does not analyze any IR and does not populate any additional
+  /// constraints.
+  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+                        ComparisonOperator cmp, OpFoldResult rhs,
+                        std::optional<int64_t> rhsDim);
+  bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+
   /// Given an affine map with a single result (and map operands), add a new
   /// column to the constraint set that represents the result of the map.
   /// Traverse additional IR starting from the map operands as needed (as long
@@ -311,6 +339,14 @@ class ValueBoundsConstraintSet
   /// value/dimension exists in the constraint set.
   int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
 
+  /// Return an affine expression that represents column `pos` in the constraint
+  /// set.
+  AffineExpr getPosExpr(int64_t pos);
+
+  /// Return "true" if the given value/dim is mapped (i.e., has a corresponding
+  /// column in the constraint system).
+  bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
+
   /// Insert a value/dimension into the constraint set. If `isSymbol` is set to
   /// "false", a dimension is added. The value/dimension is added to the
   /// worklist if `addToWorklist` is set.
@@ -330,6 +366,11 @@ class ValueBoundsConstraintSet
   /// dimensions but not for symbols.
   int64_t insert(bool isSymbol = true);
 
+  /// Insert the given affine map and its bound operands as a new column in the
+  /// constraint system. Return the position of the new column. Any operands
+  /// that were not analyzed yet are put on the worklist.
+  int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
 
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 72c5aaa2306783..087ffc438a830a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,20 +58,11 @@ struct ForOpInterface
     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
     Value initArg = forOp.getInitArgs()[iterArgIdx];
 
-    // Populate constraints for the yielded value.
-    cstr.populateConstraints(yieldedValue, dim);
-    // Populate constraints for the iter_arg. This is just to ensure that the
-    // iter_arg is mapped in the constraint set, which is a prerequisite for
-    // `compare`. It may lead to a recursive call to this function in case the
-    // iter_arg was not visited when the constraints for the yielded value were
-    // populated, but no additional work is done.
-    cstr.populateConstraints(iterArg, dim);
-
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
-    if (cstr.compare(yieldedValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
-                     dim)) {
+    if (cstr.populateAndCompare(
+            yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
+            iterArg, dim)) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
@@ -113,10 +104,6 @@ struct IfOpInterface
     Value thenValue = ifOp.thenYield().getResults()[resultNum];
     Value elseValue = ifOp.elseYield().getResults()[resultNum];
 
-    // Populate constraints for the yielded value (and all values on the
-    // backward slice, as long as the current stop condition is not satisfied).
-    cstr.populateConstraints(thenValue, dim);
-    cstr.populateConstraints(elseValue, dim);
     auto boundsBuilder = cstr.bound(value);
     if (dim)
       boundsBuilder[*dim];
@@ -125,9 +112,9 @@ struct IfOpInterface
     // If thenValue <= elseValue:
     // * result <= elseValue
     // * result >= thenValue
-    if (cstr.compare(thenValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     elseValue, dim)) {
+    if (cstr.populateAndCompare(
+            thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+            elseValue, dim)) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -139,9 +126,9 @@ struct IfOpInterface
     // If elseValue <= thenValue:
     // * result <= thenValue
     // * result >= elseValue
-    if (cstr.compare(elseValue, dim,
-                     ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     thenValue, dim)) {
+    if (cstr.populateAndCompare(
+            elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+            thenValue, dim)) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index dd98da9adc7d96..d7ffed14daccdd 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -212,6 +212,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
+                                         bool isSymbol) {
+  assert(map.getNumResults() == 1 && "expected affine map with one result");
+  int64_t pos = insert(/*isSymbol=*/false);
+
+  // Add map and operands to the constraint set. Dimensions are converted to
+  // symbols. All operands are added to the worklist (unless they were already
+  // processed).
+  auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+    return getExpr(v.first, v.second);
+  };
+  SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+  SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+      llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+  addBound(
+      presburger::BoundType::EQ, pos,
+      map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+  return pos;
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -227,6 +249,20 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
   return it->second;
 }
 
+AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
+  assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
+  return pos < cstr.getNumDimVars()
+             ? builder.getAffineDimExpr(pos)
+             : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+}
+
+bool ValueBoundsConstraintSet::isMapped(Value value,
+                                        std::optional<int64_t> dim) const {
+  auto it =
+      valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
+  return it != valueDimToPosition.end();
+}
+
 static Operation *getOwnerOfValue(Value value) {
   if (auto bbArg = dyn_cast<BlockArgument>(value))
     return bbArg.getOwner()->getParentOp();
@@ -563,27 +599,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value,
 
 int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
                                                       ValueDimList operands) {
-  assert(map.getNumResults() == 1 && "expected affine map with one result");
-  int64_t pos = insert(/*isSymbol=*/false);
-
-  // Add map and operands to the constraint set. Dimensions are converted to
-  // symbols. All operands are added to the worklist (unless they were already
-  // processed).
-  auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
-    return getExpr(v.first, v.second);
-  };
-  SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
-      llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
-  SmallVector<AffineExpr> symReplacements = llvm::to_vector(
-      llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
-  addBound(
-      presburger::BoundType::EQ, pos,
-      map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
-
+  int64_t pos = insert(map, operands, /*isSymbol=*/false);
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
   processWorklist();
-
   return pos;
 }
 
@@ -603,9 +622,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
                               {{value1, dim1}, {value2, dim2}});
 }
 
-bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
-                                       ComparisonOperator cmp, Value rhs,
-                                       std::optional<int64_t> rhsDim) {
+bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
+                                                std::optional<int64_t> lhsDim,
+                                                ComparisonOperator cmp,
+                                                OpFoldResult rhs,
+                                                std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    assertValidValueDim(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
   // This function returns "true" if "lhs CMP rhs" is proven to hold.
   //
   // Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -624,19 +652,61 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
 
   // EQ can be expressed as LE and GE.
   if (cmp == EQ)
-    return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
-           compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
+    return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
+           compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
 
   // Construct inequality. For the above example: lhs > rhs.
   // `IntegerRelation` inequalities are expressed in the "flattened" form and
   // with ">= 0". I.e., lhs - rhs - 1 >= 0.
+  SmallVector<int64_t> eq(cstr.getNumCols(), 0);
+  auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
+                     int64_t factor) {
+    if (auto constVal = ::getConstantIntValue(ofr)) {
+      eq[cstr.getNumCols() - 1] += *constVal * factor;
+    } else {
+      eq[getPos(cast<Value>(ofr), dim)] += factor;
+    }
+  };
+  if (cmp == LT || cmp == LE) {
+    addToEq(lhs, lhsDim, 1);
+    addToEq(rhs, rhsDim, -1);
+  } else if (cmp == GT || cmp == GE) {
+    addToEq(lhs, lhsDim, -1);
+    addToEq(rhs, rhsDim, 1);
+  } else {
+    llvm_unreachable("unsupported comparison operator");
+  }
+  if (cmp == LE || cmp == GE)
+    eq[cstr.getNumCols() - 1] -= 1;
+
+  // Add inequality to the constraint set and check if it made the constraint
+  // set empty.
+  int64_t ineqPos = cstr.getNumInequalities();
+  cstr.addInequality(eq);
+  bool isEmpty = cstr.isEmpty();
+  cstr.removeInequality(ineqPos);
+  return isEmpty;
+}
+
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+                                          ComparisonOperator cmp,
+                                          int64_t rhsPos) {
+  // This function returns "true" if "lhs CMP rhs" is proven to hold. For
+  // detailed documentation, see `compareValueDims`.
+
+  // EQ can be expressed as LE and GE.
+  if (cmp == EQ)
+    return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
+           comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+
+  // Construct inequality.
   SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
   if (cmp == LT || cmp == LE) {
-    ++eq[getPos(lhs, lhsDim)];
-    --eq[getPos(rhs, rhsDim)];
+    ++eq[lhsPos];
+    --eq[rhsPos];
   } else if (cmp == GT || cmp == GE) {
-    --eq[getPos(lhs, lhsDim)];
-    ++eq[getPos(rhs, rhsDim)];
+    --eq[lhsPos];
+    ++eq[rhsPos];
   } else {
     llvm_unreachable("unsupported comparison operator");
   }
@@ -652,40 +722,93 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
   return isEmpty;
 }
 
+bool ValueBoundsConstraintSet::populateAndCompare(
+    OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
+    OpFoldResult rhs, std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    assertValidValueDim(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
+  if (auto lhsVal = dyn_cast<Value>(lhs))
+    populateConstraints(lhsVal, lhsDim);
+  if (auto rhsVal = dyn_cast<Value>(rhs))
+    populateConstraints(rhsVal, rhsDim);
+
+  return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
+                                       std::optional<int64_t> lhsDim,
+                                       ComparisonOperator cmp, OpFoldResult rhs,
+                                       std::optional<int64_t> rhsDim) {
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs are not mapped.
+    if (auto lhsVal = dyn_cast<Value>(lhs))
+      if (!cstr.isMapped(lhsVal, dim))
+        return false;
+    if (auto rhsVal = dyn_cast<Value>(rhs))
+      if (!cstr.isMapped(rhsVal, dim))
+        return false;
+    // Keep processing as long as the relation cannot be proven.
+    return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+  };
+
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
+                                       ComparisonOperator cmp, AffineMap rhs,
+                                       ValueDimList rhsOperands) {
+  int64_t lhsPos = -1, rhsPos = -1;
+  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+                           ValueBoundsConstraintSet &cstr) {
+    // Keep processing as long as lhs/rhs were not processed.
+    if (lhsPos >= cstr.positionToValueDim.size() ||
+        rhsPos >= cstr.positionToValueDim.size())
+      return false;
+    // Keep processing as long as the relation cannot be proven.
+    return cstr.comparePos(lhsPos, cmp, rhsPos);
+  };
+  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+  lhsPos = cstr.insert(lhs, lhsOperands);
+  rhsPos = cstr.insert(rhs, rhsOperands);
+  return cstr.comparePos(lhsPos, cmp, rhsPos);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs,
+                                       ArrayRef<Value> lhsOperands,
+                                       ComparisonOperator cmp, AffineMap rhs,
+                                       ArrayRef<Value> rhsOperands) {
+  ValueDimList lhsValueDimOperands =
+      llvm::map_to_vector(lhsOperands, [](Value v) {
+        return std::make_pair(v, std::optional<int64_t>());
+      });
+  ValueDimList rhsValueDimOperands =
+      llvm::map_to_vector(rhsOperands, [](Value v) {
+        return std::make_pair(v, std::optional<int64_t>());
+      });
+  return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
+                                           rhsValueDimOperands);
+}
+
 FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
                                    std::optional<int64_t> dim1,
                                    std::optional<int64_t> dim2) {
-  // Subtract the two values/dimensions from each other. If the result is 0,
-  // both are equal.
-  FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
-  if (failed(delta))
-    return failure();
-  return *delta == 0;
-}
-
-FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
-                                                   OpFoldResult ofr2) {
-  Builder b(ofr1.getContext());
-  AffineMap map =
-      AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
-                     b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
-  SmallVector<OpFoldResult> ofrOperands;
-  ofrOperands.push_back(ofr1);
-  ofrOperands.push_back(ofr2);
-  SmallVector<Value> valueOperands;
-  AffineMap foldedMap =
-      foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
-  ValueDimList valueDims;
-  for (Value v : valueOperands) {
-    assert(v.getType().isIndex() && "expected index type");
-    valueDims.emplace_back(v, std::nullopt);
-  }
-  FailureOr<int64_t> delta =
-      computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
-  if (failed(delta))
-    return failure();
-  return *delta == 0;
+  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
+                                        value2, dim2))
+    return true;
+  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
+                                        value2, dim2) ||
+      ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
+                                        value2, dim2))
+    return false;
+  return failure();
 }
 
 FailureOr<bool>
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 55282e8334abd7..10da91870f49d9 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -79,6 +79,17 @@ func.func @composed_affine_apply(%i1 : index) -> (index) {
 }
 
 
+// -----
+
+func.func @are_equal(%i1 : index) {
+  %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+  %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+  %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+  // expected-remark @below{{false}}
+   "test.compare"(%i2, %i3) : (index, index) -> ()
+  return
+}
+
 // -----
 
 // Test for affine::fullyComposeAndCheckIfEqual
@@ -87,6 +98,36 @@ func.func @composed_are_equal(%i1 : index) {
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
   %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
   // expected-remark @below{{different}}
-   "test.are_equal"(%i2, %i3) {compose} : (index, index) -> ()
+   "test.compare"(%i2, %i3) {compose} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @compare_affine_max(%a: index, %b: index) {
+  %0 = affine.max affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b]
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> ()
+  // expected-error @below{{unknown}}
+  "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> ()
+  // expected-remark @below{{false}}
+  "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> ()
+  // expected-error @below{{unknown}}
+  "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
+  return
+}
+
+// -----
+
+func.func @compare_affine_min(%a: index, %b: index) {
+  %0 = affine.min affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b]
+  // expected-error @below{{unknown}}
+  "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> ()
+  // expected-remark @below{{false}}
+  "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> ()
+  // expected-error @below{{unknown}}
+  "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> ()
+  // expected-remark @below{{true}}
+  "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
   return
 }
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index 0ea06737886d41..9ab03da1c9a94f 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -219,3 +219,15 @@ func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
   "test.some_use"(%reify1) : (index) -> ()
   return
 }
+
+// -----
+
+func.func @compare_scf_for(%a: index, %b: index, %c: index) {
+  scf.for %iv = %a to %b step %c {
+    // expected-remark @below{{true}}
+    "test.compare"(%iv, %a) {cmp = "GE"} : (index, index) -> ()
+    // expected-remark @below{{true}}
+    "test.compare"(%iv, %b) {cmp = "LT"} : (index, index) -> ()
+  }
+  return
+}
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 45520da6aeb0b0..0c90bcdb42028c 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -163,8 +163,8 @@ func.func @dynamic_dims_are_equal(%t: tensor<?xf32>) {
   %c0 = arith.constant 0 : index
   %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
   %dim1 = tensor.dim %t, %c0 : tensor<?xf32>
-  // expected-remark @below {{equal}}
-  "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+  // expected-remark @below {{true}}
+  "test.compare"(%dim0, %dim1) : (index, index) -> ()
   return
 }
 
@@ -175,8 +175,8 @@ func.func @dynamic_dims_are_different(%t: tensor<?xf32>) {
   %c1 = arith.constant 1 : index
   %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
   %val = arith.addi %dim0, %c1 : index
-  // expected-remark @below {{different}}
-  "test.are_equal"(%dim0, %val) : (index, index) -> ()
+  // expected-remark @below {{false}}
+  "test.compare"(%dim0, %val) : (index, index) -> ()
   return
 }
 
@@ -186,8 +186,8 @@ func.func @dynamic_dims_are_maybe_equal_1(%t: tensor<?xf32>) {
   %c0 = arith.constant 0 : index
   %c5 = arith.constant 5 : index
   %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
-  // expected-error @below {{could not determine equality}}
-  "test.are_equal"(%dim0, %c5) : (index, index) -> ()
+  // expected-error @below {{unknown}}
+  "test.compare"(%dim0, %c5) : (index, index) -> ()
   return
 }
 
@@ -198,7 +198,7 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
   %c1 = arith.constant 1 : index
   %dim0 = tensor.dim %t, %c0 : tensor<?x?xf32>
   %dim1 = tensor.dim %t, %c1 : tensor<?x?xf32>
-  // expected-error @below {{could not determine equality}}
-  "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+  // expected-error @below {{unknown}}
+  "test.compare"(%dim0, %dim1) : (index, index) -> ()
   return
 }
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 4b2b1a06341b71..f38631054fb3c1 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -57,7 +57,7 @@ struct TestReifyValueBounds
 
 } // namespace
 
-FailureOr<BoundType> parseBoundType(const std::string &type) {
+static FailureOr<BoundType> parseBoundType(const std::string &type) {
   if (type == "EQ")
     return BoundType::EQ;
   if (type == "LB")
@@ -67,6 +67,34 @@ FailureOr<BoundType> parseBoundType(const std::string &type) {
   return failure();
 }
 
+static FailureOr<ValueBoundsConstraintSet::ComparisonOperator>
+parseComparisonOperator(const std::string &type) {
+  if (type == "EQ")
+    return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+  if (type == "LT")
+    return ValueBoundsConstraintSet::ComparisonOperator::LT;
+  if (type == "LE")
+    return ValueBoundsConstraintSet::ComparisonOperator::LE;
+  if (type == "GT")
+    return ValueBoundsConstraintSet::ComparisonOperator::GT;
+  if (type == "GE")
+    return ValueBoundsConstraintSet::ComparisonOperator::GE;
+  return failure();
+}
+
+static ValueBoundsConstraintSet::ComparisonOperator
+invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
+  if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
+    return ValueBoundsConstraintSet::ComparisonOperator::GE;
+  if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
+    return ValueBoundsConstraintSet::ComparisonOperator::GT;
+  if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
+    return ValueBoundsConstraintSet::ComparisonOperator::LE;
+  if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
+    return ValueBoundsConstraintSet::ComparisonOperator::LT;
+  llvm_unreachable("unsupported comparison operator");
+}
+
 /// Look for "test.reify_bound" ops in the input and replace their results with
 /// the reified values.
 static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
@@ -215,18 +243,34 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
   return failure(result.wasInterrupted());
 }
 
-/// Look for "test.are_equal" ops and emit errors/remarks.
+/// Look for "test.compare" ops and emit errors/remarks.
 static LogicalResult testEquality(func::FuncOp funcOp) {
   IRRewriter rewriter(funcOp.getContext());
   WalkResult result = funcOp.walk([&](Operation *op) {
-    // Look for test.are_equal ops.
-    if (op->getName().getStringRef() == "test.are_equal") {
+    // Look for test.compare ops.
+    if (op->getName().getStringRef() == "test.compare") {
       if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
           !op->getOperand(1).getType().isIndex()) {
         op->emitOpError("invalid op");
         return WalkResult::skip();
       }
+
+      // Get comparison operator.
+      std::string cmpStr = "EQ";
+      if (auto cmpAttr = op->getAttrOfType<StringAttr>("cmp"))
+        cmpStr = cmpAttr.str();
+      auto cmpType = parseComparisonOperator(cmpStr);
+      if (failed(cmpType)) {
+        op->emitOpError("invalid comparison operator");
+        return WalkResult::interrupt();
+      }
+
       if (op->hasAttr("compose")) {
+        if (cmpType != ValueBoundsConstraintSet::EQ) {
+          op->emitOpError(
+              "comparison operator must be EQ when 'composed' is specified");
+          return WalkResult::interrupt();
+        }
         FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
             op->getOperand(0), op->getOperand(1));
         if (failed(delta)) {
@@ -236,16 +280,25 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
         } else {
           op->emitRemark("different");
         }
+        return WalkResult::advance();
+      }
+
+      auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
+        return ValueBoundsConstraintSet::compare(
+            /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
+            /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
+      };
+      if (compare(*cmpType)) {
+        op->emitRemark("true");
+      } else if (*cmpType != ValueBoundsConstraintSet::EQ &&
+                 compare(invertComparisonOperator(*cmpType))) {
+        op->emitRemark("false");
+      } else if (*cmpType == ValueBoundsConstraintSet::EQ &&
+                 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
+                  compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
+        op->emitRemark("false");
       } else {
-        FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
-            op->getOperand(0), op->getOperand(1));
-        if (failed(equal)) {
-          op->emitError("could not determine equality");
-        } else if (*equal) {
-          op->emitRemark("equal");
-        } else {
-          op->emitRemark("different");
-        }
+        op->emitError("unknown");
       }
     }
     return WalkResult::advance();



More information about the llvm-branch-commits mailing list