[llvm-branch-commits] [mlir] [mlir][Interfaces][WIP] `Variable` abstraction for `ValueBoundsOpInterface` (PR #87980)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 8 04:31:16 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/87980

This commit generalizes and cleans up the `ValueBoundsConstraintSet` API. The API used to provide function overloads for comparing/computing bounds of:
- index-typed SSA value
- dimension of shaped value
- affine map + operands

This commit removes all overloads. There is now a single entry point for each `compare` variant and each `computeBound` variant. These functions now take a `Variable`, which is internally represented as an affine map and map operands.

This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.

WIP until I added a test case for `computeBounds(AffineMap)`.


>From ed12ff5144bc1fe5013698ee19ffcab9f831d7eb Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 8 Apr 2024 11:25:29 +0000
Subject: [PATCH] [mlir][Interfaces][WIP] `ValueBoundsOpInterface`: `Variable`

---
 .../mlir/Interfaces/ValueBoundsOpInterface.h  | 117 +++---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  |   6 +-
 .../Affine/Transforms/ReifyValueBounds.cpp    |   2 +-
 .../Arith/IR/ValueBoundsOpInterfaceImpl.cpp   |  69 ++++
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |   2 +-
 .../Arith/Transforms/ReifyValueBounds.cpp     |   4 +-
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |   6 +-
 .../Dialect/Linalg/Transforms/Promotion.cpp   |   6 +-
 .../Transforms/IndependenceTransforms.cpp     |   5 +-
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  17 +-
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   |   3 +-
 .../Transforms/IndependenceTransforms.cpp     |   3 +-
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |   4 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 337 ++++++++----------
 .../Dialect/Affine/TestReifyValueBounds.cpp   |   6 +-
 15 files changed, 312 insertions(+), 275 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(Variable var1, Variable var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d1..82a9fb0d490882f 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7c..6c59df91e8af781 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5a..411fc117a4d9f5d 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
   }
 };
 
+struct SelectOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+                                                   SelectOp> {
+
+  static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+                             ValueBoundsConstraintSet &cstr) {
+    Value value = selectOp.getResult();
+    Value condition = selectOp.getCondition();
+    Value trueValue = selectOp.getTrueValue();
+    Value falseValue = selectOp.getFalseValue();
+
+    if (isa<ShapedType>(condition.getType())) {
+      // If the condition is a shaped type, the condition is applied
+      // element-wise. All three operands must have the same shape.
+      cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+      cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+      return;
+    }
+
+    // Populate constraints for the true/false values (and all values on the
+    // backward slice, as long as the current stop condition is not satisfied).
+    cstr.populateConstraints(trueValue, dim);
+    cstr.populateConstraints(falseValue, dim);
+    auto boundsBuilder = cstr.bound(value);
+    if (dim)
+      boundsBuilder[*dim];
+
+    // Compare yielded values.
+    // If trueValue <= falseValue:
+    // * result <= falseValue
+    // * result >= trueValue
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{falseValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+      } else {
+        cstr.bound(value) >= trueValue;
+        cstr.bound(value) <= falseValue;
+      }
+    }
+    // If falseValue <= trueValue:
+    // * result <= trueValue
+    // * result >= falseValue
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
+                     ValueBoundsConstraintSet::ComparisonOperator::LE,
+                     /*rhs=*/{trueValue, dim})) {
+      if (dim) {
+        cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+        cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+      } else {
+        cstr.bound(value) >= falseValue;
+        cstr.bound(value) <= trueValue;
+      }
+    }
+  }
+
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+  }
+
+  void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+                                       ValueBoundsConstraintSet &cstr) const {
+    populateBounds(cast<SelectOp>(op), dim, cstr);
+  }
+};
 } // namespace
 } // namespace arith
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a2..f87f3d6350c0221 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190ed..5bb7d83bf1e3f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type,
+          ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+          closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489897..518d2e138c02a97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049d..71eb59d40836c1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9f..1f06318cbd60e04 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a3..17a1c016ea16d5a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
     if (cstr.populateAndCompare(
-            yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
-            iterArg, dim)) {
+            /*lhs=*/{yieldedValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::EQ,
+            /*rhs=*/{iterArg, dim})) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
-        cstr.bound(value) == initArg;
+        cstr.bound(value) == cstr.getExpr(initArg);
       }
     }
   }
@@ -113,8 +114,9 @@ struct IfOpInterface
     // * result <= elseValue
     // * result >= thenValue
     if (cstr.populateAndCompare(
-            thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            elseValue, dim)) {
+            /*lhs=*/{thenValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{elseValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
     // * result <= thenValue
     // * result >= elseValue
     if (cstr.populateAndCompare(
-            elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            thenValue, dim)) {
+            /*lhs=*/{elseValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{thenValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c135..d25efcf50ec566f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
 
   info.isAlignedToInnerTileSize = false;
   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::UB,
-      getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+      presburger::BoundType::UB, tileSize,
       /*stopCondition=*/nullptr, /*closedUB=*/true);
   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
   if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49b37..a89ce20048dff3d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
           boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          independencies,
+          /*closedUB=*/true)))
     return failure();
   return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a17003..15381ec520e2119 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), srcDim, resultDim);
+        {op.getSource(), srcDim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), dim, resultDim);
+        {op.getSource(), dim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++resultDim;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index fa66da4a0def937..9f220f5f6ceb729 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
 } // namespace mlir
 
+static Operation *getOwnerOfValue(Value value) {
+  if (auto bbArg = dyn_cast<BlockArgument>(value))
+    return bbArg.getOwner()->getParentOp();
+  return value.getDefiningOp();
+}
+
 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
                                              ArrayRef<OpFoldResult> sizes,
                                              ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+    : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+    : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+    : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+                                             std::optional<int64_t> dim) {
+  Builder b(ofr.getContext());
+  if (auto constInt = ::getConstantIntValue(ofr)) {
+    assert(!dim && "expected no dim for index-typed values");
+    map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+                         b.getAffineConstantExpr(*constInt));
+    return;
+  }
+  Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+  if (dim) {
+    assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+  } else {
+    assert(value.getType().isIndex() && "expected index type");
+  }
+#endif // NDEBUG
+  map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+                       b.getAffineSymbolExpr(0));
+  mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Variable> mapOperands) {
+  assert(map.getNumResults() == 1 && "expected single result");
+
+  // Turn all dims into symbols.
+  Builder b(map.getContext());
+  SmallVector<AffineExpr> dimReplacements, symReplacements;
+  for (int64_t i = 0; i < map.getNumDims(); ++i)
+    dimReplacements.push_back(b.getAffineSymbolExpr(i));
+  for (int64_t i = 0; i < map.getNumSymbols(); ++i)
+    symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+  AffineMap tmpMap = map.replaceDimsAndSymbols(
+      dimReplacements, symReplacements, /*numResultDims=*/0,
+      /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+  // Inline operands.
+  DenseMap<AffineExpr, AffineExpr> replacements;
+  for (auto [index, var] : llvm::enumerate(mapOperands)) {
+    assert(var.map.getNumResults() == 1 && "expected single result");
+    assert(var.map.getNumDims() == 0 && "expected only symbols");
+    SmallVector<AffineExpr> symReplacements;
+    for (auto valueDim : var.mapOperands) {
+      auto it = llvm::find(this->mapOperands, valueDim);
+      if (it != this->mapOperands.end()) {
+        // There is already a symbol for this operand.
+        symReplacements.push_back(b.getAffineSymbolExpr(
+            std::distance(this->mapOperands.begin(), it)));
+      } else {
+        // This is a new operand: add a new symbol.
+        symReplacements.push_back(
+            b.getAffineSymbolExpr(this->mapOperands.size()));
+        this->mapOperands.push_back(valueDim);
+      }
+    }
+    replacements[b.getAffineSymbolExpr(index)] =
+        var.map.getResult(0).replaceSymbols(symReplacements);
+  }
+  this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+                             /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Value> mapOperands)
+    : Variable(map, llvm::map_to_vector(mapOperands,
+                                        [](Value v) { return Variable(v); })) {}
+
 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
     MLIRContext *ctx, StopConditionFn stopCondition)
     : builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
   assert(!valueDimToPosition.contains(valueDim) && "already mapped");
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+                          << " for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+                          << "\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+  return insert(var.map, var.mapOperands, isSymbol);
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
           cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
          "unstructured control flow is not supported");
 #endif // NDEBUG
-
+  LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   auto it =
       valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
   assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
   return it != valueDimToPosition.end();
 }
 
-static Operation *getOwnerOfValue(Value value) {
-  if (auto bbArg = dyn_cast<BlockArgument>(value))
-    return bbArg.getOwner()->getParentOp();
-  return value.getDefiningOp();
-}
-
 void ValueBoundsConstraintSet::processWorklist() {
   LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
   }
 }
 
+void ValueBoundsConstraintSet::projectOutAnonymous(
+    std::optional<int64_t> except) {
+  int64_t nextPos = 0;
+  while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+    if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+      ++nextPos;
+    } else {
+      projectOut(nextPos);
+      // The column was projected out so another column is now at that position.
+      // Do not increase the counter.
+    }
+  }
+}
+
 LogicalResult ValueBoundsConstraintSet::computeBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
-    bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+    const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+  MLIRContext *ctx = var.getContext();
   int64_t ubAdjustment = closedUB ? 0 : 1;
-  Builder b(value.getContext());
+  Builder b(ctx);
   mapOperands.clear();
 
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
-  ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
-  assert(!stopCondition(value, dim, cstr) &&
-         "stop condition should not be satisfied for starting point");
-  int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+  ValueBoundsConstraintSet cstr(ctx, stopCondition);
+  int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+  assert(pos == 0 && "expected first column");
   cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
   cstr.projectOut([&](ValueDim p) {
-    // Do not project out `valueDim`.
-    if (valueDim == p)
-      return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
     return !stopCondition(p.first, maybeDim, cstr);
   });
+  cstr.projectOutAnonymous(/*except=*/pos);
 
   // Compute lower and upper bounds for `valueDim`.
   SmallVector<AffineMap> lb(1), ub(1);
-  cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+  cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
                            /*closedUB=*/true);
 
   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
 
 LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueDimList dependencies,
-    bool closedUB) {
+    const Variable &var, ValueDimList dependencies, bool closedUB) {
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
 
 LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueRange independencies,
-    bool closedUB) {
+    const Variable &var, ValueRange independencies, bool closedUB) {
   // Return "true" if the given value is independent of all values in
   // `independencies`. I.e., neither the value itself nor any value in the
   // backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 
   // Reify bounds in terms of any independent values.
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return isIndependent(v);
       },
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 }
 
 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, Value value, std::optional<int64_t> dim,
-    StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
-  AffineMap map =
-      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                     Builder(value.getContext()).getAffineDimExpr(0));
-  return computeConstantBound(type, map, {{value, dim}}, stopCondition,
-                              closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+    presburger::BoundType type, const Variable &var,
     StopConditionFn stopCondition, bool closedUB) {
-  ValueDimList valueDims;
-  for (Value v : operands) {
-    assert(v.getType().isIndex() && "expected index type");
-    valueDims.emplace_back(v, std::nullopt);
-  }
-  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ValueDimList operands,
-    StopConditionFn stopCondition, bool closedUB) {
-  assert(map.getNumResults() == 1 && "expected affine map with one result");
-
   // Default stop condition if none was specified: Keep adding constraints until
   // a bound could be computed.
   int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   };
 
   ValueBoundsConstraintSet cstr(
-      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
-  pos = cstr.populateConstraints(map, operands);
+      var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.populateConstraints(var.map, var.mapOperands);
   assert(pos == 0 && "expected `map` is the first column");
 
   // Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  return computeConstantBound(presburger::BoundType::EQ, map,
-                              {{value1, dim1}, {value2, dim2}});
+  return computeConstantBound(presburger::BoundType::EQ,
+                              Variable(map, {{value1, dim1}, {value2, dim2}}));
 }
 
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
-                                                std::optional<int64_t> lhsDim,
-                                                ComparisonOperator cmp,
-                                                OpFoldResult rhs,
-                                                std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+                                          ComparisonOperator cmp,
+                                          int64_t rhsPos) {
   // This function returns "true" if "lhs CMP rhs" is proven to hold.
   //
   // Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
     return false;
   }
 
-  // EQ can be expressed as LE and GE.
-  if (cmp == EQ)
-    return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
-           compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
-  // Construct inequality. For the above example: lhs > rhs.
-  // `IntegerRelation` inequalities are expressed in the "flattened" form and
-  // with ">= 0". I.e., lhs - rhs - 1 >= 0.
-  SmallVector<int64_t> eq(cstr.getNumCols(), 0);
-  auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
-                     int64_t factor) {
-    if (auto constVal = ::getConstantIntValue(ofr)) {
-      eq[cstr.getNumCols() - 1] += *constVal * factor;
-    } else {
-      eq[getPos(cast<Value>(ofr), dim)] += factor;
-    }
-  };
-  if (cmp == LT || cmp == LE) {
-    addToEq(lhs, lhsDim, 1);
-    addToEq(rhs, rhsDim, -1);
-  } else if (cmp == GT || cmp == GE) {
-    addToEq(lhs, lhsDim, -1);
-    addToEq(rhs, rhsDim, 1);
-  } else {
-    llvm_unreachable("unsupported comparison operator");
-  }
-  if (cmp == LE || cmp == GE)
-    eq[cstr.getNumCols() - 1] -= 1;
-
-  // Add inequality to the constraint set and check if it made the constraint
-  // set empty.
-  int64_t ineqPos = cstr.getNumInequalities();
-  cstr.addInequality(eq);
-  bool isEmpty = cstr.isEmpty();
-  cstr.removeInequality(ineqPos);
-  return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
-                                          ComparisonOperator cmp,
-                                          int64_t rhsPos) {
-  // This function returns "true" if "lhs CMP rhs" is proven to hold. For
-  // detailed documentation, see `compareValueDims`.
-
   // EQ can be expressed as LE and GE.
   if (cmp == EQ)
     return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
@@ -712,48 +727,16 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
-bool ValueBoundsConstraintSet::populateAndCompare(
-    OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
-    OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    populateConstraints(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    populateConstraints(rhsVal, rhsDim);
-
-  return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(Variable lhs,
+                                                  ComparisonOperator cmp,
+                                                  Variable rhs) {
+  int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+  int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+  return comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
-                                       std::optional<int64_t> lhsDim,
-                                       ComparisonOperator cmp, OpFoldResult rhs,
-                                       std::optional<int64_t> rhsDim) {
-  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
-                           ValueBoundsConstraintSet &cstr) {
-    // Keep processing as long as lhs/rhs are not mapped.
-    if (auto lhsVal = dyn_cast<Value>(lhs))
-      if (!cstr.isMapped(lhsVal, dim))
-        return false;
-    if (auto rhsVal = dyn_cast<Value>(rhs))
-      if (!cstr.isMapped(rhsVal, dim))
-        return false;
-    // Keep processing as long as the relation cannot be proven.
-    return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
-  };
-
-  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(Variable lhs, ComparisonOperator cmp,
+                                       Variable rhs) {
   int64_t lhsPos = -1, rhsPos = -1;
   auto stopCondition = [&](Value v, std::optional<int64_t> dim,
                            ValueBoundsConstraintSet &cstr) {
@@ -765,39 +748,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
     return cstr.comparePos(lhsPos, cmp, rhsPos);
   };
   ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  lhsPos = cstr.insert(lhs, lhsOperands);
-  rhsPos = cstr.insert(rhs, rhsOperands);
-  cstr.processWorklist();
+  lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+  rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
-                                       ArrayRef<Value> lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ArrayRef<Value> rhsOperands) {
-  ValueDimList lhsValueDimOperands =
-      llvm::map_to_vector(lhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  ValueDimList rhsValueDimOperands =
-      llvm::map_to_vector(rhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
-                                           rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
-                                        value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(Variable var1,
+                                                   Variable var2) {
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
     return true;
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
-                                        value2, dim2) ||
-      ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
-                                        value2, dim2))
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+      ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
     return false;
   return failure();
 }
@@ -833,7 +794,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;
@@ -850,7 +811,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index f38631054fb3c14..af4ba7de3df1f6c 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -169,7 +169,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
       FailureOr<OpFoldResult> reified = failure();
       if (constant) {
         auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
-            *boundType, value, dim, /*stopCondition=*/nullptr);
+            *boundType, {value, dim}, /*stopCondition=*/nullptr);
         if (succeeded(reifiedConst))
           reified =
               FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
@@ -285,8 +285,8 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
 
       auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
         return ValueBoundsConstraintSet::compare(
-            /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
-            /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
+            /*lhs=*/op->getOperand(0), cmp,
+            /*rhs=*/op->getOperand(1));
       };
       if (compare(*cmpType)) {
         op->emitRemark("true");



More information about the llvm-branch-commits mailing list