[llvm-branch-commits] [mlir] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass stop condition in the constructor (PR #86099)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Mar 22 22:58:36 PDT 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86099

>From ad1b2ac46a744887594e13370762939034797d31 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sat, 23 Mar 2024 05:57:33 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass stop
 condition in the constructor

This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed.

This change is in preparation of adding support for branches.
---
 .../IR/ScalableValueBoundsConstraintSet.h     |  9 +-
 .../mlir/Interfaces/ValueBoundsOpInterface.h  | 16 ++--
 .../Affine/Transforms/ReifyValueBounds.cpp    |  6 +-
 .../Arith/Transforms/ReifyValueBounds.cpp     |  6 +-
 .../Linalg/Transforms/HoistPadding.cpp        |  2 +-
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  2 +-
 .../IR/ScalableValueBoundsConstraintSet.cpp   | 21 +++--
 .../lib/Interfaces/ValueBoundsOpInterface.cpp | 82 ++++++++++---------
 .../Dialect/Affine/TestReifyValueBounds.cpp   |  9 +-
 9 files changed, 90 insertions(+), 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
index 31e19ff1ad39f7f..67a6581eb2fb4b3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
@@ -29,9 +29,12 @@ struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
 struct ScalableValueBoundsConstraintSet
     : public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
                                detail::ValueBoundsConstraintSet> {
-  ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
-                                   unsigned vscaleMax)
-      : RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
+  ScalableValueBoundsConstraintSet(
+      MLIRContext *context,
+      ValueBoundsConstraintSet::StopConditionFn stopCondition,
+      unsigned vscaleMin, unsigned vscaleMax)
+      : RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
+        vscaleMax(vscaleMax) {};
 
   using RTTIExtends::bound;
   using RTTIExtends::StopConditionFn;
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 6b4ac4174b16b59..c281739e0ded2c6 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -117,8 +117,9 @@ class ValueBoundsConstraintSet
   ///
   /// The first parameter of the function is the shaped value/index-typed
   /// value. The second parameter is the dimension in case of a shaped value.
-  using StopConditionFn =
-      function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
+  /// The third parameter is this constraint set.
+  using StopConditionFn = std::function<bool(
+      Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
 
   /// Compute a bound for the given index-typed value or shape dimension size.
   /// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -267,22 +268,20 @@ class ValueBoundsConstraintSet
   /// An index-typed value or the dimension of a shaped-type value.
   using ValueDim = std::pair<Value, int64_t>;
 
-  ValueBoundsConstraintSet(MLIRContext *ctx);
+  ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
 
   /// Populates the constraint set for a value/map without actually computing
   /// the bound. Returns the position for the value/map (via the return value
   /// and `posOut` output parameter).
   int64_t populateConstraintsSet(Value value,
-                                 std::optional<int64_t> dim = std::nullopt,
-                                 StopConditionFn stopCondition = nullptr);
+                                 std::optional<int64_t> dim = std::nullopt);
   int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
-                                 StopConditionFn stopCondition = nullptr,
                                  int64_t *posOut = nullptr);
 
   /// Iteratively process all elements on the worklist until an index-typed
   /// value or shaped value meets `stopCondition`. Such values are not processed
   /// any further.
-  void processWorklist(StopConditionFn stopCondition);
+  void processWorklist();
 
   /// Bound the given column in the underlying constraint set by the given
   /// expression.
@@ -330,6 +329,9 @@ class ValueBoundsConstraintSet
 
   /// Builder for constructing affine expressions.
   Builder builder;
+
+  /// The current stop condition function.
+  StopConditionFn stopCondition = nullptr;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f76d4465df..117ee8e8701ad7c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd1478aa9e61..fad221288f190ed 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
     bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     // We are trying to reify a bound for `value` in terms of the owning op's
     // operands. Construct a stop condition that evaluates to "true" for any SSA
     // value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
 FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
     OpBuilder &b, Location loc, presburger::BoundType type, Value value,
     ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
-  auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+  auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+                             ValueBoundsConstraintSet &cstr) {
     return v != value;
   };
   return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8eebaecb92..c3a08ce86082a8e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
     FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
         rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
         /*stopCondition=*/
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           if (v == forOp.getUpperBound())
             return false;
           // Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24e..1e13e60068ee7f6 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
     ValueDimList boundOperands;
     LogicalResult status = ValueBoundsConstraintSet::computeBound(
         bound, boundOperands, BoundType::EQ, yieldedValue, dim,
-        [&](Value v, std::optional<int64_t> d) {
+        [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
           // Stop when reaching a block argument of the loop body.
           if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
             return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index 6d7e3bc70f59de9..52359fa8a510d35 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
     unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
     StopConditionFn stopCondition) {
   using namespace presburger;
-
   assert(vscaleMin <= vscaleMax);
-  ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
-                                                vscaleMax);
 
-  int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
+  // No stop condition specified: Keep adding constraints until the worklist
+  // is empty.
+  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+                                  mlir::ValueBoundsConstraintSet &cstr) {
+    return false;
+  };
+
+  ScalableValueBoundsConstraintSet scalableCstr(
+      value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
+      vscaleMin, vscaleMax);
+  int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
 
   // Project out all variables apart from vscale.
   // This should result in constraints in terms of vscale only.
-  scalableCstr.projectOut(
-      [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
+  auto projectOutFn = [&](ValueDim p) {
+    return p.first != scalableCstr.getVscaleValue();
+  };
+  scalableCstr.projectOut(projectOutFn);
 
   assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
              scalableCstr.positionToValueDim.size() &&
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 5b00566ee3bf07d..9028fb3fb767774 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
   return std::nullopt;
 }
 
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
-    : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+    MLIRContext *ctx, StopConditionFn stopCondition)
+    : builder(ctx), stopCondition(stopCondition) {
+  assert(stopCondition && "expected non-null stop condition");
+}
 
 char ValueBoundsConstraintSet::ID = 0;
 
@@ -230,7 +233,8 @@ static Operation *getOwnerOfValue(Value value) {
   return value.getDefiningOp();
 }
 
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+  LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
   while (!worklist.empty()) {
     int64_t pos = worklist.front();
     worklist.pop();
@@ -251,13 +255,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
 
     // Do not process any further if the stop condition is met.
     auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
-    if (stopCondition(value, maybeDim))
+    if (stopCondition(value, maybeDim, *this)) {
+      LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+                              << " (dim: " << maybeDim << ")\n");
       continue;
+    }
 
     // Query `ValueBoundsOpInterface` for constraints. New items may be added to
     // the worklist.
     auto valueBoundsOp =
         dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+    LLVM_DEBUG(llvm::dbgs()
+               << "Query value bounds for: " << value
+               << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
     if (valueBoundsOp) {
       if (dim == kIndexValue) {
         valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -266,6 +276,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
       }
       continue;
     }
+    LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
 
     // If the op does not implement `ValueBoundsOpInterface`, check if it
     // implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -315,8 +326,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
     bool closedUB) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
-  assert(!stopCondition(value, dim) &&
-         "stop condition should not be satisfied for starting point");
 #endif // NDEBUG
 
   int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -326,9 +335,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
   // Process the backward slice of `value` (i.e., reverse use-def chain) until
   // `stopCondition` is met.
   ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
-  ValueBoundsConstraintSet cstr(value.getContext());
+  ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+  assert(!stopCondition(value, dim, cstr) &&
+         "stop condition should not be satisfied for starting point");
   int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
-  cstr.processWorklist(stopCondition);
+  cstr.processWorklist();
 
   // Project out all variables (apart from `valueDim`) that do not match the
   // stop condition.
@@ -338,7 +349,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
       return false;
     auto maybeDim =
         p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
-    return !stopCondition(p.first, maybeDim);
+    return !stopCondition(p.first, maybeDim, cstr);
   });
 
   // Compute lower and upper bounds for `valueDim`.
@@ -444,7 +455,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
     bool closedUB) {
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) {
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
         return llvm::is_contained(dependencies, std::make_pair(v, d));
       },
       closedUB);
@@ -480,7 +491,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
   // Reify bounds in terms of any independent values.
   return computeBound(
       resultMap, mapOperands, type, value, dim,
-      [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+      [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+        return isIndependent(v);
+      },
       closedUB);
 }
 
@@ -513,21 +526,19 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
     presburger::BoundType type, AffineMap map, ValueDimList operands,
     StopConditionFn stopCondition, bool closedUB) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
-  ValueBoundsConstraintSet cstr(map.getContext());
 
-  int64_t pos = 0;
-  if (stopCondition) {
-    cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
-  } else {
-    // No stop condition specified: Keep adding constraints until a bound could
-    // be computed.
-    cstr.populateConstraintsSet(
-        map, operands,
-        [&](Value v, std::optional<int64_t> dim) {
-          return cstr.cstr.getConstantBound64(type, pos).has_value();
-        },
-        &pos);
-  }
+  // Default stop condition if none was specified: Keep adding constraints until
+  // a bound could be computed.
+  int64_t pos;
+  auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+                                  ValueBoundsConstraintSet &cstr) {
+    return cstr.cstr.getConstantBound64(type, pos).has_value();
+  };
+
+  ValueBoundsConstraintSet cstr(
+      map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+  cstr.populateConstraintsSet(map, operands, &pos);
+
   // Compute constant bound for `valueDim`.
   int64_t ubAdjustment = closedUB ? 0 : 1;
   if (auto bound = cstr.cstr.getConstantBound64(type, pos))
@@ -535,8 +546,9 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
   return failure();
 }
 
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
-    Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
+int64_t
+ValueBoundsConstraintSet::populateConstraintsSet(Value value,
+                                                 std::optional<int64_t> dim) {
 #ifndef NDEBUG
   assertValidValueDim(value, dim);
 #endif // NDEBUG
@@ -544,12 +556,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
   AffineMap map =
       AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
                      Builder(value.getContext()).getAffineDimExpr(0));
-  return populateConstraintsSet(map, {{value, dim}}, stopCondition);
+  return populateConstraintsSet(map, {{value, dim}});
 }
 
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
-    AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
-    int64_t *posOut) {
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
+                                                         ValueDimList operands,
+                                                         int64_t *posOut) {
   assert(map.getNumResults() == 1 && "expected affine map with one result");
   int64_t pos = insert(/*isSymbol=*/false);
   if (posOut)
@@ -570,13 +582,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
 
   // Process the backward slice of `operands` (i.e., reverse use-def chain)
   // until `stopCondition` is met.
-  if (stopCondition) {
-    processWorklist(stopCondition);
-  } else {
-    // No stop condition specified: Keep adding constraints until the worklist
-    // is empty.
-    processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
-  }
+  processWorklist();
 
   return pos;
 }
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 5e160b720db627e..4b2b1a06341b717 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -117,14 +117,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
 
       // Prepare stop condition. By default, reify in terms of the op's
       // operands. No stop condition is used when a constant was requested.
-      std::function<bool(Value, std::optional<int64_t>)> stopCondition =
-          [&](Value v, std::optional<int64_t> d) {
+      std::function<bool(Value, std::optional<int64_t>,
+                         ValueBoundsConstraintSet & cstr)>
+          stopCondition = [&](Value v, std::optional<int64_t> d,
+                              ValueBoundsConstraintSet &cstr) {
             // Reify in terms of SSA values that are different from `value`.
             return v != value;
           };
       if (reifyToFuncArgs) {
         // Reify in terms of function block arguments.
-        stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
+        stopCondition = [](Value v, std::optional<int64_t> d,
+                           ValueBoundsConstraintSet &cstr) {
           auto bbArg = dyn_cast<BlockArgument>(v);
           if (!bbArg)
             return false;



More information about the llvm-branch-commits mailing list