[Mlir-commits] [llvm] [mlir] [mlir][Interfaces][WIP] `Variable` abstraction for `ValueBoundsOpInterface` (PR #87980)

Matthias Springer llvmlistbot at llvm.org
Mon Apr 15 06:31:10 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/87980

>From bc3b2989bea3ad4e15c9891e9d35618bf491b54c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Apr 2024 11:50:20 +0000
Subject: [PATCH 1/2] [mlir][Interfaces][NFC] Add TableGen test op for value
 bounds tests

This commit is a code cleanup. It defines the test ops the are used for the `ValueBoundsOpInterface` tests in TableGen, along with proper verifiers.
---
 .../value-bounds-op-interface-impl.mlir       |   2 +-
 .../Affine/value-bounds-reification.mlir      |   6 +-
 .../value-bounds-op-interface-impl.mlir       |   4 +-
 .../Dialect/Vector/test-scalable-bounds.mlir  |  18 +-
 mlir/test/lib/Dialect/Affine/CMakeLists.txt   |   8 +
 .../Dialect/Affine/TestReifyValueBounds.cpp   | 305 ++++++------------
 mlir/test/lib/Dialect/Test/CMakeLists.txt     |   1 +
 mlir/test/lib/Dialect/Test/TestDialect.cpp    |  48 +++
 mlir/test/lib/Dialect/Test/TestDialect.h      |   1 +
 mlir/test/lib/Dialect/Test/TestOps.td         |  47 +++
 .../mlir/test/BUILD.bazel                     |   2 +
 11 files changed, 226 insertions(+), 216 deletions(-)

diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 10da91870f49d9..23c6872dcebe94 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -74,7 +74,7 @@ func.func @composed_affine_apply(%i1 : index) -> (index) {
   %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
   %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
   %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
-  %reified = "test.reify_constant_bound"(%s) {type = "EQ"} : (index) -> (index)
+  %reified = "test.reify_bound"(%s) {type = "EQ", constant} : (index) -> (index)
   return %reified : index
 }
 
diff --git a/mlir/test/Dialect/Affine/value-bounds-reification.mlir b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
index 909c9098c51607..75622f59af83be 100644
--- a/mlir/test/Dialect/Affine/value-bounds-reification.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-reification.mlir
@@ -47,7 +47,7 @@ func.func @reify_slice_bound(%t: tensor<?x?xi32>, %idx: index, %ub: index, %f: f
     %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound) : (index) -> ()
 
-    %bound_const = "test.reify_constant_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+    %bound_const = "test.reify_bound"(%filled) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
     "test.some_use"(%bound_const) : (index) -> ()
   }
   return
@@ -93,7 +93,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
 
     // CHECK: %[[c129:.*]] = arith.constant 129 : index
     // CHECK: "test.some_use"(%[[c129]])
-    %lb1_ub_const = "test.reify_constant_bound"(%lb1) {type = "UB"} : (index) -> (index)
+    %lb1_ub_const = "test.reify_bound"(%lb1) {type = "UB", constant} : (index) -> (index)
     "test.some_use"(%lb1_ub_const) : (index) -> ()
 
     scf.for %iv1 = %lb1 to %ub1 step %c32 {
@@ -116,7 +116,7 @@ func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index,
 
         // CHECK: %[[c32:.*]] = arith.constant 32 : index
         // CHECK: "test.some_use"(%[[c32]])
-        %matmul_ub_const = "test.reify_constant_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index)
+        %matmul_ub_const = "test.reify_bound"(%matmul) {dim = 1, type = "UB", constant} : (tensor<1x?xi32>) -> (index)
         "test.some_use"(%matmul_ub_const) : (index) -> ()
       }
     }
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 0c90bcdb42028c..0ba9983723a0a1 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -83,7 +83,7 @@ func.func @extract_slice_static(%t: tensor<?xf32>) -> index {
 func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> index {
   %0 = tensor.extract_slice %t[2][%sz][1] : tensor<?xf32> to tensor<?xf32>
   // expected-error @below{{could not reify bound}}
-  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<?xf32>) -> (index)
+  %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<?xf32>) -> (index)
   return %1 : index
 }
 
@@ -95,7 +95,7 @@ func.func @extract_slice_dynamic_constant(%t: tensor<?xf32>, %sz: index) -> inde
 //       CHECK:   return %[[c5]]
 func.func @extract_slice_static_constant(%t: tensor<?xf32>) -> index {
   %0 = tensor.extract_slice %t[2][5][1] : tensor<?xf32> to tensor<5xf32>
-  %1 = "test.reify_constant_bound"(%0) {dim = 0} : (tensor<5xf32>) -> (index)
+  %1 = "test.reify_bound"(%0) {dim = 0, constant} : (tensor<5xf32>) -> (index)
   return %1 : index
 }
 
diff --git a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index 245a6f5c13ac3d..d549c5bd1c3785 100644
--- a/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -26,8 +26,8 @@ func.func @fixed_size_loop_nest() {
     %min_i = affine.min #map_dim_i(%i)[%c4_vscale]
     scf.for %j = %c0 to %c16 step %c4_vscale {
       %min_j = affine.min #map_dim_j(%j)[%c4_vscale]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+      %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -58,8 +58,8 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
     %min_i = affine.min #map_dynamic_dim(%i)[%c4_vscale, %dim0]
     scf.for %j = %c0 to %dim1 step %c4_vscale {
       %min_j = affine.min #map_dynamic_dim(%j)[%c4_vscale, %dim1]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
+      %bound_j = "test.reify_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -80,7 +80,7 @@ func.func @add_to_vscale() {
   %vscale = vector.vscale
   %c8 = arith.constant 8 : index
   %vscale_plus_c8 = arith.addi %vscale, %c8 : index
-  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -94,7 +94,7 @@ func.func @add_to_vscale() {
 //       CHECK:   "test.some_use"(%[[C2]]) : (index) -> ()
 func.func @vscale_fixed_size() {
   %vscale = vector.vscale
-  %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index
+  %bound = "test.reify_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -107,7 +107,7 @@ func.func @unknown_bound(%a: index) {
   %vscale = vector.vscale
   %vscale_plus_a = arith.muli %vscale, %a : index
   // expected-error @below{{could not reify bound}}
-  %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -134,7 +134,7 @@ func.func @duplicate_vscale_values() {
   %c2_vscale = arith.muli %vscale_1, %c2 : index
   %add = arith.addi %c2_vscale, %c4_vscale : index
 
-  %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -154,7 +154,7 @@ func.func @non_scalable_code() {
   %c0 = arith.constant 0 : index
   scf.for %i = %c0 to %c1024 step %c4 {
     %min_i = affine.min #map_dim_i(%i)[%c4]
-    %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+    %bound_i = "test.reify_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16, scalable} : (index) -> index
     "test.some_use"(%bound_i) : (index) -> ()
   }
   return
diff --git a/mlir/test/lib/Dialect/Affine/CMakeLists.txt b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
index 14960a45d39bab..33cefab9fa2edf 100644
--- a/mlir/test/lib/Dialect/Affine/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Affine/CMakeLists.txt
@@ -30,5 +30,13 @@ add_mlir_library(MLIRAffineTransformsTestPasses
   MLIRSupport
   MLIRMemRefDialect
   MLIRTensorDialect
+  MLIRTestDialect
   MLIRVectorUtils
   )
+
+target_include_directories(MLIRAffineTransformsTestPasses
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Test
+  )
+
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index f38631054fb3c1..6730f9b292ad93 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
@@ -57,31 +58,6 @@ struct TestReifyValueBounds
 
 } // namespace
 
-static FailureOr<BoundType> parseBoundType(const std::string &type) {
-  if (type == "EQ")
-    return BoundType::EQ;
-  if (type == "LB")
-    return BoundType::LB;
-  if (type == "UB")
-    return BoundType::UB;
-  return failure();
-}
-
-static FailureOr<ValueBoundsConstraintSet::ComparisonOperator>
-parseComparisonOperator(const std::string &type) {
-  if (type == "EQ")
-    return ValueBoundsConstraintSet::ComparisonOperator::EQ;
-  if (type == "LT")
-    return ValueBoundsConstraintSet::ComparisonOperator::LT;
-  if (type == "LE")
-    return ValueBoundsConstraintSet::ComparisonOperator::LE;
-  if (type == "GT")
-    return ValueBoundsConstraintSet::ComparisonOperator::GT;
-  if (type == "GE")
-    return ValueBoundsConstraintSet::ComparisonOperator::GE;
-  return failure();
-}
-
 static ValueBoundsConstraintSet::ComparisonOperator
 invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
   if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
@@ -101,144 +77,89 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
                                           bool reifyToFuncArgs,
                                           bool useArithOps) {
   IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](Operation *op) {
-    // Look for test.reify_bound ops.
-    if (op->getName().getStringRef() == "test.reify_bound" ||
-        op->getName().getStringRef() == "test.reify_constant_bound" ||
-        op->getName().getStringRef() == "test.reify_scalable_bound") {
-      if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
-          !op->getResultTypes()[0].isIndex()) {
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-      Value value = op->getOperand(0);
-      if (isa<IndexType>(value.getType()) !=
-          !op->hasAttrOfType<IntegerAttr>("dim")) {
-        // Op should have "dim" attribute if and only if the operand is an
-        // index-typed value.
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-
-      // Get bound type.
-      std::string boundTypeStr = "EQ";
-      if (auto boundTypeAttr = op->getAttrOfType<StringAttr>("type"))
-        boundTypeStr = boundTypeAttr.str();
-      auto boundType = parseBoundType(boundTypeStr);
-      if (failed(boundType)) {
-        op->emitOpError("invalid op");
-        return WalkResult::interrupt();
-      }
-
-      // Get shape dimension (if any).
-      auto dim = value.getType().isIndex()
-                     ? std::nullopt
-                     : std::make_optional<int64_t>(
-                           op->getAttrOfType<IntegerAttr>("dim").getInt());
-
-      // Check if a constant was requested.
-      bool constant =
-          op->getName().getStringRef() == "test.reify_constant_bound";
-
-      bool scalable = !constant && op->getName().getStringRef() ==
-                                       "test.reify_scalable_bound";
-
-      // Prepare stop condition. By default, reify in terms of the op's
-      // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>,
-                         ValueBoundsConstraintSet & cstr)>
-          stopCondition = [&](Value v, std::optional<int64_t> d,
-                              ValueBoundsConstraintSet &cstr) {
-            // Reify in terms of SSA values that are different from `value`.
-            return v != value;
-          };
-      if (reifyToFuncArgs) {
-        // Reify in terms of function block arguments.
-        stopCondition = [](Value v, std::optional<int64_t> d,
-                           ValueBoundsConstraintSet &cstr) {
-          auto bbArg = dyn_cast<BlockArgument>(v);
-          if (!bbArg)
-            return false;
-          return isa<FunctionOpInterface>(
-              bbArg.getParentBlock()->getParentOp());
+  WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
+    auto boundType = op.getBoundType();
+    Value value = op.getVar();
+    std::optional<int64_t> dim = op.getDim();
+    bool constant = op.getConstant();
+    bool scalable = op.getScalable();
+
+    // Prepare stop condition. By default, reify in terms of the op's
+    // operands. No stop condition is used when a constant was requested.
+    std::function<bool(Value, std::optional<int64_t>,
+                       ValueBoundsConstraintSet & cstr)>
+        stopCondition = [&](Value v, std::optional<int64_t> d,
+                            ValueBoundsConstraintSet &cstr) {
+          // Reify in terms of SSA values that are different from `value`.
+          return v != value;
         };
-      }
+    if (reifyToFuncArgs) {
+      // Reify in terms of function block arguments.
+      stopCondition = [](Value v, std::optional<int64_t> d,
+                         ValueBoundsConstraintSet &cstr) {
+        auto bbArg = dyn_cast<BlockArgument>(v);
+        if (!bbArg)
+          return false;
+        return isa<FunctionOpInterface>(bbArg.getParentBlock()->getParentOp());
+      };
+    }
 
-      // Reify value bound
-      rewriter.setInsertionPointAfter(op);
-      FailureOr<OpFoldResult> reified = failure();
-      if (constant) {
-        auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
-            *boundType, value, dim, /*stopCondition=*/nullptr);
-        if (succeeded(reifiedConst))
-          reified =
-              FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
-      } else if (scalable) {
-        unsigned vscaleMin = 0;
-        unsigned vscaleMax = 0;
-        if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr)) {
-          vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
-        } else {
-          op->emitOpError("expected `vscale_min` to be provided");
-          return WalkResult::skip();
+    // Reify value bound
+    rewriter.setInsertionPointAfter(op);
+    FailureOr<OpFoldResult> reified = failure();
+    if (constant) {
+      auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
+          boundType, value, dim, /*stopCondition=*/nullptr);
+      if (succeeded(reifiedConst))
+        reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+    } else if (scalable) {
+      auto loc = op->getLoc();
+      auto reifiedScalable =
+          vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+              value, dim, *op.getVscaleMin(), *op.getVscaleMax(), boundType);
+      if (succeeded(reifiedScalable)) {
+        SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
+        if (reifiedScalable->map.getNumInputs() == 1) {
+          // The only possible input to the bound is vscale.
+          vscaleOperand.push_back(std::make_pair(
+              rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
         }
-        if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr)) {
-          vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+        reified = affine::materializeComputedBound(
+            rewriter, loc, reifiedScalable->map, vscaleOperand);
+      }
+    } else {
+      if (dim) {
+        if (useArithOps) {
+          reified = arith::reifyShapedValueDimBound(
+              rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
         } else {
-          op->emitOpError("expected `vscale_max` to be provided");
-          return WalkResult::skip();
-        }
-
-        auto loc = op->getLoc();
-        auto reifiedScalable =
-            vector::ScalableValueBoundsConstraintSet::computeScalableBound(
-                value, dim, vscaleMin, vscaleMax, *boundType);
-        if (succeeded(reifiedScalable)) {
-          SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
-              vscaleOperand;
-          if (reifiedScalable->map.getNumInputs() == 1) {
-            // The only possible input to the bound is vscale.
-            vscaleOperand.push_back(std::make_pair(
-                rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
-          }
-          reified = affine::materializeComputedBound(
-              rewriter, loc, reifiedScalable->map, vscaleOperand);
+          reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
+                                             value, *dim, stopCondition);
         }
       } else {
-        if (dim) {
-          if (useArithOps) {
-            reified = arith::reifyShapedValueDimBound(
-                rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
-          } else {
-            reified = reifyShapedValueDimBound(
-                rewriter, op->getLoc(), *boundType, value, *dim, stopCondition);
-          }
+        if (useArithOps) {
+          reified = arith::reifyIndexValueBound(
+              rewriter, op->getLoc(), boundType, value, stopCondition);
         } else {
-          if (useArithOps) {
-            reified = arith::reifyIndexValueBound(
-                rewriter, op->getLoc(), *boundType, value, stopCondition);
-          } else {
-            reified = reifyIndexValueBound(rewriter, op->getLoc(), *boundType,
-                                           value, stopCondition);
-          }
+          reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
+                                         value, stopCondition);
         }
       }
-      if (failed(reified)) {
-        op->emitOpError("could not reify bound");
-        return WalkResult::interrupt();
-      }
+    }
+    if (failed(reified)) {
+      op->emitOpError("could not reify bound");
+      return WalkResult::interrupt();
+    }
 
-      // Replace the op with the reified bound.
-      if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
-        rewriter.replaceOp(op, val);
-        return WalkResult::skip();
-      }
-      Value constOp = rewriter.create<arith::ConstantIndexOp>(
-          op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
-      rewriter.replaceOp(op, constOp);
+    // Replace the op with the reified bound.
+    if (auto val = llvm::dyn_cast_if_present<Value>(*reified)) {
+      rewriter.replaceOp(op, val);
       return WalkResult::skip();
     }
-    return WalkResult::advance();
+    Value constOp = rewriter.create<arith::ConstantIndexOp>(
+        op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
+    rewriter.replaceOp(op, constOp);
+    return WalkResult::skip();
   });
   return failure(result.wasInterrupted());
 }
@@ -246,60 +167,42 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 /// Look for "test.compare" ops and emit errors/remarks.
 static LogicalResult testEquality(func::FuncOp funcOp) {
   IRRewriter rewriter(funcOp.getContext());
-  WalkResult result = funcOp.walk([&](Operation *op) {
-    // Look for test.compare ops.
-    if (op->getName().getStringRef() == "test.compare") {
-      if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
-          !op->getOperand(1).getType().isIndex()) {
-        op->emitOpError("invalid op");
-        return WalkResult::skip();
-      }
-
-      // Get comparison operator.
-      std::string cmpStr = "EQ";
-      if (auto cmpAttr = op->getAttrOfType<StringAttr>("cmp"))
-        cmpStr = cmpAttr.str();
-      auto cmpType = parseComparisonOperator(cmpStr);
-      if (failed(cmpType)) {
-        op->emitOpError("invalid comparison operator");
+  WalkResult result = funcOp.walk([&](test::CompareOp op) {
+    auto cmpType = op.getComparisonOperator();
+    if (op.getCompose()) {
+      if (cmpType != ValueBoundsConstraintSet::EQ) {
+        op->emitOpError(
+            "comparison operator must be EQ when 'composed' is specified");
         return WalkResult::interrupt();
       }
-
-      if (op->hasAttr("compose")) {
-        if (cmpType != ValueBoundsConstraintSet::EQ) {
-          op->emitOpError(
-              "comparison operator must be EQ when 'composed' is specified");
-          return WalkResult::interrupt();
-        }
-        FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
-            op->getOperand(0), op->getOperand(1));
-        if (failed(delta)) {
-          op->emitError("could not determine equality");
-        } else if (*delta == 0) {
-          op->emitRemark("equal");
-        } else {
-          op->emitRemark("different");
-        }
-        return WalkResult::advance();
-      }
-
-      auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
-        return ValueBoundsConstraintSet::compare(
-            /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
-            /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
-      };
-      if (compare(*cmpType)) {
-        op->emitRemark("true");
-      } else if (*cmpType != ValueBoundsConstraintSet::EQ &&
-                 compare(invertComparisonOperator(*cmpType))) {
-        op->emitRemark("false");
-      } else if (*cmpType == ValueBoundsConstraintSet::EQ &&
-                 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
-                  compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
-        op->emitRemark("false");
+      FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
+          op->getOperand(0), op->getOperand(1));
+      if (failed(delta)) {
+        op->emitError("could not determine equality");
+      } else if (*delta == 0) {
+        op->emitRemark("equal");
       } else {
-        op->emitError("unknown");
+        op->emitRemark("different");
       }
+      return WalkResult::advance();
+    }
+
+    auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
+      return ValueBoundsConstraintSet::compare(
+          /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
+          /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+    };
+    if (compare(cmpType)) {
+      op->emitRemark("true");
+    } else if (cmpType != ValueBoundsConstraintSet::EQ &&
+               compare(invertComparisonOperator(cmpType))) {
+      op->emitRemark("false");
+    } else if (cmpType == ValueBoundsConstraintSet::EQ &&
+               (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
+                compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
+      op->emitRemark("false");
+    } else {
+      op->emitError("unknown");
     }
     return WalkResult::advance();
   });
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index 47ddcf6524748c..d246c0492a3bd5 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -85,6 +85,7 @@ add_mlir_library(MLIRTestDialect
   MLIRTensorDialect
   MLIRTransformUtils
   MLIRTransforms
+  MLIRValueBoundsOpInterface
 )
 
 add_mlir_translation_library(MLIRTestFromLLVMIRTranslation
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 74378633032ce7..25c5190ca0ef3a 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -516,6 +516,54 @@ static void printOptionalCustomParser(AsmPrinter &p, Operation *,
   p.printAttribute(result);
 }
 
+//===----------------------------------------------------------------------===//
+// ReifyBoundOp
+//===----------------------------------------------------------------------===//
+
+::mlir::presburger::BoundType ReifyBoundOp::getBoundType() {
+  if (getType() == "EQ")
+    return ::mlir::presburger::BoundType::EQ;
+  if (getType() == "LB")
+    return ::mlir::presburger::BoundType::LB;
+  if (getType() == "UB")
+    return ::mlir::presburger::BoundType::UB;
+  llvm_unreachable("invalid bound type");
+}
+
+LogicalResult ReifyBoundOp::verify() {
+  if (isa<ShapedType>(getVar().getType())) {
+    if (!getDim().has_value())
+      return emitOpError("expected 'dim' attribute for shaped type variable");
+  } else if (getVar().getType().isIndex()) {
+    if (getDim().has_value())
+      return emitOpError("unexpected 'dim' attribute for index variable");
+  } else {
+    return emitOpError("expected index-typed variable or shape type variable");
+  }
+  if (getConstant() && getScalable())
+    return emitOpError("'scalable' and 'constant' are mutually exlusive");
+  if (getScalable() != getVscaleMin().has_value())
+    return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+  if (getScalable() != getVscaleMax().has_value())
+    return emitOpError("expected 'vscale_min' if and only if 'scalable'");
+  return success();
+}
+
+::mlir::ValueBoundsConstraintSet::ComparisonOperator
+CompareOp::getComparisonOperator() {
+  if (getCmp() == "EQ")
+    return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+  if (getCmp() == "LT")
+    return ValueBoundsConstraintSet::ComparisonOperator::LT;
+  if (getCmp() == "LE")
+    return ValueBoundsConstraintSet::ComparisonOperator::LE;
+  if (getCmp() == "GT")
+    return ValueBoundsConstraintSet::ComparisonOperator::GT;
+  if (getCmp() == "GE")
+    return ValueBoundsConstraintSet::ComparisonOperator::GE;
+  llvm_unreachable("invalid comparison operator");
+}
+
 //===----------------------------------------------------------------------===//
 // Test removing op with inner ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index 4ba28c47ed1c33..d5b2fbeafc4104 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -41,6 +41,7 @@
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
 #include <memory>
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c88f85b8b6b361..0ab2427038add3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2184,6 +2184,53 @@ def ForwardBufferOp : TEST_Op<"forward_buffer", [Pure]> {
   let results = (outs AnyRankedOrUnrankedMemRef:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// Test ValueBoundsOpInterface
+//===----------------------------------------------------------------------===//
+
+def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
+  let description = [{
+    Reify a bound for the given index-typed value or dimension size of a shaped
+    value. "LB", "EQ" and "UB" bounds are supported. If `scalable` is set, a
+    bound in terms of "vector.vscale" is computed. `vscale_min` and `vscale_max`
+    must be specified in that case.
+  }];
+
+  let arguments = (ins AnyType:$var,
+                       OptionalAttr<I64Attr>:$dim,
+                       DefaultValuedAttr<StrAttr, "\"EQ\"">:$type,
+                       UnitAttr:$constant,
+                       UnitAttr:$scalable,
+                       OptionalAttr<I64Attr>:$vscale_min,
+                       OptionalAttr<I64Attr>:$vscale_max);
+  let results = (outs Index:$result);
+
+  let extraClassDeclaration = [{
+    ::mlir::presburger::BoundType getBoundType();
+  }];
+
+  let hasVerifier = 1;
+}
+
+def CompareOp : TEST_Op<"compare"> {
+  let description = [{
+    Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
+    specified comparison operator was proven to hold. The remark also indicates
+    whether the opposite comparison operator was proven to hold.
+  }];
+
+  let arguments = (ins Index:$lhs,
+                       Index:$rhs,
+                       DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+                       UnitAttr:$compose);
+  let results = (outs);
+
+  let extraClassDeclaration = [{
+    ::mlir::ValueBoundsConstraintSet::ComparisonOperator
+        getComparisonOperator();
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index 684b59e7f62f65..dc5f4047c286db 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -421,6 +421,7 @@ cc_library(
         "//mlir:TranslateLib",
         "//mlir:TransformUtils",
         "//mlir:Transforms",
+        "//mlir:ValueBoundsOpInterface",
         "//mlir:ViewLikeInterface",
     ],
 )
@@ -644,6 +645,7 @@ cc_library(
         "lib/Dialect/Affine/*.cpp",
     ]),
     deps = [
+        ":TestDialect",
         "//llvm:Support",
         "//mlir:AffineAnalysis",
         "//mlir:AffineDialect",

>From 8ba7e9eb27463f4f777a0d03379d95c0916455e3 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Apr 2024 13:29:59 +0000
Subject: [PATCH 2/2] [mlir][Interfaces][WIP] `ValueBoundsOpInterface`:
 `Variable`

---
 .../Dialect/Affine/Transforms/Transforms.h    |  11 +
 .../Dialect/Arith/Transforms/Transforms.h     |  11 +
 .../mlir/Interfaces/ValueBoundsOpInterface.h  | 119 +++---
 .../Affine/IR/ValueBoundsOpInterfaceImpl.cpp  |   6 +-
 .../Affine/Transforms/ReifyValueBounds.cpp    |  15 +-
 .../Arith/IR/ValueBoundsOpInterfaceImpl.cpp   |   8 +-
 .../Dialect/Arith/Transforms/IntNarrowing.cpp |   2 +-
 .../Arith/Transforms/ReifyValueBounds.cpp     |  15 +-
 .../lib/Dialect/Linalg/Transforms/Padding.cpp |   6 +-
 .../Dialect/Linalg/Transforms/Promotion.cpp   |   6 +-
 .../Transforms/IndependenceTransforms.cpp     |   5 +-
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  17 +-
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   |   3 +-
 .../Transforms/IndependenceTransforms.cpp     |   3 +-
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |   4 +-
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 338 ++++++++----------
 .../value-bounds-op-interface-impl.mlir       |  24 ++
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  26 +-
 mlir/test/lib/Dialect/Test/TestDialect.cpp    |  37 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |  16 +-
 20 files changed, 361 insertions(+), 311 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 8e840e744064d5..1ea73752208156 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -53,6 +53,17 @@ void reorderOperandsByHoistability(RewriterBase &rewriter, AffineApplyOp op);
 /// maximally compose chains of AffineApplyOps.
 FailureOr<AffineApplyOp> decompose(RewriterBase &rewriter, AffineApplyOp op);
 
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+                const ValueBoundsConstraintSet::Variable &var,
+                ValueBoundsConstraintSet::StopConditionFn stopCondition,
+                bool closedUB = false);
+
 /// Reify a bound for the given index-typed value in terms of SSA values for
 /// which `stopCondition` is met. If no stop condition is specified, reify in
 /// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
index 970a52a06a11a2..bbc7e5d3e0dd70 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Transforms.h
@@ -24,6 +24,17 @@ enum class BoundType;
 
 namespace arith {
 
+/// Reify a bound for the given variable in terms of SSA values for which
+/// `stopCondition` is met.
+///
+/// By default, lower/equal bounds are closed and upper bounds are open. If
+/// `closedUB` is set to "true", upper bounds are also closed.
+FailureOr<OpFoldResult>
+reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
+                const ValueBoundsConstraintSet::Variable &var,
+                ValueBoundsConstraintSet::StopConditionFn stopCondition,
+                bool closedUB = false);
+
 /// Reify a bound for the given index-typed value in terms of SSA values for
 /// which `stopCondition` is met. If no stop condition is specified, reify in
 /// terms of the operands of the owner op.
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3..ac17ace5a976d2 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
 public:
   static char ID;
 
+  /// A variable that can be added to the constraint set as a "column". The
+  /// value bounds infrastructure can compute bounds for variables and compare
+  /// two variables.
+  ///
+  /// Internally, a variable is represented as an affine map and operands.
+  class Variable {
+  public:
+    /// Construct a variable for an index-typed attribute or SSA value.
+    Variable(OpFoldResult ofr);
+
+    /// Construct a variable for an index-typed SSA value.
+    Variable(Value indexValue);
+
+    /// Construct a variable for a dimension of a shaped value.
+    Variable(Value shapedValue, int64_t dim);
+
+    /// Construct a variable for an index-typed attribute/SSA value or for a
+    /// dimension of a shaped value. A non-null dimension must be provided if
+    /// and only if `ofr` is a shaped value.
+    Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+    /// Construct a variable for a map and its operands.
+    Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+    Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+    MLIRContext *getContext() const { return map.getContext(); }
+
+  private:
+    friend class ValueBoundsConstraintSet;
+    AffineMap map;
+    ValueDimList mapOperands;
+  };
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
   using StopConditionFn = std::function<bool(
       Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
-  /// Compute a bound for the given index-typed value or shape dimension size.
-  /// The computed bound is stored in `resultMap`. The operands of the bound are
-  /// stored in `mapOperands`. An operand is either an index-type SSA value
-  /// or a shaped value and a dimension.
+  /// Compute a bound for the given variable. The computed bound is stored in
+  /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+  /// operand is either an index-type SSA value or a shaped value and a
+  /// dimension.
   ///
-  /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
-  /// is computed in terms of values/dimensions for which `stopCondition`
-  /// evaluates to "true". To that end, the backward slice (reverse use-def
-  /// chain) of the given value is visited in a worklist-driven manner and the
-  /// constraint set is populated according to `ValueBoundsOpInterface` for each
-  /// visited value.
+  /// The bound is computed in terms of values/dimensions for which
+  /// `stopCondition` evaluates to "true". To that end, the backward slice
+  /// (reverse use-def chain) of the given value is visited in a worklist-driven
+  /// manner and the constraint set is populated according to
+  /// `ValueBoundsOpInterface` for each visited value.
   ///
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
-  static LogicalResult computeBound(AffineMap &resultMap,
-                                    ValueDimList &mapOperands,
-                                    presburger::BoundType type, Value value,
-                                    std::optional<int64_t> dim,
-                                    StopConditionFn stopCondition,
-                                    bool closedUB = false);
+  static LogicalResult
+  computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+               presburger::BoundType type, const Variable &var,
+               StopConditionFn stopCondition, bool closedUB = false);
 
   /// Compute a bound in terms of the values/dimensions in `dependencies`. The
   /// computed bound consists of only constant terms and dependent values (or
   /// dimension sizes thereof).
   static LogicalResult
   computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                        presburger::BoundType type, Value value,
-                        std::optional<int64_t> dim, ValueDimList dependencies,
-                        bool closedUB = false);
+                        presburger::BoundType type, const Variable &var,
+                        ValueDimList dependencies, bool closedUB = false);
 
   /// Compute a bound in that is independent of all values in `independencies`.
   ///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
   /// appear in the computed bound.
   static LogicalResult
   computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
-                          presburger::BoundType type, Value value,
-                          std::optional<int64_t> dim, ValueRange independencies,
-                          bool closedUB = false);
+                          presburger::BoundType type, const Variable &var,
+                          ValueRange independencies, bool closedUB = false);
 
-  /// Compute a constant bound for the given affine map, where dims and symbols
-  /// are bound to the given operands. The affine map must have exactly one
-  /// result.
+  /// Compute a constant bound for the given variable.
   ///
   /// This function traverses the backward slice of the given operands in a
   /// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
   /// By default, lower/equal bounds are closed and upper bounds are open. If
   /// `closedUB` is set to "true", upper bounds are also closed.
   static FailureOr<int64_t>
-  computeConstantBound(presburger::BoundType type, Value value,
-                       std::optional<int64_t> dim = std::nullopt,
+  computeConstantBound(presburger::BoundType type, const Variable &var,
                        StopConditionFn stopCondition = nullptr,
                        bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
-  static FailureOr<int64_t> computeConstantBound(
-      presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
-      StopConditionFn stopCondition = nullptr, bool closedUB = false);
 
   /// Compute a constant delta between the given two values. Return "failure"
   /// if a constant delta could not be determined.
@@ -221,9 +241,8 @@ class ValueBoundsConstraintSet
   /// proven. This could be because the specified relation does in fact not hold
   /// or because there is not enough information in the constraint set. In other
   /// words, if we do not know for sure, this function returns "false".
-  bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                          ComparisonOperator cmp, OpFoldResult rhs,
-                          std::optional<int64_t> rhsDim);
+  bool populateAndCompare(const Variable &lhs, ComparisonOperator cmp,
+                          const Variable &rhs);
 
   /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
   /// specified relation could not be proven. This could be because the
@@ -233,24 +252,12 @@ class ValueBoundsConstraintSet
   ///
   /// This function keeps traversing the backward slice of lhs/rhs until could
   /// prove the relation or until it ran out of IR.
-  static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                      ComparisonOperator cmp, OpFoldResult rhs,
-                      std::optional<int64_t> rhsDim);
-  static bool compare(AffineMap lhs, ValueDimList lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ValueDimList rhsOperands);
-  static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
-                      ComparisonOperator cmp, AffineMap rhs,
-                      ArrayRef<Value> rhsOperands);
-
-  /// Compute whether the given values/dimensions are equal. Return "failure" if
+  static bool compare(const Variable &lhs, ComparisonOperator cmp,
+                      const Variable &rhs);
+
+  /// Compute whether the given variables are equal. Return "failure" if
   /// equality could not be determined.
-  ///
-  /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
-  /// index-typed.
-  static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
-                                  std::optional<int64_t> dim1 = std::nullopt,
-                                  std::optional<int64_t> dim2 = std::nullopt);
+  static FailureOr<bool> areEqual(const Variable &var1, const Variable &var2);
 
   /// Return "true" if the given slices are guaranteed to be overlapping.
   /// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +324,6 @@ class ValueBoundsConstraintSet
   ///
   /// This function does not analyze any IR and does not populate any additional
   /// constraints.
-  bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
-                        ComparisonOperator cmp, OpFoldResult rhs,
-                        std::optional<int64_t> rhsDim);
   bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
 
   /// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +378,7 @@ class ValueBoundsConstraintSet
   /// constraint system. Return the position of the new column. Any operands
   /// that were not analyzed yet are put on the worklist.
   int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+  int64_t insert(const Variable &var, bool isSymbol = true);
 
   /// Project out the given column in the constraint set.
   void projectOut(int64_t pos);
@@ -381,6 +386,8 @@ class ValueBoundsConstraintSet
   /// Project out all columns for which the condition holds.
   void projectOut(function_ref<bool(ValueDim)> condition);
 
+  void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
   /// Mapping of columns to values/shape dimensions.
   SmallVector<std::optional<ValueDim>> positionToValueDim;
   /// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d..82a9fb0d490882 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
   mapOperands.push_back(value1);
   mapOperands.push_back(value2);
   affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
-  ValueDimList valueDims;
-  for (Value v : mapOperands)
-    valueDims.push_back({v, std::nullopt});
   return ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::EQ, map, valueDims);
+      presburger::BoundType::EQ,
+      ValueBoundsConstraintSet::Variable(map, mapOperands));
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7..1a266b72d1f8d3 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -16,16 +16,15 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
-                Value value, std::optional<int64_t> dim,
-                ValueBoundsConstraintSet::StopConditionFn stopCondition,
-                bool closedUB) {
+FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
+    OpBuilder &b, Location loc, presburger::BoundType type,
+    const ValueBoundsConstraintSet::Variable &var,
+    ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
   // Compute bound.
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, var, stopCondition, closedUB)))
     return failure();
 
   // Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     // the owner of `value`.
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, dim,
+  return reifyValueBound(b, loc, type, {value, dim},
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
                              ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+  return reifyValueBound(b, loc, type, value,
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d43808bc45df..7cfcc4180539c2 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
     // If trueValue <= falseValue:
     // * result <= falseValue
     // * result >= trueValue
-    if (cstr.compare(trueValue, dim,
+    if (cstr.compare(/*lhs=*/{trueValue, dim},
                      ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     falseValue, dim)) {
+                     /*rhs=*/{falseValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
     // If falseValue <= trueValue:
     // * result <= trueValue
     // * result >= falseValue
-    if (cstr.compare(falseValue, dim,
+    if (cstr.compare(/*lhs=*/{falseValue, dim},
                      ValueBoundsConstraintSet::ComparisonOperator::LE,
-                     trueValue, dim)) {
+                     /*rhs=*/{trueValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a..f87f3d6350c022 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
       return failure();
 
     FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
-        presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+        presburger::BoundType::UB, in,
         /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(ub))
       return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190e..5fb7953f937007 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
   return buildExpr(map.getResult(0));
 }
 
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
-                Value value, std::optional<int64_t> dim,
-                ValueBoundsConstraintSet::StopConditionFn stopCondition,
-                bool closedUB) {
+FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
+    OpBuilder &b, Location loc, presburger::BoundType type,
+    const ValueBoundsConstraintSet::Variable &var,
+    ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
   // Compute bound.
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeBound(
-          boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+          boundMap, mapOperands, type, var, stopCondition, closedUB)))
     return failure();
 
   // Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     // the owner of `value`.
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, dim,
+  return reifyValueBound(b, loc, type, {value, dim},
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
                              ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
-  return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+  return reifyValueBound(b, loc, type, value,
                          stopCondition ? stopCondition : reifyToOperands,
                          closedUB);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db248989..518d2e138c02a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
     // Otherwise, try to compute a constant upper bound for the size value.
     FailureOr<int64_t> upperBound =
         ValueBoundsConstraintSet::computeConstantBound(
-            presburger::BoundType::UB, opOperand->get(),
-            /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+            presburger::BoundType::UB,
+            {opOperand->get(),
+             /*dim=*/i},
+            /*stopCondition=*/nullptr, /*closedUB=*/true);
     if (failed(upperBound)) {
       LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
       return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049..71eb59d40836c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
     if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
       size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
     } else {
-      Value materializedSize =
-          getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
       FailureOr<int64_t> upperBound =
           ValueBoundsConstraintSet::computeConstantBound(
-              presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+              presburger::BoundType::UB, rangeValue.size,
               /*stopCondition=*/nullptr, /*closedUB=*/true);
       size = failed(upperBound)
-                 ? materializedSize
+                 ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
                  : b.create<arith::ConstantIndexOp>(loc, *upperBound);
     }
     LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9..1f06318cbd60e0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
                                                ValueRange independencies) {
   if (ofr.is<Attribute>())
     return ofr;
-  Value value = ofr.get<Value>();
   AffineMap boundMap;
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
-          boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+          /*closedUB=*/true)))
     return failure();
   return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a..17a1c016ea16d5 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
     if (cstr.populateAndCompare(
-            yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
-            iterArg, dim)) {
+            /*lhs=*/{yieldedValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::EQ,
+            /*rhs=*/{iterArg, dim})) {
       if (dim.has_value()) {
         cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
       } else {
-        cstr.bound(value) == initArg;
+        cstr.bound(value) == cstr.getExpr(initArg);
       }
     }
   }
@@ -113,8 +114,9 @@ struct IfOpInterface
     // * result <= elseValue
     // * result >= thenValue
     if (cstr.populateAndCompare(
-            thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            elseValue, dim)) {
+            /*lhs=*/{thenValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{elseValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
     // * result <= thenValue
     // * result >= elseValue
     if (cstr.populateAndCompare(
-            elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
-            thenValue, dim)) {
+            /*lhs=*/{elseValue, dim},
+            ValueBoundsConstraintSet::ComparisonOperator::LE,
+            /*rhs=*/{thenValue, dim})) {
       if (dim) {
         cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
         cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8e301c13..d25efcf50ec566 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
 
   info.isAlignedToInnerTileSize = false;
   FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
-      presburger::BoundType::UB,
-      getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+      presburger::BoundType::UB, tileSize,
       /*stopCondition=*/nullptr, /*closedUB=*/true);
   std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
   if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 721730862d49b3..a89ce20048dff3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
   ValueDimList mapOperands;
   if (failed(ValueBoundsConstraintSet::computeIndependentBound(
           boundMap, mapOperands, presburger::BoundType::UB, value,
-          /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+          independencies,
+          /*closedUB=*/true)))
     return failure();
   return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
 }
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2f7a1700..15381ec520e211 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), srcDim, resultDim);
+        {op.getSource(), srcDim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
-        op.getSource(), op.getResult(), dim, resultDim);
+        {op.getSource(), dim}, {op.getResult(), resultDim});
     if (failed(equalDimSize) || !*equalDimSize)
       return false;
     ++resultDim;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ffa4c0b55cad7c..93b2d89e601d6d 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -25,6 +25,12 @@ namespace mlir {
 #include "mlir/Interfaces/ValueBoundsOpInterface.cpp.inc"
 } // namespace mlir
 
+static Operation *getOwnerOfValue(Value value) {
+  if (auto bbArg = dyn_cast<BlockArgument>(value))
+    return bbArg.getOwner()->getParentOp();
+  return value.getDefiningOp();
+}
+
 HyperrectangularSlice::HyperrectangularSlice(ArrayRef<OpFoldResult> offsets,
                                              ArrayRef<OpFoldResult> sizes,
                                              ArrayRef<OpFoldResult> strides)
@@ -67,6 +73,83 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr)
+    : Variable(ofr, std::nullopt) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value indexValue)
+    : Variable(static_cast<OpFoldResult>(indexValue)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(Value shapedValue, int64_t dim)
+    : Variable(static_cast<OpFoldResult>(shapedValue), std::optional(dim)) {}
+
+ValueBoundsConstraintSet::Variable::Variable(OpFoldResult ofr,
+                                             std::optional<int64_t> dim) {
+  Builder b(ofr.getContext());
+  if (auto constInt = ::getConstantIntValue(ofr)) {
+    assert(!dim && "expected no dim for index-typed values");
+    map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/0,
+                         b.getAffineConstantExpr(*constInt));
+    return;
+  }
+  Value value = cast<Value>(ofr);
+#ifndef NDEBUG
+  if (dim) {
+    assert(isa<ShapedType>(value.getType()) && "expected shaped type");
+  } else {
+    assert(value.getType().isIndex() && "expected index type");
+  }
+#endif // NDEBUG
+  map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
+                       b.getAffineSymbolExpr(0));
+  mapOperands.emplace_back(value, dim);
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Variable> mapOperands) {
+  assert(map.getNumResults() == 1 && "expected single result");
+
+  // Turn all dims into symbols.
+  Builder b(map.getContext());
+  SmallVector<AffineExpr> dimReplacements, symReplacements;
+  for (int64_t i = 0; i < map.getNumDims(); ++i)
+    dimReplacements.push_back(b.getAffineSymbolExpr(i));
+  for (int64_t i = 0; i < map.getNumSymbols(); ++i)
+    symReplacements.push_back(b.getAffineSymbolExpr(i + map.getNumDims()));
+  AffineMap tmpMap = map.replaceDimsAndSymbols(
+      dimReplacements, symReplacements, /*numResultDims=*/0,
+      /*numResultSyms=*/map.getNumSymbols() + map.getNumDims());
+
+  // Inline operands.
+  DenseMap<AffineExpr, AffineExpr> replacements;
+  for (auto [index, var] : llvm::enumerate(mapOperands)) {
+    assert(var.map.getNumResults() == 1 && "expected single result");
+    assert(var.map.getNumDims() == 0 && "expected only symbols");
+    SmallVector<AffineExpr> symReplacements;
+    for (auto valueDim : var.mapOperands) {
+      auto it = llvm::find(this->mapOperands, valueDim);
+      if (it != this->mapOperands.end()) {
+        // There is already a symbol for this operand.
+        symReplacements.push_back(b.getAffineSymbolExpr(
+            std::distance(this->mapOperands.begin(), it)));
+      } else {
+        // This is a new operand: add a new symbol.
+        symReplacements.push_back(
+            b.getAffineSymbolExpr(this->mapOperands.size()));
+        this->mapOperands.push_back(valueDim);
+      }
+    }
+    replacements[b.getAffineSymbolExpr(index)] =
+        var.map.getResult(0).replaceSymbols(symReplacements);
+  }
+  this->map = tmpMap.replace(replacements, /*numResultDims=*/0,
+                             /*numResultSyms=*/this->mapOperands.size());
+}
+
+ValueBoundsConstraintSet::Variable::Variable(AffineMap map,
+                                             ArrayRef<Value> mapOperands)
+    : Variable(map, llvm::map_to_vector(mapOperands,
+                                        [](Value v) { return Variable(v); })) {}
+
 ValueBoundsConstraintSet::ValueBoundsConstraintSet(
     MLIRContext *ctx, StopConditionFn stopCondition)
     : builder(ctx), stopCondition(stopCondition) {
@@ -176,6 +259,11 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
   assert(!valueDimToPosition.contains(valueDim) && "already mapped");
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting constraint set column " << pos
+                          << " for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -194,6 +282,8 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
 int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
   int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
                          : cstr.appendVar(VarKind::SetDim);
+  LLVM_DEBUG(llvm::dbgs() << "Inserting anonymous constraint set column " << pos
+                          << "\n");
   positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
   // Update reverse mapping.
   for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
@@ -224,6 +314,10 @@ int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
   return pos;
 }
 
+int64_t ValueBoundsConstraintSet::insert(const Variable &var, bool isSymbol) {
+  return insert(var.map, var.mapOperands, isSymbol);
+}
+
 int64_t ValueBoundsConstraintSet::getPos(Value value,
                                          std::optional<int64_t> dim) const {
 #ifndef NDEBUG
@@ -232,7 +326,10 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
           cast<BlockArgument>(value).getOwner()->isEntryBlock()) &&
          "unstructured control flow is not supported");
 #endif // NDEBUG
-
+  LLVM_DEBUG(llvm::dbgs() << "Getting pos for: " << value
+                          << " (dim: " << dim.value_or(kIndexValue)
+                          << ", owner: " << getOwnerOfValue(value)->getName()
+                          << ")\n");
   auto it =
       valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
   assert(it != valueDimToPosition.end() && "expected mapped entry");
@@ -253,12 +350,6 @@ bool ValueBoundsConstraintSet::isMapped(Value value,
   return it != valueDimToPosition.end();
 }
 
-static Operation *getOwnerOfValue(Value value) {
-  if (auto bbArg = dyn_cast<BlockArgument>(value))
-    return bbArg.getOwner()->getParentOp();
-  return value.getDefiningOp();
-}
-
 void ValueBoundsConstraintSet::processWorklist() {
   LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
@@ -346,41 +437,47 @@ void ValueBoundsConstraintSet::projectOut(
   }
 }
 
+void ValueBoundsConstraintSet::projectOutAnonymous(
+    std::optional<int64_t> except) {
+  int64_t nextPos = 0;
+  while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
+    if (positionToValueDim[nextPos].has_value() || except == nextPos) {
+      ++nextPos;
+    } else {
+      projectOut(nextPos);
+      // The column was projected out so another column is now at that position.
+      // Do not increase the counter.
+    }
+  }
+}
+
 LogicalResult ValueBoundsConstraintSet::computeBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, StopConditionFn stopCondition,
-    bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
+    const Variable &var, StopConditionFn stopCondition, bool closedUB) {
+  MLIRContext *ctx = var.getContext();
   int64_t ubAdjustment = closedUB ? 0 : 1;
-  Builder b(value.getContext());
+  Builder b(ctx);
   mapOperands.clear();
 
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
-  ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
-  assert(!stopCondition(value, dim, cstr) &&
-         "stop condition should not be satisfied for starting point");
-  int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+  ValueBoundsConstraintSet cstr(ctx, stopCondition);
+  int64_t pos = cstr.insert(var, /*isSymbol=*/false);
+  assert(pos == 0 && "expected first column");
   cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
   cstr.projectOut([&](ValueDim p) {
-    // Do not project out `valueDim`.
-    if (valueDim == p)
-      return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
     return !stopCondition(p.first, maybeDim, cstr);
   });
+  cstr.projectOutAnonymous(/*except=*/pos);
 
   // Compute lower and upper bounds for `valueDim`.
   SmallVector<AffineMap> lb(1), ub(1);
-  cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lb, &ub,
+  cstr.cstr.getSliceBounds(pos, 1, ctx, &lb, &ub,
                            /*closedUB=*/true);
 
   // Note: There are TODOs in the implementation of `getSliceBounds`. In such a
@@ -477,10 +574,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
 
 LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueDimList dependencies,
-    bool closedUB) {
+    const Variable &var, ValueDimList dependencies, bool closedUB) {
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
@@ -489,8 +585,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
 
 LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
     AffineMap &resultMap, ValueDimList &mapOperands, presburger::BoundType type,
-    Value value, std::optional<int64_t> dim, ValueRange independencies,
-    bool closedUB) {
+    const Variable &var, ValueRange independencies, bool closedUB) {
   // Return "true" if the given value is independent of all values in
   // `independencies`. I.e., neither the value itself nor any value in the
   // backward slice (reverse use-def chain) is contained in `independencies`.
@@ -516,7 +611,7 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 
   // Reify bounds in terms of any independent values.
   return computeBound(
-      resultMap, mapOperands, type, value, dim,
+      resultMap, mapOperands, type, var,
       [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return isIndependent(v);
       },
@@ -524,35 +619,8 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
 }
 
 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, Value value, std::optional<int64_t> dim,
-    StopConditionFn stopCondition, bool closedUB) {
-#ifndef NDEBUG
-  assertValidValueDim(value, dim);
-#endif // NDEBUG
-
-  AffineMap map =
-      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
-                     Builder(value.getContext()).getAffineDimExpr(0));
-  return computeConstantBound(type, map, {{value, dim}}, stopCondition,
-                              closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+    presburger::BoundType type, const Variable &var,
     StopConditionFn stopCondition, bool closedUB) {
-  ValueDimList valueDims;
-  for (Value v : operands) {
-    assert(v.getType().isIndex() && "expected index type");
-    valueDims.emplace_back(v, std::nullopt);
-  }
-  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ValueDimList operands,
-    StopConditionFn stopCondition, bool closedUB) {
-  assert(map.getNumResults() == 1 && "expected affine map with one result");
-
   // Default stop condition if none was specified: Keep adding constraints until
   // a bound could be computed.
   int64_t pos = 0;
@@ -562,8 +630,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   };
 
   ValueBoundsConstraintSet cstr(
-      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
-  pos = cstr.populateConstraints(map, operands);
+      var.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  pos = cstr.populateConstraints(var.map, var.mapOperands);
   assert(pos == 0 && "expected `map` is the first column");
 
   // Compute constant bound for `valueDim`.
@@ -608,22 +676,13 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
   Builder b(value1.getContext());
   AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
                                  b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
-  return computeConstantBound(presburger::BoundType::EQ, map,
-                              {{value1, dim1}, {value2, dim2}});
+  return computeConstantBound(presburger::BoundType::EQ,
+                              Variable(map, {{value1, dim1}, {value2, dim2}}));
 }
 
-bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
-                                                std::optional<int64_t> lhsDim,
-                                                ComparisonOperator cmp,
-                                                OpFoldResult rhs,
-                                                std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+                                          ComparisonOperator cmp,
+                                          int64_t rhsPos) {
   // This function returns "true" if "lhs CMP rhs" is proven to hold.
   //
   // Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -640,50 +699,6 @@ bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
     return false;
   }
 
-  // EQ can be expressed as LE and GE.
-  if (cmp == EQ)
-    return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
-           compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
-
-  // Construct inequality. For the above example: lhs > rhs.
-  // `IntegerRelation` inequalities are expressed in the "flattened" form and
-  // with ">= 0". I.e., lhs - rhs - 1 >= 0.
-  SmallVector<int64_t> eq(cstr.getNumCols(), 0);
-  auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
-                     int64_t factor) {
-    if (auto constVal = ::getConstantIntValue(ofr)) {
-      eq[cstr.getNumCols() - 1] += *constVal * factor;
-    } else {
-      eq[getPos(cast<Value>(ofr), dim)] += factor;
-    }
-  };
-  if (cmp == LT || cmp == LE) {
-    addToEq(lhs, lhsDim, 1);
-    addToEq(rhs, rhsDim, -1);
-  } else if (cmp == GT || cmp == GE) {
-    addToEq(lhs, lhsDim, -1);
-    addToEq(rhs, rhsDim, 1);
-  } else {
-    llvm_unreachable("unsupported comparison operator");
-  }
-  if (cmp == LE || cmp == GE)
-    eq[cstr.getNumCols() - 1] -= 1;
-
-  // Add inequality to the constraint set and check if it made the constraint
-  // set empty.
-  int64_t ineqPos = cstr.getNumInequalities();
-  cstr.addInequality(eq);
-  bool isEmpty = cstr.isEmpty();
-  cstr.removeInequality(ineqPos);
-  return isEmpty;
-}
-
-bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
-                                          ComparisonOperator cmp,
-                                          int64_t rhsPos) {
-  // This function returns "true" if "lhs CMP rhs" is proven to hold. For
-  // detailed documentation, see `compareValueDims`.
-
   // EQ can be expressed as LE and GE.
   if (cmp == EQ)
     return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
@@ -712,48 +727,17 @@ bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
   return isEmpty;
 }
 
-bool ValueBoundsConstraintSet::populateAndCompare(
-    OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
-    OpFoldResult rhs, std::optional<int64_t> rhsDim) {
-#ifndef NDEBUG
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    assertValidValueDim(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    assertValidValueDim(rhsVal, rhsDim);
-#endif // NDEBUG
-
-  if (auto lhsVal = dyn_cast<Value>(lhs))
-    populateConstraints(lhsVal, lhsDim);
-  if (auto rhsVal = dyn_cast<Value>(rhs))
-    populateConstraints(rhsVal, rhsDim);
-
-  return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+bool ValueBoundsConstraintSet::populateAndCompare(const Variable &lhs,
+                                                  ComparisonOperator cmp,
+                                                  const Variable &rhs) {
+  int64_t lhsPos = populateConstraints(lhs.map, lhs.mapOperands);
+  int64_t rhsPos = populateConstraints(rhs.map, rhs.mapOperands);
+  return comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
-                                       std::optional<int64_t> lhsDim,
-                                       ComparisonOperator cmp, OpFoldResult rhs,
-                                       std::optional<int64_t> rhsDim) {
-  auto stopCondition = [&](Value v, std::optional<int64_t> dim,
-                           ValueBoundsConstraintSet &cstr) {
-    // Keep processing as long as lhs/rhs are not mapped.
-    if (auto lhsVal = dyn_cast<Value>(lhs))
-      if (!cstr.isMapped(lhsVal, dim))
-        return false;
-    if (auto rhsVal = dyn_cast<Value>(rhs))
-      if (!cstr.isMapped(rhsVal, dim))
-        return false;
-    // Keep processing as long as the relation cannot be proven.
-    return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
-  };
-
-  ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
-}
-
-bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ValueDimList rhsOperands) {
+bool ValueBoundsConstraintSet::compare(const Variable &lhs,
+                                       ComparisonOperator cmp,
+                                       const Variable &rhs) {
   int64_t lhsPos = -1, rhsPos = -1;
   auto stopCondition = [&](Value v, std::optional<int64_t> dim,
                            ValueBoundsConstraintSet &cstr) {
@@ -765,39 +749,17 @@ bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
     return cstr.comparePos(lhsPos, cmp, rhsPos);
   };
   ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
-  lhsPos = cstr.insert(lhs, lhsOperands);
-  rhsPos = cstr.insert(rhs, rhsOperands);
-  cstr.processWorklist();
+  lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands);
+  rhsPos = cstr.populateConstraints(rhs.map, rhs.mapOperands);
   return cstr.comparePos(lhsPos, cmp, rhsPos);
 }
 
-bool ValueBoundsConstraintSet::compare(AffineMap lhs,
-                                       ArrayRef<Value> lhsOperands,
-                                       ComparisonOperator cmp, AffineMap rhs,
-                                       ArrayRef<Value> rhsOperands) {
-  ValueDimList lhsValueDimOperands =
-      llvm::map_to_vector(lhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  ValueDimList rhsValueDimOperands =
-      llvm::map_to_vector(rhsOperands, [](Value v) {
-        return std::make_pair(v, std::optional<int64_t>());
-      });
-  return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
-                                           rhsValueDimOperands);
-}
-
-FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
-                                   std::optional<int64_t> dim1,
-                                   std::optional<int64_t> dim2) {
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
-                                        value2, dim2))
+FailureOr<bool> ValueBoundsConstraintSet::areEqual(const Variable &var1,
+                                                   const Variable &var2) {
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::EQ, var2))
     return true;
-  if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
-                                        value2, dim2) ||
-      ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
-                                        value2, dim2))
+  if (ValueBoundsConstraintSet::compare(var1, ComparisonOperator::LT, var2) ||
+      ValueBoundsConstraintSet::compare(var1, ComparisonOperator::GT, var2))
     return false;
   return failure();
 }
@@ -833,7 +795,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;
@@ -850,7 +812,7 @@ ValueBoundsConstraintSet::areOverlappingSlices(MLIRContext *ctx,
       AffineMap foldedMap =
           foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
       FailureOr<int64_t> constBound = computeConstantBound(
-          presburger::BoundType::EQ, foldedMap, valueOperands);
+          presburger::BoundType::EQ, Variable(foldedMap, valueOperands));
       foundUnknownBound |= failed(constBound);
       if (succeeded(constBound) && *constBound <= 0)
         return false;
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 23c6872dcebe94..935c08aceff548 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -131,3 +131,27 @@ func.func @compare_affine_min(%a: index, %b: index) {
   "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
   return
 }
+
+// -----
+
+func.func @compare_const_map() {
+  %c5 = arith.constant 5 : index
+  // expected-remark @below{{true}}
+  "test.compare"(%c5) {cmp = "GT", rhs_map = affine_map<() -> (4)>}
+      : (index) -> ()
+  // expected-remark @below{{true}}
+  "test.compare"(%c5) {cmp = "LT", lhs_map = affine_map<() -> (4)>}
+      : (index) -> ()
+  return
+}
+
+// -----
+
+func.func @compare_maps(%a: index, %b: index) {
+  // expected-remark @below{{true}}
+  "test.compare"(%a, %b, %b, %a)
+      {cmp = "GT", lhs_map = affine_map<(d0, d1) -> (1 + d0 + d1)>,
+       rhs_map = affine_map<(d0, d1) -> (d0 + d1)>}
+      : (index, index, index, index) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6730f9b292ad93..b098a5a23fd316 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
     FailureOr<OpFoldResult> reified = failure();
     if (constant) {
       auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
-          boundType, value, dim, /*stopCondition=*/nullptr);
+          boundType, {value, dim}, /*stopCondition=*/nullptr);
       if (succeeded(reifiedConst))
         reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
     } else if (scalable) {
@@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
             rewriter, loc, reifiedScalable->map, vscaleOperand);
       }
     } else {
-      if (dim) {
-        if (useArithOps) {
-          reified = arith::reifyShapedValueDimBound(
-              rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
-        } else {
-          reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
-                                             value, *dim, stopCondition);
-        }
+      if (useArithOps) {
+        reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
+                                         op.getVariable(), stopCondition);
       } else {
-        if (useArithOps) {
-          reified = arith::reifyIndexValueBound(
-              rewriter, op->getLoc(), boundType, value, stopCondition);
-        } else {
-          reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
-                                         value, stopCondition);
-        }
+        reified = reifyValueBound(rewriter, op->getLoc(), boundType,
+                                  op.getVariable(), stopCondition);
       }
     }
     if (failed(reified)) {
@@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
     }
 
     auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
-      return ValueBoundsConstraintSet::compare(
-          /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
-          /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+      return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
     };
     if (compare(cmpType)) {
       op->emitRemark("true");
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190ca0ef3a..36d7606fe1345b 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
   return success();
 }
 
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+  if (getDim().has_value())
+    return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+  return ValueBoundsConstraintSet::Variable(getVar());
+}
+
 ::mlir::ValueBoundsConstraintSet::ComparisonOperator
 CompareOp::getComparisonOperator() {
   if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
   llvm_unreachable("invalid comparison operator");
 }
 
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+  if (!getLhsMap())
+    return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+  SmallVector<Value> mapOperands(
+      getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+  return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+  int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+  if (!getRhsMap())
+    return ValueBoundsConstraintSet::Variable(
+        getVarOperands()[rhsOperandsBegin]);
+  SmallVector<Value> mapOperands(
+      getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+  return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+  if (getCompose() && (getLhsMap() || getRhsMap()))
+    return emitOpError(
+        "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+  int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+  expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+  if (getVarOperands().size() != expectedNumOperands)
+    return emitOpError("expected ")
+           << expectedNumOperands << " operands, but got "
+           << getVarOperands().size();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Test removing op with inner ops.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0ab2427038add3..2ed50c22faa1b1 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
 
   let extraClassDeclaration = [{
     ::mlir::presburger::BoundType getBoundType();
+    ::mlir::ValueBoundsConstraintSet::Variable getVariable();
   }];
 
   let hasVerifier = 1;
@@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> {
     Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
     specified comparison operator was proven to hold. The remark also indicates
     whether the opposite comparison operator was proven to hold.
+
+    `var_operands` must have exactly two operands: one for the LHS operand and
+    one for the RHS operand. If `lhs_map` is specified, as many operands as
+    `lhs_map` has inputs are expected instead of the first operand. If `rhs_map`
+    is specified, as many operands as `rhs_map` has inputs are expected instead
+    of the second operand.
   }];
 
-  let arguments = (ins Index:$lhs,
-                       Index:$rhs,
+  let arguments = (ins Variadic<Index>:$var_operands,
                        DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+                       OptionalAttr<AffineMapAttr>:$lhs_map,
+                       OptionalAttr<AffineMapAttr>:$rhs_map,
                        UnitAttr:$compose);
   let results = (outs);
 
   let extraClassDeclaration = [{
     ::mlir::ValueBoundsConstraintSet::ComparisonOperator
         getComparisonOperator();
+    ::mlir::ValueBoundsConstraintSet::Variable getLhs();
+    ::mlir::ValueBoundsConstraintSet::Variable getRhs();
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list