[Mlir-commits] [mlir] [mlir][Vector] Add utility for computing scalable value bounds (PR #83876)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Mar 12 06:35:00 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/83876

>From ffb5404711c2f3ce1a7340a33bd8796bc6b083b9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 1 Mar 2024 18:23:09 +0000
Subject: [PATCH 1/5] [mlir][Vector] Add utility for computing scalable value
 bounds

This adds a new API built with the `ValueBoundsConstraintSet` to compute
the bounds of possibly scalable quantities. It uses knowledge of the
range of vscale (which is defined by the target architecture), to solve
for the bound as either a constant or an expression in terms of vscale.

The result is an `AffineMap` will always take at most one parameter,
vscale, and return a single result, which is the bound of `value`.

The API is defined as follows:

```c++
FailureOr<ConstantOrScalableBound>
vector::computeScalableBound(Value value, std::optional<int64_t> dim,
                             unsigned vscaleMin, unsigned vscaleMax,
                             presburger::BoundType boundType);
```

Note: `ConstantOrScalableBound` is a thin wrapper over the `AffineMap`
with a utility for converting the bound to a single quantity (i.e. a
size and scalable flag).

We believe this API could prove useful downstream in IREE (which uses
a similar analysis to hoist allocas, which currently fails for scalable
vectors).
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  29 ++++
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  20 ++-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 130 +++++++++++++++++
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 102 +++++++++----
 .../Vector/test-scalable-upper-bound.mlir     | 137 ++++++++++++++++++
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  30 +++-
 6 files changed, 417 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index f6b03a0f2c8007..635d609d0b3e71 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -98,6 +99,34 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
+struct ConstantOrScalableBound {
+  AffineMap map;
+
+  struct BoundSize {
+    int64_t baseSize{0};
+    bool scalable{false};
+  };
+
+  /// Get the (possibly) scalable size of the bound, returns failure if the
+  /// bound cannot be represented as a single quantity.
+  FailureOr<BoundSize> getSize() const;
+};
+
+/// Computes a (possibly) scalable bound for a given value. This is similar to
+/// `ValueBoundsConstraintSet::computeConstantBound()`, but uses knowledge of
+/// the range of vscale to compute either a constant bound, an expression in
+/// terms of vscale, or failure if no bound can be computed.
+///
+/// The resulting `AffineMap` will always take at most one parameter, vscale,
+/// and return a single result, which is the bound of `value`.
+///
+/// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
+/// `vscaleMax`, the resulting bound (if found), will be constant.
+FailureOr<ConstantOrScalableBound>
+computeScalableBound(Value value, std::optional<int64_t> dim,
+                     unsigned vscaleMin, unsigned vscaleMax,
+                     presburger::BoundType boundType);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 28dadfb9ecf868..6d0e16bf215f8a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -265,10 +265,28 @@ class ValueBoundsConstraintSet {
 
   ValueBoundsConstraintSet(MLIRContext *ctx);
 
+  /// A callback to allow injecting custom value bounds constraints.
+  /// It takes the current value, the dim (or kIndexValue), and a reference to
+  /// the constraints set.
+  using PopulateCustomValueBoundsFn =
+      function_ref<void(Value, int64_t, ValueBoundsConstraintSet &)>;
+
+  /// Populates the constraint set for a value/map without actually computing
+  /// the bound.
+  int64_t populateConstraintsSet(
+      Value value, std::optional<int64_t> dim = std::nullopt,
+      PopulateCustomValueBoundsFn customValueBounds = nullptr,
+      StopConditionFn stopCondition = nullptr);
+  int64_t populateConstraintsSet(
+      AffineMap map, ValueDimList mapOperands,
+      PopulateCustomValueBoundsFn customValueBounds = nullptr,
+      StopConditionFn stopCondition = nullptr, int64_t *posOut = nullptr);
+
   /// Iteratively process all elements on the worklist until an index-typed
   /// value or shaped value meets `stopCondition`. Such values are not processed
   /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  void processWorklist(StopConditionFn stopCondition,
+                       PopulateCustomValueBoundsFn customValueBounds = nullptr);
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d613672608c3ad..77b2ad6ff540ce 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -24,6 +24,7 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 
@@ -300,3 +301,132 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
 }
+
+FailureOr<vector::ConstantOrScalableBound::BoundSize>
+vector::ConstantOrScalableBound::getSize() const {
+  if (map.isSingleConstant())
+    return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
+  if (map.getNumResults() != 1 || map.getNumInputs() != 1)
+    return failure();
+  auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
+  if (!binop || binop.getKind() != AffineExprKind::Mul)
+    return failure();
+  auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
+    if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
+      constant = cst.getValue();
+      return true;
+    }
+    return false;
+  };
+  // Match `s0 * cst` or `cst * s0`:
+  int64_t cst = 0;
+  auto lhs = binop.getLHS();
+  auto rhs = binop.getRHS();
+  if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
+      (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
+    return BoundSize{cst, /*scalable=*/true};
+  }
+  return failure();
+}
+
+namespace {
+struct ScalableValueBoundsConstraintSet : public ValueBoundsConstraintSet {
+  using ValueBoundsConstraintSet::ValueBoundsConstraintSet;
+
+  static Operation *getOwnerOfValue(Value value) {
+    if (auto bbArg = dyn_cast<BlockArgument>(value))
+      return bbArg.getOwner()->getParentOp();
+    return value.getDefiningOp();
+  }
+
+  static FailureOr<AffineMap>
+  computeScalableBound(Value value, std::optional<int64_t> dim,
+                       unsigned vscaleMin, unsigned vscaleMax,
+                       presburger::BoundType boundType) {
+    using namespace presburger;
+
+    assert(vscaleMin <= vscaleMax);
+    ScalableValueBoundsConstraintSet cstr(value.getContext());
+
+    Value vscale;
+    int64_t pos = cstr.populateConstraintsSet(
+        value, dim,
+        /* Custom vscale value bounds */
+        [&vscale, vscaleMin, vscaleMax](Value value, int64_t dim,
+                                        ValueBoundsConstraintSet &cstr) {
+          if (dim != ValueBoundsConstraintSet::kIndexValue)
+            return;
+          if (isa_and_present<vector::VectorScaleOp>(getOwnerOfValue(value))) {
+            if (vscale) {
+              // All copies of vscale are equivalent.
+              cstr.bound(value) == cstr.getExpr(vscale);
+            } else {
+              // We know vscale is confined to [vscaleMin, vscaleMax].
+              cstr.bound(value) >= vscaleMin;
+              cstr.bound(value) <= vscaleMax;
+              vscale = value;
+            }
+          }
+        },
+        /* Stop condition */
+        [](auto, auto) {
+          // Keep adding constraints till the worklist is empty.
+          return false;
+        });
+
+    // Project out all variables apart from the first vscale.
+    cstr.projectOut([&](ValueDim p) { return p.first != vscale; });
+
+    assert(cstr.cstr.getNumDimAndSymbolVars() ==
+               cstr.positionToValueDim.size() &&
+           "inconsistent mapping state");
+
+    for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
+      if (i == pos)
+        continue;
+      if (cstr.positionToValueDim[i] !=
+          ValueDim(vscale, ValueBoundsConstraintSet::kIndexValue)) {
+        return failure();
+      }
+    }
+
+    SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
+    cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
+                             &upperBound,
+                             /*closedUB=*/true);
+
+    auto invalidBound = [](auto &bound) {
+      return !bound[0] || bound[0].getNumResults() != 1;
+    };
+
+    AffineMap bound = [&] {
+      if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
+          lowerBound[0] == lowerBound[0]) {
+        return lowerBound[0];
+      } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
+        return lowerBound[0];
+      } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
+        return upperBound[0];
+      }
+      return AffineMap{};
+    }();
+
+    if (!bound)
+      return failure();
+
+    return bound;
+  }
+};
+
+} // namespace
+
+FailureOr<vector::ConstantOrScalableBound>
+vector::computeScalableBound(Value value, std::optional<int64_t> dim,
+                             unsigned vscaleMin, unsigned vscaleMax,
+                             presburger::BoundType boundType) {
+  auto bound = ScalableValueBoundsConstraintSet::computeScalableBound(
+      value, dim, vscaleMin, vscaleMax, boundType);
+  if (failed(bound))
+    return failure();
+  return ConstantOrScalableBound{*bound};
+}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 85abc2df894797..ac4e3b935a0542 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -191,7 +191,9 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist(
+    StopConditionFn stopCondition,
+    PopulateCustomValueBoundsFn customValueBounds) {
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -215,8 +217,11 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
     if (stopCondition(value, maybeDim))
       continue;
 
-    // Query `ValueBoundsOpInterface` for constraints. New items may be added to
-    // the worklist.
+    // 1. Query `customValueBounds` for constraints (if provided).
+    if (customValueBounds)
+      customValueBounds(value, dim, *this);
+
+    // 2. Query `ValueBoundsOpInterface` for constraints.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
     if (valueBoundsOp) {
@@ -228,6 +233,8 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
       continue;
     }
 
+    // Steps 1 and 2 above may add new items to the worklist.
+
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
     // tied to OpOperands. Tied values have the same shape.
@@ -471,55 +478,92 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
                               closedUB);
 }
 
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
+    StopConditionFn stopCondition, bool closedUB) {
+  ValueDimList valueDims;
+  for (Value v : operands) {
+    assert(v.getType().isIndex() && "expected index type");
+    valueDims.emplace_back(v, std::nullopt);
+  }
+  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+}
+
 FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, AffineMap map, ValueDimList operands,
     StopConditionFn stopCondition, bool closedUB) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
   ValueBoundsConstraintSet cstr(map.getContext());
-  int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+  int64_t pos = 0;
+  if (stopCondition) {
+    cstr.populateConstraintsSet(map, operands, nullptr, stopCondition, &pos);
+  } else {
+    // No stop condition specified: Keep adding constraints until a bound could
+    // be computed.
+    cstr.populateConstraintsSet(
+        map, operands, nullptr,
+        [&](Value v, std::optional<int64_t> dim) {
+          return cstr.cstr.getConstantBound64(type, pos).has_value();
+        },
+        &pos);
+  }
+  // Compute constant bound for `valueDim`.
+  int64_t ubAdjustment = closedUB ? 0 : 1;
+  if (auto bound = cstr.cstr.getConstantBound64(type, pos))
+    return type == BoundType::UB ? *bound + ubAdjustment : *bound;
+  return failure();
+}
+
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(
+    Value value, std::optional<int64_t> dim,
+    PopulateCustomValueBoundsFn customValueBounds,
+    StopConditionFn stopCondition) {
+#ifndef NDEBUG
+  assertValidValueDim(value, dim);
+#endif // NDEBUG
+
+  AffineMap map =
+      AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+                     Builder(value.getContext()).getAffineDimExpr(0));
+  return populateConstraintsSet(map, {{value, dim}}, customValueBounds,
+                                stopCondition);
+}
+
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(
+    AffineMap map, ValueDimList operands,
+    PopulateCustomValueBoundsFn customValueBounds,
+    StopConditionFn stopCondition, int64_t *posOut) {
+  assert(map.getNumResults() == 1 && "expected affine map with one result");
+  int64_t pos = insert(/*isSymbol=*/false);
+  if (posOut)
+    *posOut = pos;
 
   // Add map and operands to the constraint set. Dimensions are converted to
   // symbols. All operands are added to the worklist.
   auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
-    return cstr.getExpr(v.first, v.second);
+    return getExpr(v.first, v.second);
   };
   SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
       llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
   SmallVector<AffineExpr> symReplacements = llvm::to_vector(
       llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
-  cstr.addBound(
+  addBound(
       presburger::BoundType::EQ, pos,
       map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
 
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
   if (stopCondition) {
-    cstr.processWorklist(stopCondition);
+    processWorklist(stopCondition, customValueBounds);
   } else {
-    // No stop condition specified: Keep adding constraints until a bound could
-    // be computed.
-    cstr.processWorklist(
-        /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        });
+    // No stop condition specified: Keep adding constraints until the worklist
+    // is empty.
+    processWorklist([](Value v, std::optional<int64_t> dim) { return false; },
+                    customValueBounds);
   }
 
-  // Compute constant bound for `valueDim`.
-  int64_t ubAdjustment = closedUB ? 0 : 1;
-  if (auto bound = cstr.cstr.getConstantBound64(type, pos))
-    return type == BoundType::UB ? *bound + ubAdjustment : *bound;
-  return failure();
-}
-
-FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
-    presburger::BoundType type, AffineMap map, ArrayRef<Value> operands,
-    StopConditionFn stopCondition, bool closedUB) {
-  ValueDimList valueDims;
-  for (Value v : operands) {
-    assert(v.getType().isIndex() && "expected index type");
-    valueDims.emplace_back(v, std::nullopt);
-  }
-  return computeConstantBound(type, map, valueDims, stopCondition, closedUB);
+  return pos;
 }
 
 FailureOr<int64_t>
diff --git a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
new file mode 100644
index 00000000000000..2afc4db874b73e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s -test-affine-reify-value-bounds -cse -verify-diagnostics \
+// RUN:   -verify-diagnostics -split-input-file | FileCheck %s
+
+#fixedDim0Map = affine_map<(d0)[s0] -> (-d0 + 32400, s0)>
+#fixedDim1Map = affine_map<(d0)[s0] -> (-d0 + 16, s0)>
+
+// Here the upper bound for min_i is 4 x vscale, as we know 4 x vscale is
+// always less than 32400. The bound for min_j is 16 as at vscale > 4,
+// 4 x vscale will be > 16, so the value will be clamped at 16.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_0:.*]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK-LABEL: @fixed_size_loop_nest
+//   CHECK-DAG:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_0]]()[%vscale]
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]], %[[C16]]) : (index, index) -> ()
+func.func @fixed_size_loop_nest() {
+  %c16 = arith.constant 16 : index
+  %c32400 = arith.constant 32400 : index
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  scf.for %i = %c0 to %c32400 step %c4_vscale {
+    %min_i = affine.min #fixedDim0Map(%i)[%c4_vscale]
+    scf.for %j = %c0 to %c16 step %c4_vscale {
+      %min_j = affine.min #fixedDim1Map(%j)[%c4_vscale]
+      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB"} : (index) -> index
+      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB"} : (index) -> index
+      "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
+    }
+  }
+  return
+}
+
+// -----
+
+#dynamicDim0Map = affine_map<(d0, d1)[s0] -> (-d0 + d1, s0)>
+#dynamicDim1Map = affine_map<(d0, d1)[s0] -> (-d0 + d1, s0)>
+
+// Here upper bounds for both min_i and min_j are both 4 x vscale, as we know
+// that is always the largest value they could take. As if `dim < 4 x vscale`
+// then 4 x vscale is an overestimate, and if `dim > 4 x vscale` then the min
+// will be clamped to 4 x vscale.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_1:.*]] = affine_map<()[s0] -> (s0 * 4)>
+
+// CHECK-LABEL: @dynamic_size_loop_nest
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_1]]()[%vscale]
+//       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]], %[[SCALABLE_BOUND]]) : (index, index) -> ()
+func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
+  %c4 = arith.constant 4 : index
+  %c0 = arith.constant 0 : index
+  %vscale = vector.vscale
+  %c4_vscale = arith.muli %vscale, %c4 : index
+  scf.for %i = %c0 to %dim0 step %c4_vscale {
+    %min_i = affine.min #dynamicDim0Map(%i)[%c4_vscale, %dim0]
+    scf.for %j = %c0 to %dim1 step %c4_vscale {
+      %min_j = affine.min #dynamicDim1Map(%j)[%c4_vscale, %dim1]
+      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB"} : (index) -> index
+      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB"} : (index) -> index
+      "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
+    }
+  }
+  return
+}
+
+// -----
+
+// Here the upper bound is just a value + a constant.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_2:.*]] = affine_map<()[s0] -> (s0 + 8)>
+
+// CHECK-LABEL: @add_to_vscale
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_2]]()[%vscale]
+//       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
+func.func @add_to_vscale() {
+  %vscale = vector.vscale
+  %c8 = arith.constant 8 : index
+  %vscale_plus_c8 = arith.addi %vscale, %c8 : index
+  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "UB"} : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}
+
+// -----
+
+// Here we know vscale is always 2 so we get a constant upper bound.
+
+// CHECK-LABEL: @vscale_fixed_size
+//       CHECK:   %[[C2:.*]] = arith.constant 2 : index
+//       CHECK:   "test.some_use"(%[[C2]]) : (index) -> ()
+func.func @vscale_fixed_size() {
+  %vscale = vector.vscale
+  %bound = "test.reify_scalable_bound"(%vscale) {type = "UB", vscale_min = 2, vscale_max = 2} : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}
+
+// -----
+
+// Here we don't know the upper bound (%a is underspecified)
+
+func.func @unknown_bound(%a: index) {
+  %vscale = vector.vscale
+  %vscale_plus_a = arith.muli %vscale, %a : index
+  // expected-error @below{{could not reify bound}}
+  %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB"} : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}
+
+// -----
+
+// Here we have two vscale values (that have not been CSE'd), but they should
+// still be treated as equivalent.
+
+// CHECK: #[[$SCALABLE_BOUND_MAP_3:.*]] = affine_map<()[s0] -> (s0 * 6)>
+
+// CHECK-LABEL: @duplicate_vscale_values
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_3]]()[%vscale]
+//       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
+func.func @duplicate_vscale_values() {
+  %c4 = arith.constant 4 : index
+  %vscale_0 = vector.vscale
+
+  %c2 = arith.constant 2 : index
+  %vscale_1 = vector.vscale
+
+  %c4_vscale = arith.muli %vscale_0, %c4 : index
+  %c2_vscale = arith.muli %vscale_1, %c2 : index
+  %add = arith.addi %c2_vscale, %c4_vscale : index
+
+  %bound = "test.reify_scalable_bound"(%add) {type = "UB"} : (index) -> index
+  "test.some_use"(%bound) : (index) -> ()
+  return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..a2fe4b3f6c34f7 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Pass/Pass.h"
@@ -75,7 +76,8 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
   WalkResult result = funcOp.walk([&](Operation *op) {
     // Look for test.reify_bound ops.
     if (op->getName().getStringRef() == "test.reify_bound" ||
-        op->getName().getStringRef() == "test.reify_constant_bound") {
+        op->getName().getStringRef() == "test.reify_constant_bound" ||
+        op->getName().getStringRef() == "test.reify_scalable_bound") {
       if (op->getNumOperands() != 1 || op->getNumResults() != 1 ||
           !op->getResultTypes()[0].isIndex()) {
         op->emitOpError("invalid op");
@@ -110,6 +112,9 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
       bool constant =
           op->getName().getStringRef() == "test.reify_constant_bound";
 
+      bool scalable = !constant && op->getName().getStringRef() ==
+                                       "test.reify_scalable_bound";
+
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
       std::function<bool(Value, std::optional<int64_t>)> stopCondition =
@@ -137,6 +142,29 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
         if (succeeded(reifiedConst))
           reified =
               FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
+      } else if (scalable) {
+        unsigned vscaleMin = 1;
+        unsigned vscaleMax = 16;
+
+        if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr))
+          vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+        if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr))
+          vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+
+        auto loc = op->getLoc();
+        auto reifiedScalable = vector::computeScalableBound(
+            value, dim, vscaleMin, vscaleMax, *boundType);
+        if (succeeded(reifiedScalable)) {
+          SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
+              vscaleOperand;
+          if (reifiedScalable->map.getNumInputs() == 1) {
+            // The only possible input to the bound is vscale.
+            vscaleOperand.push_back(std::make_pair(
+                rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
+          }
+          reified = affine::materializeComputedBound(
+              rewriter, loc, reifiedScalable->map, vscaleOperand);
+        }
       } else {
         if (dim) {
           if (useArithOps) {

>From 24479f998ce98e0dc01edf85fcfd0bdc73279eae Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 7 Mar 2024 17:26:02 +0000
Subject: [PATCH 2/5] Fix `affine_map` in dynamic dim test

---
 mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
index 2afc4db874b73e..dc6ead13d23207 100644
--- a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
@@ -35,8 +35,7 @@ func.func @fixed_size_loop_nest() {
 
 // -----
 
-#dynamicDim0Map = affine_map<(d0, d1)[s0] -> (-d0 + d1, s0)>
-#dynamicDim1Map = affine_map<(d0, d1)[s0] -> (-d0 + d1, s0)>
+#dynamicDimMap = affine_map<(d0)[s0, s1] -> (-d0 + s1, s0)>
 
 // Here upper bounds for both min_i and min_j are both 4 x vscale, as we know
 // that is always the largest value they could take. As if `dim < 4 x vscale`
@@ -54,9 +53,9 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
   %vscale = vector.vscale
   %c4_vscale = arith.muli %vscale, %c4 : index
   scf.for %i = %c0 to %dim0 step %c4_vscale {
-    %min_i = affine.min #dynamicDim0Map(%i)[%c4_vscale, %dim0]
+    %min_i = affine.min #dynamicDimMap(%i)[%c4_vscale, %dim0]
     scf.for %j = %c0 to %dim1 step %c4_vscale {
-      %min_j = affine.min #dynamicDim1Map(%j)[%c4_vscale, %dim1]
+      %min_j = affine.min #dynamicDimMap(%j)[%c4_vscale, %dim1]
       %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB"} : (index) -> index
       %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB"} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()

>From 7af6eb186d2035ca0cdbe2095f7608332ae6b053 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 8 Mar 2024 09:56:17 +0000
Subject: [PATCH 3/5] Require vscale min/max to be explicit in tests

---
 .../Vector/test-scalable-upper-bound.mlir       | 14 +++++++-------
 .../lib/Dialect/Affine/TestReifyValueBounds.cpp | 17 ++++++++++++-----
 2 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
index dc6ead13d23207..3b8fe2bc8ac1d0 100644
--- a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
@@ -25,8 +25,8 @@ func.func @fixed_size_loop_nest() {
     %min_i = affine.min #fixedDim0Map(%i)[%c4_vscale]
     scf.for %j = %c0 to %c16 step %c4_vscale {
       %min_j = affine.min #fixedDim1Map(%j)[%c4_vscale]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB"} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB"} : (index) -> index
+      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -56,8 +56,8 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
     %min_i = affine.min #dynamicDimMap(%i)[%c4_vscale, %dim0]
     scf.for %j = %c0 to %dim1 step %c4_vscale {
       %min_j = affine.min #dynamicDimMap(%j)[%c4_vscale, %dim1]
-      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB"} : (index) -> index
-      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB"} : (index) -> index
+      %bound_i = "test.reify_scalable_bound"(%min_i) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+      %bound_j = "test.reify_scalable_bound"(%min_j) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
       "test.some_use"(%bound_i, %bound_j) : (index, index) -> ()
     }
   }
@@ -77,7 +77,7 @@ func.func @add_to_vscale() {
   %vscale = vector.vscale
   %c8 = arith.constant 8 : index
   %vscale_plus_c8 = arith.addi %vscale, %c8 : index
-  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "UB"} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -104,7 +104,7 @@ func.func @unknown_bound(%a: index) {
   %vscale = vector.vscale
   %vscale_plus_a = arith.muli %vscale, %a : index
   // expected-error @below{{could not reify bound}}
-  %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB"} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%vscale_plus_a) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -130,7 +130,7 @@ func.func @duplicate_vscale_values() {
   %c2_vscale = arith.muli %vscale_1, %c2 : index
   %add = arith.addi %c2_vscale, %c4_vscale : index
 
-  %bound = "test.reify_scalable_bound"(%add) {type = "UB"} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%add) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index a2fe4b3f6c34f7..65e2caa6de79fa 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -143,13 +143,20 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
           reified =
               FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
       } else if (scalable) {
-        unsigned vscaleMin = 1;
-        unsigned vscaleMax = 16;
-
-        if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr))
+        unsigned vscaleMin = 0;
+        unsigned vscaleMax = 0;
+        if (auto attr = "vscale_min"; op->hasAttrOfType<IntegerAttr>(attr)) {
           vscaleMin = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
-        if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr))
+        } else {
+          op->emitOpError("expected `vscale_min` to be provided");
+          return WalkResult::skip();
+        }
+        if (auto attr = "vscale_max"; op->hasAttrOfType<IntegerAttr>(attr)) {
           vscaleMax = unsigned(op->getAttrOfType<IntegerAttr>(attr).getInt());
+        } else {
+          op->emitOpError("expected `vscale_max` to be provided");
+          return WalkResult::skip();
+        }
 
         auto loc = op->getLoc();
         auto reifiedScalable = vector::computeScalableBound(

>From 725daf33d4b71730095b6f25d482f46f29745b80 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 12 Mar 2024 12:31:03 +0000
Subject: [PATCH 4/5] Rewrite to use RTTI and vscale impl of
 `ValueBoundsOpInterface`

---
 .../IR/ScalableValueBoundsConstraintSet.h     |  87 ++++++++++++
 .../Vector/IR/ValueBoundsOpInterfaceImpl.h    |  20 +++
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  29 ----
 mlir/include/mlir/InitAllDialects.h           |   2 +
 .../mlir/Interfaces/ValueBoundsOpInterface.h  |  32 ++---
 mlir/lib/Dialect/Vector/IR/CMakeLists.txt     |   2 +
 .../IR/ScalableValueBoundsConstraintSet.cpp   | 110 +++++++++++++++
 .../Vector/IR/ValueBoundsOpInterfaceImpl.cpp  |  51 +++++++
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 130 ------------------
 .../lib/Interfaces/ValueBoundsOpInterface.cpp |  36 ++---
 ...r-bound.mlir => test-scalable-bounds.mlir} |  22 +--
 .../Dialect/Affine/TestReifyValueBounds.cpp   |   7 +-
 12 files changed, 316 insertions(+), 212 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
 create mode 100644 mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
 create mode 100644 mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
 rename mlir/test/Dialect/Vector/{test-scalable-upper-bound.mlir => test-scalable-bounds.mlir} (88%)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
new file mode 100644
index 00000000000000..a1e155582705f7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -0,0 +1,87 @@
+//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
+#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
+
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+namespace mlir::vector {
+
+/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
+struct ScalableValueBoundsConstraintSet
+    : public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
+                               ValueBoundsConstraintSet> {
+  ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
+                                   unsigned vscaleMax)
+      : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
+
+  /// A thin wrapper over an `AffineMap` which can represent a constant bound,
+  /// or a scalable bound (in terms of vscale). The `AffineMap` will always
+  /// take at most one parameter, vscale, and returns a single result, which is
+  /// the bound of value.
+  struct ConstantOrScalableBound {
+    AffineMap map;
+
+    struct BoundSize {
+      int64_t baseSize{0};
+      bool scalable{false};
+    };
+
+    /// Get the (possibly) scalable size of the bound, returns failure if
+    /// the bound cannot be represented as a single quantity.
+    FailureOr<BoundSize> getSize() const;
+  };
+
+  /// Computes a (possibly) scalable bound for a given value. This is
+  /// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
+  /// uses knowledge of the range of vscale to compute either a constant
+  /// bound, an expression in terms of vscale, or failure if no bound can
+  /// be computed.
+  ///
+  /// The resulting `AffineMap` will always take at most one parameter,
+  /// vscale, and return a single result, which is the bound of `value`.
+  ///
+  /// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
+  /// `vscaleMax`, the resulting bound (if found), will be constant.
+  static FailureOr<ConstantOrScalableBound>
+  computeScalableBound(Value value, std::optional<int64_t> dim,
+                       unsigned vscaleMin, unsigned vscaleMax,
+                       presburger::BoundType boundType);
+
+  /// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
+  Value getVscaleValue() const { return vscale; }
+
+  /// Sets the value of vscale. Asserts if vscale has already been set.
+  void setVscale(vector::VectorScaleOp vscaleOp) {
+    assert(!vscale && "expected vscale to be unset");
+    vscale = vscaleOp.getResult();
+  }
+
+  /// The minimum possible value of vscale.
+  unsigned getVscaleMin() const { return vscaleMin; }
+
+  /// The maximum possible value of vscale.
+  unsigned getVscaleMax() const { return vscaleMax; }
+
+  static char ID;
+
+private:
+  const unsigned vscaleMin;
+  const unsigned vscaleMax;
+  Value vscale = nullptr;
+};
+
+using ConstantOrScalableBound =
+    ScalableValueBoundsConstraintSet::ConstantOrScalableBound;
+
+} // namespace mlir::vector
+
+#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
diff --git a/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
new file mode 100644
index 00000000000000..4794bc9016c6f9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace vector {
+void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 635d609d0b3e71..f6b03a0f2c8007 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 #define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
 
-#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -99,34 +98,6 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 std::optional<StaticTileOffsetRange>
 createUnrollIterator(VectorType vType, int64_t targetRank = 1);
 
-struct ConstantOrScalableBound {
-  AffineMap map;
-
-  struct BoundSize {
-    int64_t baseSize{0};
-    bool scalable{false};
-  };
-
-  /// Get the (possibly) scalable size of the bound, returns failure if the
-  /// bound cannot be represented as a single quantity.
-  FailureOr<BoundSize> getSize() const;
-};
-
-/// Computes a (possibly) scalable bound for a given value. This is similar to
-/// `ValueBoundsConstraintSet::computeConstantBound()`, but uses knowledge of
-/// the range of vscale to compute either a constant bound, an expression in
-/// terms of vscale, or failure if no bound can be computed.
-///
-/// The resulting `AffineMap` will always take at most one parameter, vscale,
-/// and return a single result, which is the bound of `value`.
-///
-/// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
-/// `vscaleMax`, the resulting bound (if found), will be constant.
-FailureOr<ConstantOrScalableBound>
-computeScalableBound(Value value, std::optional<int64_t> dim,
-                     unsigned vscaleMin, unsigned vscaleMax,
-                     presburger::BoundType boundType);
-
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e508d51205f347..5868f7e50b1724 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -85,6 +85,7 @@
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
@@ -178,6 +179,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
   vector::registerSubsetOpInterfaceExternalModels(registry);
+  vector::registerValueBoundsOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);
   ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
   spirv::registerSPIRVTargetInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 6d0e16bf215f8a..b4ed0967e63f18 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "llvm/ADT/SetVector.h"
+#include "llvm/Support/ExtensibleRTTI.h"
 
 #include <queue>
 
@@ -63,7 +64,8 @@ using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
 ///
 /// Note: Any modification of existing IR invalides the data stored in this
 /// class. Adding new operations is allowed.
-class ValueBoundsConstraintSet {
+class ValueBoundsConstraintSet
+    : public llvm::RTTIExtends<ValueBoundsConstraintSet, llvm::RTTIRoot> {
 protected:
   /// Helper class that builds a bound for a shaped value dimension or
   /// index-typed value.
@@ -107,6 +109,8 @@ class ValueBoundsConstraintSet {
   };
 
 public:
+  static char ID;
+
   /// The stop condition when traversing the backward slice of a shaped value/
   /// index-type value. The traversal continues until the stop condition
   /// evaluates to "true" for a value.
@@ -265,28 +269,20 @@ class ValueBoundsConstraintSet {
 
   ValueBoundsConstraintSet(MLIRContext *ctx);
 
-  /// A callback to allow injecting custom value bounds constraints.
-  /// It takes the current value, the dim (or kIndexValue), and a reference to
-  /// the constraints set.
-  using PopulateCustomValueBoundsFn =
-      function_ref<void(Value, int64_t, ValueBoundsConstraintSet &)>;
-
   /// Populates the constraint set for a value/map without actually computing
-  /// the bound.
-  int64_t populateConstraintsSet(
-      Value value, std::optional<int64_t> dim = std::nullopt,
-      PopulateCustomValueBoundsFn customValueBounds = nullptr,
-      StopConditionFn stopCondition = nullptr);
-  int64_t populateConstraintsSet(
-      AffineMap map, ValueDimList mapOperands,
-      PopulateCustomValueBoundsFn customValueBounds = nullptr,
-      StopConditionFn stopCondition = nullptr, int64_t *posOut = nullptr);
+  /// the bound. Returns the position for the value/map (via the return value
+  /// and `posOut` output parameter).
+  int64_t populateConstraintsSet(Value value,
+                                 std::optional<int64_t> dim = std::nullopt,
+                                 StopConditionFn stopCondition = nullptr);
+  int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
+                                 StopConditionFn stopCondition = nullptr,
+                                 int64_t *posOut = nullptr);
 
   /// Iteratively process all elements on the worklist until an index-typed
   /// value or shaped value meets `stopCondition`. Such values are not processed
   /// any further.
-  void processWorklist(StopConditionFn stopCondition,
-                       PopulateCustomValueBoundsFn customValueBounds = nullptr);
+  void processWorklist(StopConditionFn stopCondition);
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
index 70f3fa8c297d4b..204462ffd047c6 100644
--- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRVectorDialect
   VectorOps.cpp
+  ValueBoundsOpInterfaceImpl.cpp
+  ScalableValueBoundsConstraintSet.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
new file mode 100644
index 00000000000000..497ec603bf8430
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -0,0 +1,110 @@
+//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::vector {
+
+FailureOr<ConstantOrScalableBound::BoundSize>
+ConstantOrScalableBound::getSize() const {
+  if (map.isSingleConstant())
+    return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
+  if (map.getNumResults() != 1 || map.getNumInputs() != 1)
+    return failure();
+  auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
+  if (!binop || binop.getKind() != AffineExprKind::Mul)
+    return failure();
+  auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
+    if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
+      constant = cst.getValue();
+      return true;
+    }
+    return false;
+  };
+  // Match `s0 * cst` or `cst * s0`:
+  int64_t cst = 0;
+  auto lhs = binop.getLHS();
+  auto rhs = binop.getRHS();
+  if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
+      (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
+    return BoundSize{cst, /*scalable=*/true};
+  }
+  return failure();
+}
+
+char ScalableValueBoundsConstraintSet::ID = 0;
+
+FailureOr<ConstantOrScalableBound>
+ScalableValueBoundsConstraintSet::computeScalableBound(
+    Value value, std::optional<int64_t> dim, unsigned vscaleMin,
+    unsigned vscaleMax, presburger::BoundType boundType) {
+  using namespace presburger;
+
+  assert(vscaleMin <= vscaleMax);
+  ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
+                                                vscaleMax);
+
+  int64_t pos = scalableCstr.populateConstraintsSet(value, dim,
+                                                    /* Stop condition */
+                                                    [](auto, auto) {
+                                                      // Keep adding constraints
+                                                      // till the worklist is
+                                                      // empty.
+                                                      return false;
+                                                    });
+
+  // Project out all variables apart from vscale.
+  // This should result in constraints in terms of vscale only.
+  scalableCstr.projectOut(
+      [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
+
+  assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
+             scalableCstr.positionToValueDim.size() &&
+         "inconsistent mapping state");
+
+  // Check that the only symbols left are vscale.
+  for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
+    if (i == pos)
+      continue;
+    if (scalableCstr.positionToValueDim[i] !=
+        ValueDim(scalableCstr.getVscaleValue(),
+                 ValueBoundsConstraintSet::kIndexValue)) {
+      return failure();
+    }
+  }
+
+  SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
+  scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
+                                   &upperBound,
+                                   /*closedUB=*/true);
+
+  auto invalidBound = [](auto &bound) {
+    return !bound[0] || bound[0].getNumResults() != 1;
+  };
+
+  AffineMap bound = [&] {
+    if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
+        lowerBound[0] == lowerBound[0]) {
+      return lowerBound[0];
+    } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
+      return lowerBound[0];
+    } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
+      return upperBound[0];
+    }
+    return AffineMap{};
+  }();
+
+  if (!bound)
+    return failure();
+
+  return ConstantOrScalableBound{bound};
+}
+
+} // namespace mlir::vector
diff --git a/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..f2e9bc58adc000
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -0,0 +1,51 @@
+//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir::vector {
+namespace {
+
+struct VectorScaleOpInterface
+    : public ValueBoundsOpInterface::ExternalModel<VectorScaleOpInterface,
+                                                   VectorScaleOp> {
+  void populateBoundsForIndexValue(Operation *op, Value value,
+                                   ValueBoundsConstraintSet &cstr) const {
+    auto *scalableCstr = dyn_cast<ScalableValueBoundsConstraintSet>(&cstr);
+    if (!scalableCstr)
+      return;
+    auto vscaleOp = cast<VectorScaleOp>(op);
+    assert(value == vscaleOp.getResult() && "invalid value");
+    if (auto vscale = scalableCstr->getVscaleValue()) {
+      // All copies of vscale are equivalent.
+      scalableCstr->bound(value) == cstr.getExpr(vscale);
+    } else {
+      // We know vscale is confined to [vscaleMin, vscaleMax].
+      cstr.bound(value) >= scalableCstr->getVscaleMin();
+      cstr.bound(value) <= scalableCstr->getVscaleMax();
+      scalableCstr->setVscale(vscaleOp);
+    }
+  }
+};
+
+} // namespace
+} // namespace mlir::vector
+
+void mlir::vector::registerValueBoundsOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+    vector::VectorScaleOp::attachInterface<vector::VectorScaleOpInterface>(
+        *ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 77b2ad6ff540ce..d613672608c3ad 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -24,7 +24,6 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/MathExtras.h"
 
@@ -301,132 +300,3 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
   shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
   return StaticTileOffsetRange(shapeToUnroll, /*unrollStep=*/1);
 }
-
-FailureOr<vector::ConstantOrScalableBound::BoundSize>
-vector::ConstantOrScalableBound::getSize() const {
-  if (map.isSingleConstant())
-    return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
-  if (map.getNumResults() != 1 || map.getNumInputs() != 1)
-    return failure();
-  auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
-  if (!binop || binop.getKind() != AffineExprKind::Mul)
-    return failure();
-  auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
-    if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
-      constant = cst.getValue();
-      return true;
-    }
-    return false;
-  };
-  // Match `s0 * cst` or `cst * s0`:
-  int64_t cst = 0;
-  auto lhs = binop.getLHS();
-  auto rhs = binop.getRHS();
-  if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
-      (matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
-    return BoundSize{cst, /*scalable=*/true};
-  }
-  return failure();
-}
-
-namespace {
-struct ScalableValueBoundsConstraintSet : public ValueBoundsConstraintSet {
-  using ValueBoundsConstraintSet::ValueBoundsConstraintSet;
-
-  static Operation *getOwnerOfValue(Value value) {
-    if (auto bbArg = dyn_cast<BlockArgument>(value))
-      return bbArg.getOwner()->getParentOp();
-    return value.getDefiningOp();
-  }
-
-  static FailureOr<AffineMap>
-  computeScalableBound(Value value, std::optional<int64_t> dim,
-                       unsigned vscaleMin, unsigned vscaleMax,
-                       presburger::BoundType boundType) {
-    using namespace presburger;
-
-    assert(vscaleMin <= vscaleMax);
-    ScalableValueBoundsConstraintSet cstr(value.getContext());
-
-    Value vscale;
-    int64_t pos = cstr.populateConstraintsSet(
-        value, dim,
-        /* Custom vscale value bounds */
-        [&vscale, vscaleMin, vscaleMax](Value value, int64_t dim,
-                                        ValueBoundsConstraintSet &cstr) {
-          if (dim != ValueBoundsConstraintSet::kIndexValue)
-            return;
-          if (isa_and_present<vector::VectorScaleOp>(getOwnerOfValue(value))) {
-            if (vscale) {
-              // All copies of vscale are equivalent.
-              cstr.bound(value) == cstr.getExpr(vscale);
-            } else {
-              // We know vscale is confined to [vscaleMin, vscaleMax].
-              cstr.bound(value) >= vscaleMin;
-              cstr.bound(value) <= vscaleMax;
-              vscale = value;
-            }
-          }
-        },
-        /* Stop condition */
-        [](auto, auto) {
-          // Keep adding constraints till the worklist is empty.
-          return false;
-        });
-
-    // Project out all variables apart from the first vscale.
-    cstr.projectOut([&](ValueDim p) { return p.first != vscale; });
-
-    assert(cstr.cstr.getNumDimAndSymbolVars() ==
-               cstr.positionToValueDim.size() &&
-           "inconsistent mapping state");
-
-    for (int64_t i = 0; i < cstr.cstr.getNumDimAndSymbolVars(); ++i) {
-      if (i == pos)
-        continue;
-      if (cstr.positionToValueDim[i] !=
-          ValueDim(vscale, ValueBoundsConstraintSet::kIndexValue)) {
-        return failure();
-      }
-    }
-
-    SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
-    cstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
-                             &upperBound,
-                             /*closedUB=*/true);
-
-    auto invalidBound = [](auto &bound) {
-      return !bound[0] || bound[0].getNumResults() != 1;
-    };
-
-    AffineMap bound = [&] {
-      if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
-          lowerBound[0] == lowerBound[0]) {
-        return lowerBound[0];
-      } else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
-        return lowerBound[0];
-      } else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
-        return upperBound[0];
-      }
-      return AffineMap{};
-    }();
-
-    if (!bound)
-      return failure();
-
-    return bound;
-  }
-};
-
-} // namespace
-
-FailureOr<vector::ConstantOrScalableBound>
-vector::computeScalableBound(Value value, std::optional<int64_t> dim,
-                             unsigned vscaleMin, unsigned vscaleMax,
-                             presburger::BoundType boundType) {
-  auto bound = ScalableValueBoundsConstraintSet::computeScalableBound(
-      value, dim, vscaleMin, vscaleMax, boundType);
-  if (failed(bound))
-    return failure();
-  return ConstantOrScalableBound{*bound};
-}
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index ac4e3b935a0542..06ec3f4e135e9f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -70,6 +70,8 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
 ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
     : builder(ctx) {}
 
+char ValueBoundsConstraintSet::ID = 0;
+
 #ifndef NDEBUG
 static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
   if (value.getType().isIndex()) {
@@ -191,9 +193,7 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(
-    StopConditionFn stopCondition,
-    PopulateCustomValueBoundsFn customValueBounds) {
+void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -217,11 +217,8 @@ void ValueBoundsConstraintSet::processWorklist(
     if (stopCondition(value, maybeDim))
       continue;
 
-    // 1. Query `customValueBounds` for constraints (if provided).
-    if (customValueBounds)
-      customValueBounds(value, dim, *this);
-
-    // 2. Query `ValueBoundsOpInterface` for constraints.
+    // Query `ValueBoundsOpInterface` for constraints. New items may be added to
+    // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
     if (valueBoundsOp) {
@@ -233,8 +230,6 @@ void ValueBoundsConstraintSet::processWorklist(
       continue;
     }
 
-    // Steps 1 and 2 above may add new items to the worklist.
-
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
     // tied to OpOperands. Tied values have the same shape.
@@ -497,12 +492,12 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
 
   int64_t pos = 0;
   if (stopCondition) {
-    cstr.populateConstraintsSet(map, operands, nullptr, stopCondition, &pos);
+    cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
   } else {
     // No stop condition specified: Keep adding constraints until a bound could
     // be computed.
     cstr.populateConstraintsSet(
-        map, operands, nullptr,
+        map, operands,
         [&](Value v, std::optional<int64_t> dim) {
           return cstr.cstr.getConstantBound64(type, pos).has_value();
         },
@@ -516,9 +511,7 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
 }
 
 int64_t ValueBoundsConstraintSet::populateConstraintsSet(
-    Value value, std::optional<int64_t> dim,
-    PopulateCustomValueBoundsFn customValueBounds,
-    StopConditionFn stopCondition) {
+    Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
 #endif // NDEBUG
@@ -526,14 +519,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
   AffineMap map =
       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
                      Builder(value.getContext()).getAffineDimExpr(0));
-  return populateConstraintsSet(map, {{value, dim}}, customValueBounds,
-                                stopCondition);
+  return populateConstraintsSet(map, {{value, dim}}, stopCondition);
 }
 
 int64_t ValueBoundsConstraintSet::populateConstraintsSet(
-    AffineMap map, ValueDimList operands,
-    PopulateCustomValueBoundsFn customValueBounds,
-    StopConditionFn stopCondition, int64_t *posOut) {
+    AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
+    int64_t *posOut) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
   int64_t pos = insert(/*isSymbol=*/false);
   if (posOut)
@@ -555,12 +546,11 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
   if (stopCondition) {
-    processWorklist(stopCondition, customValueBounds);
+    processWorklist(stopCondition);
   } else {
     // No stop condition specified: Keep adding constraints until the worklist
     // is empty.
-    processWorklist([](Value v, std::optional<int64_t> dim) { return false; },
-                    customValueBounds);
+    processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
   }
 
   return pos;
diff --git a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
similarity index 88%
rename from mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
rename to mlir/test/Dialect/Vector/test-scalable-bounds.mlir
index 3b8fe2bc8ac1d0..2d382d76f7c73c 100644
--- a/mlir/test/Dialect/Vector/test-scalable-upper-bound.mlir
+++ b/mlir/test/Dialect/Vector/test-scalable-bounds.mlir
@@ -11,7 +11,8 @@
 // CHECK: #[[$SCALABLE_BOUND_MAP_0:.*]] = affine_map<()[s0] -> (s0 * 4)>
 
 // CHECK-LABEL: @fixed_size_loop_nest
-//   CHECK-DAG:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_0]]()[%vscale]
+//   CHECK-DAG:   %[[VSCALE:.*]] = vector.vscale
+//   CHECK-DAG:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_0]]()[%[[VSCALE]]]
 //   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
 //       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]], %[[C16]]) : (index, index) -> ()
 func.func @fixed_size_loop_nest() {
@@ -45,7 +46,8 @@ func.func @fixed_size_loop_nest() {
 // CHECK: #[[$SCALABLE_BOUND_MAP_1:.*]] = affine_map<()[s0] -> (s0 * 4)>
 
 // CHECK-LABEL: @dynamic_size_loop_nest
-//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_1]]()[%vscale]
+//       CHECK:   %[[VSCALE:.*]] = vector.vscale
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_1]]()[%[[VSCALE]]]
 //       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]], %[[SCALABLE_BOUND]]) : (index, index) -> ()
 func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
   %c4 = arith.constant 4 : index
@@ -66,32 +68,33 @@ func.func @dynamic_size_loop_nest(%dim0: index, %dim1: index) {
 
 // -----
 
-// Here the upper bound is just a value + a constant.
+// Here the bound is just a value + a constant.
 
 // CHECK: #[[$SCALABLE_BOUND_MAP_2:.*]] = affine_map<()[s0] -> (s0 + 8)>
 
 // CHECK-LABEL: @add_to_vscale
-//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_2]]()[%vscale]
+//       CHECK:   %[[VSCALE:.*]] = vector.vscale
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_2]]()[%[[VSCALE]]]
 //       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
 func.func @add_to_vscale() {
   %vscale = vector.vscale
   %c8 = arith.constant 8 : index
   %vscale_plus_c8 = arith.addi %vscale, %c8 : index
-  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%vscale_plus_c8) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
 
 // -----
 
-// Here we know vscale is always 2 so we get a constant upper bound.
+// Here we know vscale is always 2 so we get a constant bound.
 
 // CHECK-LABEL: @vscale_fixed_size
 //       CHECK:   %[[C2:.*]] = arith.constant 2 : index
 //       CHECK:   "test.some_use"(%[[C2]]) : (index) -> ()
 func.func @vscale_fixed_size() {
   %vscale = vector.vscale
-  %bound = "test.reify_scalable_bound"(%vscale) {type = "UB", vscale_min = 2, vscale_max = 2} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%vscale) {type = "EQ", vscale_min = 2, vscale_max = 2} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
@@ -117,7 +120,8 @@ func.func @unknown_bound(%a: index) {
 // CHECK: #[[$SCALABLE_BOUND_MAP_3:.*]] = affine_map<()[s0] -> (s0 * 6)>
 
 // CHECK-LABEL: @duplicate_vscale_values
-//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_3]]()[%vscale]
+//       CHECK:   %[[VSCALE:.*]] = vector.vscale
+//       CHECK:   %[[SCALABLE_BOUND:.*]] = affine.apply #[[$SCALABLE_BOUND_MAP_3]]()[%[[VSCALE]]]
 //       CHECK:   "test.some_use"(%[[SCALABLE_BOUND]]) : (index) -> ()
 func.func @duplicate_vscale_values() {
   %c4 = arith.constant 4 : index
@@ -130,7 +134,7 @@ func.func @duplicate_vscale_values() {
   %c2_vscale = arith.muli %vscale_1, %c2 : index
   %add = arith.addi %c2_vscale, %c4_vscale : index
 
-  %bound = "test.reify_scalable_bound"(%add) {type = "UB", vscale_min = 1, vscale_max = 16} : (index) -> index
+  %bound = "test.reify_scalable_bound"(%add) {type = "EQ", vscale_min = 1, vscale_max = 16} : (index) -> index
   "test.some_use"(%bound) : (index) -> ()
   return
 }
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 65e2caa6de79fa..5e160b720db627 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -13,7 +13,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
+#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Pass/Pass.h"
@@ -159,8 +159,9 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
         }
 
         auto loc = op->getLoc();
-        auto reifiedScalable = vector::computeScalableBound(
-            value, dim, vscaleMin, vscaleMax, *boundType);
+        auto reifiedScalable =
+            vector::ScalableValueBoundsConstraintSet::computeScalableBound(
+                value, dim, vscaleMin, vscaleMax, *boundType);
         if (succeeded(reifiedScalable)) {
           SmallVector<std::pair<Value, std::optional<int64_t>>, 1>
               vscaleOperand;

>From acffc546086f3cd7ec897fa2beef8985e7a0b4c3 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 12 Mar 2024 13:33:53 +0000
Subject: [PATCH 5/5] Expose computeScalableBound() `stopCondition`

---
 .../Vector/IR/ScalableValueBoundsConstraintSet.h     |  3 ++-
 .../Vector/IR/ScalableValueBoundsConstraintSet.cpp   | 12 +++---------
 2 files changed, 5 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
index a1e155582705f7..ac74352f9af623 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -54,7 +54,8 @@ struct ScalableValueBoundsConstraintSet
   static FailureOr<ConstantOrScalableBound>
   computeScalableBound(Value value, std::optional<int64_t> dim,
                        unsigned vscaleMin, unsigned vscaleMax,
-                       presburger::BoundType boundType);
+                       presburger::BoundType boundType,
+                       StopConditionFn stopCondition = nullptr);
 
   /// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
   Value getVscaleValue() const { return vscale; }
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index 497ec603bf8430..de15739fc7a9b9 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -44,21 +44,15 @@ char ScalableValueBoundsConstraintSet::ID = 0;
 FailureOr<ConstantOrScalableBound>
 ScalableValueBoundsConstraintSet::computeScalableBound(
     Value value, std::optional<int64_t> dim, unsigned vscaleMin,
-    unsigned vscaleMax, presburger::BoundType boundType) {
+    unsigned vscaleMax, presburger::BoundType boundType,
+    StopConditionFn stopCondition) {
   using namespace presburger;
 
   assert(vscaleMin <= vscaleMax);
   ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
                                                 vscaleMax);
 
-  int64_t pos = scalableCstr.populateConstraintsSet(value, dim,
-                                                    /* Stop condition */
-                                                    [](auto, auto) {
-                                                      // Keep adding constraints
-                                                      // till the worklist is
-                                                      // empty.
-                                                      return false;
-                                                    });
+  int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
 
   // Project out all variables apart from vscale.
   // This should result in constraints in terms of vscale only.



More information about the Mlir-commits mailing list