[llvm-branch-commits] [mlir] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass stop condition in the constructor (PR #86099)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Mar 21 01:05:56 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86099
This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed.
This change is in preparation of adding support for branches.
>From db3dde1d9c6e3eb1b85083d1a3545691f47acb7c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 21 Mar 2024 08:04:11 +0000
Subject: [PATCH] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass stop
condition in the constructor
This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed.
This change is in preparation of adding support for branches.
---
.../mlir/Interfaces/ValueBoundsOpInterface.h | 16 +++--
.../Affine/Transforms/ReifyValueBounds.cpp | 6 +-
.../Arith/Transforms/ReifyValueBounds.cpp | 6 +-
.../Linalg/Transforms/HoistPadding.cpp | 2 +-
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 2 +-
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 60 +++++++++++--------
.../Dialect/Affine/TestReifyValueBounds.cpp | 9 ++-
7 files changed, 62 insertions(+), 39 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 94a8a8b429c801..b79c44162ea8ef 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -113,8 +113,9 @@ class ValueBoundsConstraintSet {
///
/// The first parameter of the function is the shaped value/index-typed
/// value. The second parameter is the dimension in case of a shaped value.
- using StopConditionFn =
- function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
+ /// The third parameter is this constraint set.
+ using StopConditionFn = function_ref<bool(
+ Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
/// Compute a bound for the given index-typed value or shape dimension size.
/// The computed bound is stored in `resultMap`. The operands of the bound are
@@ -263,12 +264,12 @@ class ValueBoundsConstraintSet {
/// An index-typed value or the dimension of a shaped-type value.
using ValueDim = std::pair<Value, int64_t>;
- ValueBoundsConstraintSet(MLIRContext *ctx);
+ ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
/// Iteratively process all elements on the worklist until an index-typed
- /// value or shaped value meets `stopCondition`. Such values are not processed
- /// any further.
- void processWorklist(StopConditionFn stopCondition);
+ /// value or shaped value meets `currentStopCondition`. Such values are not
+ /// processed any further.
+ void processWorklist();
/// Bound the given column in the underlying constraint set by the given
/// expression.
@@ -316,6 +317,9 @@ class ValueBoundsConstraintSet {
/// Builder for constructing affine expressions.
Builder builder;
+
+ /// The current stop condition function.
+ StopConditionFn stopCondition = nullptr;
};
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f76d4465d..117ee8e8701ad7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd1478aa9e6..fad221288f190e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8eebaecb9..c3a08ce86082a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
/*stopCondition=*/
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
if (v == forOp.getUpperBound())
return false;
// Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0cecf0d24..1e13e60068ee7f 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
ValueDimList boundOperands;
LogicalResult status = ValueBoundsConstraintSet::computeBound(
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
// Stop when reaching a block argument of the loop body.
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index f2f732f3a21d25..ec710bbacc758f 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
- : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+ MLIRContext *ctx, StopConditionFn stopCondition)
+ : builder(ctx), stopCondition(stopCondition) {}
#ifndef NDEBUG
static void assertValidValueDim(Value value, std::optional<int64_t> dim) {
@@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) {
return value.getDefiningOp();
}
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+ LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
int64_t pos = worklist.front();
worklist.pop();
@@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
// Do not process any further if the stop condition is met.
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
- if (stopCondition(value, maybeDim))
+ if (stopCondition(value, maybeDim, *this)) {
+ LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+ << " (dim: " << maybeDim << ")\n");
continue;
+ }
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
// the worklist.
auto valueBoundsOp =
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+ LLVM_DEBUG(llvm::dbgs()
+ << "Query value bounds for: " << value
+ << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
if (valueBoundsOp) {
if (dim == kIndexValue) {
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
}
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
// If the op does not implement `ValueBoundsOpInterface`, check if it
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
bool closedUB) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
- assert(!stopCondition(value, dim) &&
- "stop condition should not be satisfied for starting point");
#endif // NDEBUG
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext());
+ ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+ assert(!stopCondition(value, dim, cstr) &&
+ "stop condition should not be satisfied for starting point");
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
- cstr.processWorklist(stopCondition);
+ cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
@@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
- return !stopCondition(p.first, maybeDim);
+ return !stopCondition(p.first, maybeDim, cstr);
});
// Compute lower and upper bounds for `valueDim`.
@@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
bool closedUB) {
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
closedUB);
@@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+ return isIndependent(v);
+ },
closedUB);
}
@@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList operands,
StopConditionFn stopCondition, bool closedUB) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
- ValueBoundsConstraintSet cstr(map.getContext());
- int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+ // Default stop condition if none was specified: Keep adding constraints until
+ // a bound could be computed.
+ int64_t pos;
+ auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ };
+
+ ValueBoundsConstraintSet cstr(
+ map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ pos = cstr.insert(/*isSymbol=*/false);
// Add map and operands to the constraint set. Dimensions are converted to
// symbols. All operands are added to the worklist.
@@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
// Process the backward slice of `operands` (i.e., reverse use-def chain)
- // until `stopCondition` is met.
- if (stopCondition) {
- cstr.processWorklist(stopCondition);
- } else {
- // No stop condition specified: Keep adding constraints until a bound could
- // be computed.
- cstr.processWorklist(
- /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) {
- return cstr.cstr.getConstantBound64(type, pos).has_value();
- });
- }
+ // until the stop condition is met.
+ cstr.processWorklist();
// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 39671a930f2e21..e99a13cdca2f3c 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
// Prepare stop condition. By default, reify in terms of the op's
// operands. No stop condition is used when a constant was requested.
- std::function<bool(Value, std::optional<int64_t>)> stopCondition =
- [&](Value v, std::optional<int64_t> d) {
+ std::function<bool(Value, std::optional<int64_t>,
+ ValueBoundsConstraintSet & cstr)>
+ stopCondition = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
- stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
+ stopCondition = stopCondition = [](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
auto bbArg = dyn_cast<BlockArgument>(v);
if (!bbArg)
return false;
More information about the llvm-branch-commits
mailing list