[Mlir-commits] [mlir] 40dd3aa - [mlir][Interfaces] `Variable` abstraction for `ValueBoundsOpInterface` (#87980)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 16 01:59:06 PDT 2024


Author: Matthias Springer
Date: 2024-04-16T10:59:02+02:00
New Revision: 40dd3aa91d3f73184e34e45e597b84bec059c572

URL: https://github.com/llvm/llvm-project/commit/40dd3aa91d3f73184e34e45e597b84bec059c572
DIFF: https://github.com/llvm/llvm-project/commit/40dd3aa91d3f73184e34e45e597b84bec059c572.diff

LOG: [mlir][Interfaces] `Variable` abstraction for `ValueBoundsOpInterface` (#87980)

This commit generalizes and cleans up the `ValueBoundsConstraintSet`
API. The API used to provide function overloads for comparing/computing
bounds of:
- index-typed SSA value
- dimension of shaped value
- affine map + operands

This commit removes all overloads. There is now a single entry point for
each `compare` variant and each `computeBound` variant. These functions
now take a `Variable`, which is internally represented as an affine map
and map operands.

This commit also adds support for computing bounds for an affine map +
operands. There was previously no public API for that.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
    mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
    mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
    mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
    mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
    mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
    mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
    mlir/lib/Dialect/Tensor/Utils/Utils.cpp
    mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
    mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
    mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 8e840e744064d5..1ea73752208156 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
 /// maximally compose chains of AffineApplyOps.
 FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
 
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+                const ValueBoundsConstraintSet::Variable &var,
+                ValueBoundsConstraintSet::StopConditionFn stopCondition,
+                bool closedUB = false);
+
 /// Reify a bound for the given index-typed value in terms of SSA values for
 /// which `stopCondition` is met. If no stop condition is specified, reify in
 /// terms of the operands of the owner op.

diff  --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
index 970a52a06a11a2..bbc7e5d3e0dd70 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
@@ -24,6 +24,17 @@ enum class BoundType;
 
 namespace arith {
 
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+                const ValueBoundsConstraintSet::Variable &var,
+                ValueBoundsConstraintSet::StopConditionFn stopCondition,
+                bool closedUB = false);
+
 /// Reify a bound for the given index-typed value in terms of SSA values for
 /// which `stopCondition` is met. If no stop condition is specified, reify in
 /// terms of the operands of the owner op.

diff  --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3..ac17ace5a976d2 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,8 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
+                          const Variable &rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +252,12 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(const Variable &lhs, ComparisonOperator cmp,
+                      const Variable &rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +324,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +378,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +386,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.

diff  --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d..82a9fb0d490882 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }

diff  --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7..1a266b72d1f8d3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -16,16 +16,15 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
-                Value value, std::optional<int64_t> dim,
-                ValueBoundsConstraintSet::StopConditionFn stopCondition,
-                bool closedUB) {
+FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
+    OpBuilder &b, Location loc, presburger::BoundType type,
+    const ValueBoundsConstraintSet::Variable &var,
+    ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
   // Compute bound.
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, var, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     // the owner of `value`.
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, dim,
+  return reifyValueBound(b, loc, type, {value, dim},
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
                              ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+  return reifyValueBound(b, loc, type, value,
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }

diff  --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d43808bc45df..7cfcc4180539c2 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
     // If trueValue <= falseValue:
     // * result <= falseValue
     // * result >= trueValue
-    if (cstr.compare(trueValue, dim,
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
                      ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     falseValue, dim)) {
+                     /*rhs=*/{falseValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
     // If falseValue <= trueValue:
     // * result <= trueValue
     // * result >= falseValue
-    if (cstr.compare(falseValue, dim,
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
                      ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     trueValue, dim)) {
+                     /*rhs=*/{trueValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);

diff  --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a..f87f3d6350c022 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();

diff  --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190e..5fb7953f937007 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
   return buildExpr(map.getResult(0));
 }
 
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
-                Value value, std::optional<int64_t> dim,
-                ValueBoundsConstraintSet::StopConditionFn stopCondition,
-                bool closedUB) {
+FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
+    OpBuilder &b, Location loc, presburger::BoundType type,
+    const ValueBoundsConstraintSet::Variable &var,
+    ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
   // Compute bound.
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, var, stopCondition, closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     // the owner of `value`.
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, dim,
+  return reifyValueBound(b, loc, type, {value, dim},
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
                              ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+  return reifyValueBound(b, loc, type, value,
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db248989..518d2e138c02a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049..71eb59d40836c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9..1f06318cbd60e0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }

diff  --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a..17a1c016ea16d5 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
     if (cstr.populateAndCompare(
-            yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
-            iterArg, dim)) {
+            /*lhs=*/{yieldedValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::EQ,
+            /*rhs=*/{iterArg, dim})) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
-        cstr.bound(value) == initArg;
+        cstr.bound(value) == cstr.getExpr(initArg);
       }
     }
   }
@@ -113,8 +114,9 @@ struct IfOpInterface
     // * result <= elseValue
     // * result >= thenValue
     if (cstr.populateAndCompare(
-            thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            elseValue, dim)) {
+            /*lhs=*/{thenValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{elseValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
     // * result <= thenValue
     // * result >= elseValue
     if (cstr.populateAndCompare(
-            elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            thenValue, dim)) {
+            /*lhs=*/{elseValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{thenValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c13..d25efcf50ec566 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
 
   info.isAlignedToInnerTileSize = false;
   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::UB,
-      getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+      presburger::BoundType::UB, tileSize,
       /*stopCondition=*/nullptr, /*closedUB=*/true);
   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
   if (!failed(cstSize) && cstInnerSize) {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49b3..a89ce20048dff3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
           boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          independencies,
+          /*closedUB=*/true)))
     return failure();
   return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }

diff  --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a1700..15381ec520e211 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), srcDim, resultDim);
+        {op.getSource(), srcDim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), dim, resultDim);
+        {op.getSource(), dim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++resultDim;

diff  --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ffa4c0b55cad7c..87937591e60ad8 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
 } // namespace mlir
 
+static Operation *getOwnerOfValue(Value value) {
+  if (auto bbArg = dyn_cast<BlockArgument>(value))
+    return bbArg.getOwner()->getParentOp();
+  return value.getDefiningOp();
+}
+
 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
                                              ArrayRef<OpFoldResult> sizes,
                                              ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+    : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+    : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+    : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+                                             std::optional<int64_t> dim) {
+  Builder b(ofr.getContext());
+  if (auto constInt = ::getConstantIntValue(ofr)) {
+    assert(!dim && "expected no dim for index-typed values");
+    map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+                         b.getAffineConstantExpr(*constInt));
+    return;
+  }
+  Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+  if (dim) {
+    assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+  } else {
+    assert(value.getType().isIndex() && "expected index type");
+  }
+#endif // NDEBUG
+  map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+                       b.getAffineSymbolExpr(0));
+  mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Variable> mapOperands) {
+  assert(map.getNumResults() == 1 && "expected single result");
+
+  // Turn all dims into symbols.
+  Builder b(map.getContext());
+  SmallVector<AffineExpr> dimReplacements, symReplacements;
+  for (int64_t i = 0, e = map.getNumDims(); i < e; ++i)
+    dimReplacements.push_back(b.getAffineSymbolExpr(i));
+  for (int64_t i = 0, e = map.getNumSymbols(); i < e; ++i)
+    symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+  AffineMap tmpMap = map.replaceDimsAndSymbols(
+      dimReplacements, symReplacements, /*numResultDims=*/0,
+      /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+  // Inline operands.
+  DenseMap<AffineExpr, AffineExpr> replacements;
+  for (auto [index, var] : llvm::enumerate(mapOperands)) {
+    assert(var.map.getNumResults() == 1 && "expected single result");
+    assert(var.map.getNumDims() == 0 && "expected only symbols");
+    SmallVector<AffineExpr> symReplacements;
+    for (auto valueDim : var.mapOperands) {
+      auto it = llvm::find(this->mapOperands, valueDim);
+      if (it != this->mapOperands.end()) {
+        // There is already a symbol for this operand.
+        symReplacements.push_back(b.getAffineSymbolExpr(
+            std::distance(this->mapOperands.begin(), it)));
+      } else {
+        // This is a new operand: add a new symbol.
+        symReplacements.push_back(
+            b.getAffineSymbolExpr(this->mapOperands.size()));
+        this->mapOperands.push_back(valueDim);
+      }
+    }
+    replacements[b.getAffineSymbolExpr(index)] =
+        var.map.getResult(0).replaceSymbols(symReplacements);
+  }
+  this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+                             /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Value> mapOperands)
+    : Variable(map, llvm::map_to_vector(mapOperands,
+                                        [](Value v) { return Variable(v); })) {}
+
 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
     MLIRContext *ctx, StopConditionFn stopCondition)
     : builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
   assert(!valueDimToPosition.contains(valueDim) && "already mapped");
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+                          << " for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+                          << "\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+  return insert(var.map, var.mapOperands, isSymbol);
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
           cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
          "unstructured control flow is not supported");
 #endif // NDEBUG
-
+  LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   auto it =
       valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
   assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
   return it != valueDimToPosition.end();
 }
 
-static Operation *getOwnerOfValue(Value value) {
-  if (auto bbArg = dyn_cast<BlockArgument>(value))
-    return bbArg.getOwner()->getParentOp();
-  return value.getDefiningOp();
-}
-
 void ValueBoundsConstraintSet::processWorklist() {
   LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
   }
 }
 
+void ValueBoundsConstraintSet::projectOutAnonymous(
+    std::optional<int64_t> except) {
+  int64_t nextPos = 0;
+  while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+    if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+      ++nextPos;
+    } else {
+      projectOut(nextPos);
+      // The column was projected out so another column is now at that position.
+      // Do not increase the counter.
+    }
+  }
+}
+
 LogicalResult ValueBoundsConstraintSet::computeBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
-    bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+    const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+  MLIRContext *ctx = var.getContext();
   int64_t ubAdjustment = closedUB ? 0 : 1;
-  Builder b(value.getContext());
+  Builder b(ctx);
   mapOperands.clear();
 
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
-  ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
-  assert(!stopCondition(value, dim, cstr) &&
-         "stop condition should not be satisfied for starting point");
-  int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+  ValueBoundsConstraintSet cstr(ctx, stopCondition);
+  int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+  assert(pos == 0 && "expected first column");
   cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
   cstr.projectOut([&](ValueDim p) {
-    // Do not project out `valueDim`.
-    if (valueDim == p)
-      return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
     return !stopCondition(p.first, maybeDim, cstr);
   });
+  cstr.projectOutAnonymous(/*except=*/pos);
 
   // Compute lower and upper bounds for `valueDim`.
   SmallVector<AffineMap> lb(1), ub(1);
-  cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+  cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
                            /*closedUB=*/true);
 
   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
 
 LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueDimList dependencies,
-    bool closedUB) {
+    const Variable &var, ValueDimList dependencies, bool closedUB) {
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
 
 LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueRange independencies,
-    bool closedUB) {
+    const Variable &var, ValueRange independencies, bool closedUB) {
   // Return "true" if the given value is independent of all values in
   // `independencies`. I.e., neither the value itself nor any value in the
   // backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 
   // Reify bounds in terms of any independent values.
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return isIndependent(v);
       },
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 }
 
 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, Value value, std::optional<int64_t> dim,
-    StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
-  AffineMap map =
-      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                     Builder(value.getContext()).getAffineDimExpr(0));
-  return computeConstantBound(type, map, {{value, dim}}, stopCondition,
-                              closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+    presburger::BoundType type, const Variable &var,
     StopConditionFn stopCondition, bool closedUB) {
-  ValueDimList valueDims;
-  for (Value v : operands) {
-    assert(v.getType().isIndex() && "expected index type");
-    valueDims.emplace_back(v, std::nullopt);
-  }
-  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ValueDimList operands,
-    StopConditionFn stopCondition, bool closedUB) {
-  assert(map.getNumResults() == 1 && "expected affine map with one result");
-
   // Default stop condition if none was specified: Keep adding constraints until
   // a bound could be computed.
   int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   };
 
   ValueBoundsConstraintSet cstr(
-      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
-  pos = cstr.populateConstraints(map, operands);
+      var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.populateConstraints(var.map, var.mapOperands);
   assert(pos == 0 && "expected `map` is the first column");
 
   // Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  return computeConstantBound(presburger::BoundType::EQ, map,
-                              {{value1, dim1}, {value2, dim2}});
+  return computeConstantBound(presburger::BoundType::EQ,
+                              Variable(map, {{value1, dim1}, {value2, dim2}}));
 }
 
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
-                                                std::optional<int64_t> lhsDim,
-                                                ComparisonOperator cmp,
-                                                OpFoldResult rhs,
-                                                std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+                                          ComparisonOperator cmp,
+                                          int64_t rhsPos) {
   // This function returns "true" if "lhs CMP rhs" is proven to hold.
   //
   // Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
     return false;
   }
 
-  // EQ can be expressed as LE and GE.
-  if (cmp == EQ)
-    return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
-           compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
-  // Construct inequality. For the above example: lhs > rhs.
-  // `IntegerRelation` inequalities are expressed in the "flattened" form and
-  // with ">= 0". I.e., lhs - rhs - 1 >= 0.
-  SmallVector<int64_t> eq(cstr.getNumCols(), 0);
-  auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
-                     int64_t factor) {
-    if (auto constVal = ::getConstantIntValue(ofr)) {
-      eq[cstr.getNumCols() - 1] += *constVal * factor;
-    } else {
-      eq[getPos(cast<Value>(ofr), dim)] += factor;
-    }
-  };
-  if (cmp == LT || cmp == LE) {
-    addToEq(lhs, lhsDim, 1);
-    addToEq(rhs, rhsDim, -1);
-  } else if (cmp == GT || cmp == GE) {
-    addToEq(lhs, lhsDim, -1);
-    addToEq(rhs, rhsDim, 1);
-  } else {
-    llvm_unreachable("unsupported comparison operator");
-  }
-  if (cmp == LE || cmp == GE)
-    eq[cstr.getNumCols() - 1] -= 1;
-
-  // Add inequality to the constraint set and check if it made the constraint
-  // set empty.
-  int64_t ineqPos = cstr.getNumInequalities();
-  cstr.addInequality(eq);
-  bool isEmpty = cstr.isEmpty();
-  cstr.removeInequality(ineqPos);
-  return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
-                                          ComparisonOperator cmp,
-                                          int64_t rhsPos) {
-  // This function returns "true" if "lhs CMP rhs" is proven to hold. For
-  // detailed documentation, see `compareValueDims`.
-
   // EQ can be expressed as LE and GE.
   if (cmp == EQ)
     return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
@@ -712,48 +727,17 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
-bool ValueBoundsConstraintSet::populateAndCompare(
-    OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
-    OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    populateConstraints(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    populateConstraints(rhsVal, rhsDim);
-
-  return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
+                                                  ComparisonOperator cmp,
+                                                  const Variable &rhs) {
+  int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+  int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+  return comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
-                                       std::optional<int64_t> lhsDim,
-                                       ComparisonOperator cmp, OpFoldResult rhs,
-                                       std::optional<int64_t> rhsDim) {
-  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
-                           ValueBoundsConstraintSet &cstr) {
-    // Keep processing as long as lhs/rhs are not mapped.
-    if (auto lhsVal = dyn_cast<Value>(lhs))
-      if (!cstr.isMapped(lhsVal, dim))
-        return false;
-    if (auto rhsVal = dyn_cast<Value>(rhs))
-      if (!cstr.isMapped(rhsVal, dim))
-        return false;
-    // Keep processing as long as the relation cannot be proven.
-    return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
-  };
-
-  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(const Variable &lhs,
+                                       ComparisonOperator cmp,
+                                       const Variable &rhs) {
   int64_t lhsPos = -1, rhsPos = -1;
   auto stopCondition = [&](Value v, std::optional<int64_t> dim,
                            ValueBoundsConstraintSet &cstr) {
@@ -765,39 +749,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
     return cstr.comparePos(lhsPos, cmp, rhsPos);
   };
   ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  lhsPos = cstr.insert(lhs, lhsOperands);
-  rhsPos = cstr.insert(rhs, rhsOperands);
-  cstr.processWorklist();
+  lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+  rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
-                                       ArrayRef<Value> lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ArrayRef<Value> rhsOperands) {
-  ValueDimList lhsValueDimOperands =
-      llvm::map_to_vector(lhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  ValueDimList rhsValueDimOperands =
-      llvm::map_to_vector(rhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
-                                           rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
-                                        value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
+                                                   const Variable &var2) {
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
     return true;
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
-                                        value2, dim2) ||
-      ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
-                                        value2, dim2))
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+      ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
     return false;
   return failure();
 }
@@ -833,7 +795,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;
@@ -850,7 +812,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;

diff  --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 23c6872dcebe94..935c08aceff548 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -131,3 +131,27 @@ func.func @compare_affine_min(%a: index, %b: index) {
   "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
   return
 }
+
+// -----
+
+func.func @compare_const_map() {
+  %c5 = arith.constant 5 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%c5) {cmp = "GT", rhs_map = affine_map<() -> (4)>}
+      : (index) -> ()
+  // expected-remark @below{{true}}
+  "test.compare"(%c5) {cmp = "LT", lhs_map = affine_map<() -> (4)>}
+      : (index) -> ()
+  return
+}
+
+// -----
+
+func.func @compare_maps(%a: index, %b: index) {
+  // expected-remark @below{{true}}
+  "test.compare"(%a, %b, %b, %a)
+      {cmp = "GT", lhs_map = affine_map<(d0, d1) -> (1 + d0 + d1)>,
+       rhs_map = affine_map<(d0, d1) -> (d0 + d1)>}
+      : (index, index, index, index) -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6730f9b292ad93..b098a5a23fd316 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
     FailureOr<OpFoldResult> reified = failure();
     if (constant) {
       auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
-          boundType, value, dim, /*stopCondition=*/nullptr);
+          boundType, {value, dim}, /*stopCondition=*/nullptr);
       if (succeeded(reifiedConst))
         reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
     } else if (scalable) {
@@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
             rewriter, loc, reifiedScalable->map, vscaleOperand);
       }
     } else {
-      if (dim) {
-        if (useArithOps) {
-          reified = arith::reifyShapedValueDimBound(
-              rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
-        } else {
-          reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
-                                             value, *dim, stopCondition);
-        }
+      if (useArithOps) {
+        reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
+                                         op.getVariable(), stopCondition);
       } else {
-        if (useArithOps) {
-          reified = arith::reifyIndexValueBound(
-              rewriter, op->getLoc(), boundType, value, stopCondition);
-        } else {
-          reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
-                                         value, stopCondition);
-        }
+        reified = reifyValueBound(rewriter, op->getLoc(), boundType,
+                                  op.getVariable(), stopCondition);
       }
     }
     if (failed(reified)) {
@@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
     }
 
     auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
-      return ValueBoundsConstraintSet::compare(
-          /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
-          /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+      return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
     };
     if (compare(cmpType)) {
       op->emitRemark("true");

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190ca0ef3a..36d7606fe1345b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
   return success();
 }
 
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+  if (getDim().has_value())
+    return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+  return ValueBoundsConstraintSet::Variable(getVar());
+}
+
 ::mlir::ValueBoundsConstraintSet::ComparisonOperator
 CompareOp::getComparisonOperator() {
   if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
   llvm_unreachable("invalid comparison operator");
 }
 
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+  if (!getLhsMap())
+    return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+  SmallVector<Value> mapOperands(
+      getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+  return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+  int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+  if (!getRhsMap())
+    return ValueBoundsConstraintSet::Variable(
+        getVarOperands()[rhsOperandsBegin]);
+  SmallVector<Value> mapOperands(
+      getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+  return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+  if (getCompose() && (getLhsMap() || getRhsMap()))
+    return emitOpError(
+        "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+  int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+  expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+  if (getVarOperands().size() != expectedNumOperands)
+    return emitOpError("expected ")
+           << expectedNumOperands << " operands, but got "
+           << getVarOperands().size();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Test removing op with inner ops.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ebf158b8bb8203..b641b3da719c78 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
 
   let extraClassDeclaration = [{
     ::mlir::presburger::BoundType getBoundType();
+    ::mlir::ValueBoundsConstraintSet::Variable getVariable();
   }];
 
   let hasVerifier = 1;
@@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> {
     Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
     specified comparison operator was proven to hold. The remark also indicates
     whether the opposite comparison operator was proven to hold.
+
+    `var_operands` must have exactly two operands: one for the LHS operand and
+    one for the RHS operand. If `lhs_map` is specified, as many operands as
+    `lhs_map` has inputs are expected instead of the first operand. If `rhs_map`
+    is specified, as many operands as `rhs_map` has inputs are expected instead
+    of the second operand.
   }];
 
-  let arguments = (ins Index:$lhs,
-                       Index:$rhs,
+  let arguments = (ins Variadic<Index>:$var_operands,
                        DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+                       OptionalAttr<AffineMapAttr>:$lhs_map,
+                       OptionalAttr<AffineMapAttr>:$rhs_map,
                        UnitAttr:$compose);
   let results = (outs);
 
   let extraClassDeclaration = [{
     ::mlir::ValueBoundsConstraintSet::ComparisonOperator
         getComparisonOperator();
+    ::mlir::ValueBoundsConstraintSet::Variable getLhs();
+    ::mlir::ValueBoundsConstraintSet::Variable getRhs();
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list