[llvm-branch-commits] [mlir] [mlir][SCF] `ValueBoundsConstraintSet`: Support `scf.if` (branches) (PR #85895)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Mar 22 23:00:46 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/85895
>From b4bab14a9451e0dd1663fcbce5718d053ba68e5a Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Sat, 23 Mar 2024 05:59:33 +0000
Subject: [PATCH] [mlir][SCF] `ValueBoundsConstraintSet`: Support preliminary
support for branches
This commit adds support for `scf.if` to `ValueBoundsConstraintSet`.
Example:
```
%0 = scf.if ... -> index {
scf.yield %a : index
} else {
scf.yield %b : index
}
```
The following constraints hold for %0:
* %0 >= min(%a, %b)
* %0 <= max(%a, %b)
Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b:
* %0 >= %a
* %0 <= %b
This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints.
---
.../mlir/Interfaces/ValueBoundsOpInterface.h | 34 +++--
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 61 +++++++++
.../IR/ScalableValueBoundsConstraintSet.cpp | 14 ++-
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 79 +++++++++---
.../SCF/value-bounds-op-interface-impl.mlir | 119 +++++++++++++++++-
5 files changed, 278 insertions(+), 29 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index c281739e0ded2c6..f35432ca0136f36 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -203,6 +203,26 @@ class ValueBoundsConstraintSet
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
+ /// Traverse the IR starting from the given value/dim and populate constraints
+ /// as long as the stop condition holds. Also process all values/dims that are
+ /// already on the worklist.
+ void populateConstraints(Value value, std::optional<int64_t> dim);
+
+ /// Comparison operator for `ValueBoundsConstraintSet::compare`.
+ enum ComparisonOperator { LT, LE, EQ, GT, GE };
+
+ /// Try to prove that, based on the current state of this constraint set
+ /// (i.e., without analyzing additional IR or adding new constraints), the
+ /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
+ ///
+ /// Return "true" if the specified relation between the two values/dims was
+ /// proven to hold. Return "false" if the specified relation could not be
+ /// proven. This could be because the specified relation does in fact not hold
+ /// or because there is not enough information in the constraint set. In other
+ /// words, if we do not know for sure, this function returns "false".
+ bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
+ Value rhs, std::optional<int64_t> rhsDim);
+
/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
@@ -270,13 +290,13 @@ class ValueBoundsConstraintSet
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
- /// Populates the constraint set for a value/map without actually computing
- /// the bound. Returns the position for the value/map (via the return value
- /// and `posOut` output parameter).
- int64_t populateConstraintsSet(Value value,
- std::optional<int64_t> dim = std::nullopt);
- int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
- int64_t *posOut = nullptr);
+ /// Given an affine map with a single result (and map operands), add a new
+ /// column to the constraint set that represents the result of the map.
+ /// Traverse additional IR starting from the map operands as needed (as long
+ /// as the stop condition is not satisfied). Also process all values/dims that
+ /// are already on the worklist. Return the position of the newly added
+ /// column.
+ int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);
/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 1e13e60068ee7f6..8e9d1021f93e4b1 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -111,6 +111,66 @@ struct ForOpInterface
}
};
+struct IfOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
+
+ static void populateBounds(scf::IfOp ifOp, Value value,
+ std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ unsigned int resultNum = cast<OpResult>(value).getResultNumber();
+ Value thenValue = ifOp.thenYield().getResults()[resultNum];
+ Value elseValue = ifOp.elseYield().getResults()[resultNum];
+
+ // Populate constraints for the yielded value (and all values on the
+ // backward slice, as long as the current stop condition is not satisfied).
+ cstr.populateConstraints(thenValue, dim);
+ cstr.populateConstraints(elseValue, dim);
+ auto boundsBuilder = cstr.bound(value);
+ if (dim)
+ boundsBuilder[*dim];
+
+ // Compare yielded values.
+ // If thenValue <= elseValue:
+ // * result <= elseValue
+ // * result >= thenValue
+ if (cstr.compare(thenValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ elseValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
+ } else {
+ cstr.bound(value) >= thenValue;
+ cstr.bound(value) <= elseValue;
+ }
+ }
+ // If elseValue <= thenValue:
+ // * result <= thenValue
+ // * result >= elseValue
+ if (cstr.compare(elseValue, dim,
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ thenValue, dim)) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
+ } else {
+ cstr.bound(value) >= elseValue;
+ cstr.bound(value) <= thenValue;
+ }
+ }
+ }
+
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
+ }
+
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<IfOp>(op), value, dim, cstr);
+ }
+};
+
} // namespace
} // namespace scf
} // namespace mlir
@@ -119,5 +179,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
+ scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index 52359fa8a510d35..f8df34843a36312 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -59,12 +59,16 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
ScalableValueBoundsConstraintSet scalableCstr(
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
vscaleMin, vscaleMax);
- int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
+ int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
+ scalableCstr.processWorklist();
- // Project out all variables apart from vscale.
- // This should result in constraints in terms of vscale only.
+ // Project out all columns apart from vscale and the starting point
+ // (value/dim). This should result in constraints in terms of vscale only.
auto projectOutFn = [&](ValueDim p) {
- return p.first != scalableCstr.getVscaleValue();
+ bool isStartingPoint =
+ p.first == value &&
+ p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
+ return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
};
scalableCstr.projectOut(projectOutFn);
@@ -72,7 +76,7 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
scalableCstr.positionToValueDim.size() &&
"inconsistent mapping state");
- // Check that the only symbols left are vscale.
+ // Check that the only columns left are vscale and the starting point.
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
if (i == pos)
continue;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 9028fb3fb767774..dd98da9adc7d967 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -529,7 +529,7 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
- int64_t pos;
+ int64_t pos = 0;
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
@@ -537,7 +537,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
ValueBoundsConstraintSet cstr(
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
- cstr.populateConstraintsSet(map, operands, &pos);
+ pos = cstr.populateConstraints(map, operands);
+ assert(pos == 0 && "expected `map` is the first column");
// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
@@ -546,29 +547,28 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
-int64_t
-ValueBoundsConstraintSet::populateConstraintsSet(Value value,
- std::optional<int64_t> dim) {
+void ValueBoundsConstraintSet::populateConstraints(Value value,
+ std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG
- AffineMap map =
- AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- Builder(value.getContext()).getAffineDimExpr(0));
- return populateConstraintsSet(map, {{value, dim}});
+ // `getExpr` pushes the value/dim onto the worklist (unless it was already
+ // analyzed).
+ (void)getExpr(value, dim);
+ // Process all values/dims on the worklist. This may traverse and analyze
+ // additional IR, depending the current stop function.
+ processWorklist();
}
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
- ValueDimList operands,
- int64_t *posOut) {
+int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
+ ValueDimList operands) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
- if (posOut)
- *posOut = pos;
// Add map and operands to the constraint set. Dimensions are converted to
- // symbols. All operands are added to the worklist.
+ // symbols. All operands are added to the worklist (unless they were already
+ // processed).
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
return getExpr(v.first, v.second);
};
@@ -603,6 +603,55 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
{{value1, dim1}, {value2, dim2}});
}
+bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, Value rhs,
+ std::optional<int64_t> rhsDim) {
+ // This function returns "true" if "lhs CMP rhs" is proven to hold.
+ //
+ // Example for ComparisonOperator::LE and index-typed values: We would like to
+ // prove that lhs <= rhs. Proof by contradiction: add the inverse
+ // relation (lhs > rhs) to the constraint set and check if the resulting
+ // constraint set is "empty" (i.e. has no solution). In that case,
+ // lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.
+
+ // We cannot prove anything if the constraint set is already empty.
+ if (cstr.isEmpty()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "cannot compare value/dims: constraint system is already empty");
+ return false;
+ }
+
+ // EQ can be expressed as LE and GE.
+ if (cmp == EQ)
+ return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
+ compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
+
+ // Construct inequality. For the above example: lhs > rhs.
+ // `IntegerRelation` inequalities are expressed in the "flattened" form and
+ // with ">= 0". I.e., lhs - rhs - 1 >= 0.
+ SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
+ if (cmp == LT || cmp == LE) {
+ ++eq[getPos(lhs, lhsDim)];
+ --eq[getPos(rhs, rhsDim)];
+ } else if (cmp == GT || cmp == GE) {
+ --eq[getPos(lhs, lhsDim)];
+ ++eq[getPos(rhs, rhsDim)];
+ } else {
+ llvm_unreachable("unsupported comparison operator");
+ }
+ if (cmp == LE || cmp == GE)
+ eq[cstr.getNumDimAndSymbolVars()] -= 1;
+
+ // Add inequality to the constraint set and check if it made the constraint
+ // set empty.
+ int64_t ineqPos = cstr.getNumInequalities();
+ cstr.addInequality(eq);
+ bool isEmpty = cstr.isEmpty();
+ cstr.removeInequality(ineqPos);
+ return isEmpty;
+}
+
FailureOr<bool>
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
std::optional<int64_t> dim1,
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index e4d71415924994b..0ea06737886d416 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \
-// RUN: -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \
+// RUN: -verify-diagnostics -split-input-file | FileCheck %s
// CHECK-LABEL: func @scf_for(
// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index
@@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in
"test.some_use"(%reify1) : (index) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_constant(
+func.func @scf_if_constant(%c : i1) {
+ // CHECK: arith.constant 4 : index
+ // CHECK: arith.constant 9 : index
+ %c4 = arith.constant 4 : index
+ %c9 = arith.constant 9 : index
+ %r = scf.if %c -> index {
+ scf.yield %c4 : index
+ } else {
+ scf.yield %c9 : index
+ }
+
+ // CHECK: %[[c4:.*]] = arith.constant 4 : index
+ // CHECK: %[[c10:.*]] = arith.constant 10 : index
+ %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+ %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+ // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+ "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+ return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+// CHECK-LABEL: func @scf_if_dynamic(
+// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) {
+ %c4 = arith.constant 4 : index
+ %r = scf.if %c -> index {
+ %add1 = arith.addi %a, %b : index
+ scf.yield %add1 : index
+ } else {
+ %add2 = arith.addi %b, %c4 : index
+ %add3 = arith.addi %add2, %a : index
+ scf.yield %add3 : index
+ }
+
+ // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+ // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]]
+ %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+ %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index)
+ // CHECK: "test.some_use"(%[[lb]], %[[ub]])
+ "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) {
+ %r = scf.if %c -> index {
+ scf.yield %a : index
+ } else {
+ scf.yield %b : index
+ }
+ // The reified bound would be min(%a, %b). min/max expressions are not
+ // supported in reified bounds.
+ // expected-error @below{{could not reify bound}}
+ %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index)
+ "test.some_use"(%reify1) : (index) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @scf_if_tensor_dim(
+func.func @scf_if_tensor_dim(%c : i1) {
+ // CHECK: arith.constant 4 : index
+ // CHECK: arith.constant 9 : index
+ %c4 = arith.constant 4 : index
+ %c9 = arith.constant 9 : index
+ %t1 = tensor.empty(%c4) : tensor<?xf32>
+ %t2 = tensor.empty(%c9) : tensor<?xf32>
+ %r = scf.if %c -> tensor<?xf32> {
+ scf.yield %t1 : tensor<?xf32>
+ } else {
+ scf.yield %t2 : tensor<?xf32>
+ }
+
+ // CHECK: %[[c4:.*]] = arith.constant 4 : index
+ // CHECK: %[[c10:.*]] = arith.constant 10 : index
+ %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0}
+ : (tensor<?xf32>) -> (index)
+ %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0}
+ : (tensor<?xf32>) -> (index)
+ // CHECK: "test.some_use"(%[[c4]], %[[c10]])
+ "test.some_use"(%reify1, %reify2) : (index, index) -> ()
+ return
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @scf_if_eq(
+// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1)
+func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
+ %c0 = arith.constant 0 : index
+ %r = scf.if %c -> index {
+ %add1 = arith.addi %a, %b : index
+ scf.yield %add1 : index
+ } else {
+ %add2 = arith.addi %b, %c0 : index
+ %add3 = arith.addi %add2, %a : index
+ scf.yield %add3 : index
+ }
+
+ // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]]
+ %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index)
+ // CHECK: "test.some_use"(%[[eq]])
+ "test.some_use"(%reify1) : (index) -> ()
+ return
+}
More information about the llvm-branch-commits
mailing list