[llvm-branch-commits] [mlir] [mlir][Interfaces] `ValueBoundsOpInterface`: Add API to compare values (PR #86915)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 27 23:25:23 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds a new public API to `ValueBoundsOpInterface` to compare values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT.
The new `ValueBoundsOpInterface::compare` API replaces and generalizes `ValueBoundsOpInterface::areEqual`. Not only does it provide additional comparison operators, it also works in cases where the difference between the two values/dims is non-constant. The previous implementation of `areEqual` used to compute a constant bound of `val1 - val2`.
Note: This commit refactors, generalizes and adds a public API for value/dim comparison. The comparison functionality itself was introduced in #<!-- -->85895 and is already in use for analyzing `scf.if`.
In the long term, this improvement will allow for a more powerful analysis of subset ops. A future commit will update `areOverlappingSlices` to use the new comparison API. (`areEquivalentSlices` is already using the new API.) This will improve subset equivalence/disjointness checks with non-constant offsets/sizes/strides.
---
Patch is 31.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86915.diff
7 Files Affected:
- (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+51-10)
- (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+9-22)
- (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+181-58)
- (modified) mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir (+42-1)
- (modified) mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (+12)
- (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+8-8)
- (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+66-13)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index f35432ca0136f3..d27081fad8c6c0 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
enum ComparisonOperator { LT, LE, EQ, GT, GE };
- /// Try to prove that, based on the current state of this constraint set
+ /// Populate constraints for lhs/rhs (until the stop condition is met). Then,
+ /// try to prove that, based on the current state of this constraint set
/// (i.e., without analyzing additional IR or adding new constraints), the
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
///
@@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
- Value rhs, std::optional<int64_t> rhsDim);
+ bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+
+ /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
+ /// specified relation could not be proven. This could be because the
+ /// specified relation does in fact not hold or because there is not enough
+ /// information in the constraint set. In other words, if we do not know for
+ /// sure, this function returns "false".
+ ///
+ /// This function keeps traversing the backward slice of lhs/rhs until could
+ /// prove the relation or until it ran out of IR.
+ static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+ static bool compare(AffineMap lhs, ValueDimList lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ValueDimList rhsOperands);
+ static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ArrayRef<Value> rhsOperands);
/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
/// index-typed.
- static FailureOr<bool> areEqual(Value value1, Value value2,
+ static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
- /// Compute whether the given values/attributes are equal. Return "failure" if
- /// equality could not be determined.
- ///
- /// `ofr1`/`ofr2` must be of index type.
- static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
-
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
/// Return "failure" if unknown.
@@ -290,6 +304,20 @@ class ValueBoundsConstraintSet
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
+ /// Return "true" if, based on the current state of the constraint system,
+ /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
+ /// could not be proven. This could be because the specified relation does in
+ /// fact not hold or because there is not enough information in the constraint
+ /// set. In other words, if we do not know for sure, this function returns
+ /// "false".
+ ///
+ /// This function does not analyze any IR and does not populate any additional
+ /// constraints.
+ bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+ bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
@@ -311,6 +339,14 @@ class ValueBoundsConstraintSet
/// value/dimension exists in the constraint set.
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
+ /// Return an affine expression that represents column `pos` in the constraint
+ /// set.
+ AffineExpr getPosExpr(int64_t pos);
+
+ /// Return "true" if the given value/dim is mapped (i.e., has a corresponding
+ /// column in the constraint system).
+ bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
+
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
/// "false", a dimension is added. The value/dimension is added to the
/// worklist if `addToWorklist` is set.
@@ -330,6 +366,11 @@ class ValueBoundsConstraintSet
/// dimensions but not for symbols.
int64_t insert(bool isSymbol = true);
+ /// Insert the given affine map and its bound operands as a new column in the
+ /// constraint system. Return the position of the new column. Any operands
+ /// that were not analyzed yet are put on the worklist.
+ int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 72c5aaa2306783..087ffc438a830a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,20 +58,11 @@ struct ForOpInterface
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];
- // Populate constraints for the yielded value.
- cstr.populateConstraints(yieldedValue, dim);
- // Populate constraints for the iter_arg. This is just to ensure that the
- // iter_arg is mapped in the constraint set, which is a prerequisite for
- // `compare`. It may lead to a recursive call to this function in case the
- // iter_arg was not visited when the constraints for the yielded value were
- // populated, but no additional work is done.
- cstr.populateConstraints(iterArg, dim);
-
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
- if (cstr.compare(yieldedValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
- dim)) {
+ if (cstr.populateAndCompare(
+ yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ iterArg, dim)) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
@@ -113,10 +104,6 @@ struct IfOpInterface
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];
- // Populate constraints for the yielded value (and all values on the
- // backward slice, as long as the current stop condition is not satisfied).
- cstr.populateConstraints(thenValue, dim);
- cstr.populateConstraints(elseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
@@ -125,9 +112,9 @@ struct IfOpInterface
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
- if (cstr.compare(thenValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ if (cstr.populateAndCompare(
+ thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+ elseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -139,9 +126,9 @@ struct IfOpInterface
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
- if (cstr.compare(elseValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ if (cstr.populateAndCompare(
+ elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+ thenValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index dd98da9adc7d96..d7ffed14daccdd 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -212,6 +212,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
+ bool isSymbol) {
+ assert(map.getNumResults() == 1 && "expected affine map with one result");
+ int64_t pos = insert(/*isSymbol=*/false);
+
+ // Add map and operands to the constraint set. Dimensions are converted to
+ // symbols. All operands are added to the worklist (unless they were already
+ // processed).
+ auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+ return getExpr(v.first, v.second);
+ };
+ SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+ SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+ addBound(
+ presburger::BoundType::EQ, pos,
+ map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+ return pos;
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -227,6 +249,20 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
return it->second;
}
+AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
+ assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
+ return pos < cstr.getNumDimVars()
+ ? builder.getAffineDimExpr(pos)
+ : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
+}
+
+bool ValueBoundsConstraintSet::isMapped(Value value,
+ std::optional<int64_t> dim) const {
+ auto it =
+ valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
+ return it != valueDimToPosition.end();
+}
+
static Operation *getOwnerOfValue(Value value) {
if (auto bbArg = dyn_cast<BlockArgument>(value))
return bbArg.getOwner()->getParentOp();
@@ -563,27 +599,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value,
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
ValueDimList operands) {
- assert(map.getNumResults() == 1 && "expected affine map with one result");
- int64_t pos = insert(/*isSymbol=*/false);
-
- // Add map and operands to the constraint set. Dimensions are converted to
- // symbols. All operands are added to the worklist (unless they were already
- // processed).
- auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
- return getExpr(v.first, v.second);
- };
- SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
- llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
- SmallVector<AffineExpr> symReplacements = llvm::to_vector(
- llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
- addBound(
- presburger::BoundType::EQ, pos,
- map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
-
+ int64_t pos = insert(map, operands, /*isSymbol=*/false);
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
processWorklist();
-
return pos;
}
@@ -603,9 +622,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
{{value1, dim1}, {value2, dim2}});
}
-bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, Value rhs,
- std::optional<int64_t> rhsDim) {
+bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
+ std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp,
+ OpFoldResult rhs,
+ std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ assertValidValueDim(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -624,19 +652,61 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
// EQ can be expressed as LE and GE.
if (cmp == EQ)
- return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
- compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
+ return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
+ compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
// Construct inequality. For the above example: lhs > rhs.
// `IntegerRelation` inequalities are expressed in the "flattened" form and
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
+ SmallVector<int64_t> eq(cstr.getNumCols(), 0);
+ auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
+ int64_t factor) {
+ if (auto constVal = ::getConstantIntValue(ofr)) {
+ eq[cstr.getNumCols() - 1] += *constVal * factor;
+ } else {
+ eq[getPos(cast<Value>(ofr), dim)] += factor;
+ }
+ };
+ if (cmp == LT || cmp == LE) {
+ addToEq(lhs, lhsDim, 1);
+ addToEq(rhs, rhsDim, -1);
+ } else if (cmp == GT || cmp == GE) {
+ addToEq(lhs, lhsDim, -1);
+ addToEq(rhs, rhsDim, 1);
+ } else {
+ llvm_unreachable("unsupported comparison operator");
+ }
+ if (cmp == LE || cmp == GE)
+ eq[cstr.getNumCols() - 1] -= 1;
+
+ // Add inequality to the constraint set and check if it made the constraint
+ // set empty.
+ int64_t ineqPos = cstr.getNumInequalities();
+ cstr.addInequality(eq);
+ bool isEmpty = cstr.isEmpty();
+ cstr.removeInequality(ineqPos);
+ return isEmpty;
+}
+
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+ ComparisonOperator cmp,
+ int64_t rhsPos) {
+ // This function returns "true" if "lhs CMP rhs" is proven to hold. For
+ // detailed documentation, see `compareValueDims`.
+
+ // EQ can be expressed as LE and GE.
+ if (cmp == EQ)
+ return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
+ comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+
+ // Construct inequality.
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
if (cmp == LT || cmp == LE) {
- ++eq[getPos(lhs, lhsDim)];
- --eq[getPos(rhs, rhsDim)];
+ ++eq[lhsPos];
+ --eq[rhsPos];
} else if (cmp == GT || cmp == GE) {
- --eq[getPos(lhs, lhsDim)];
- ++eq[getPos(rhs, rhsDim)];
+ --eq[lhsPos];
+ ++eq[rhsPos];
} else {
llvm_unreachable("unsupported comparison operator");
}
@@ -652,40 +722,93 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
return isEmpty;
}
+bool ValueBoundsConstraintSet::populateAndCompare(
+ OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
+ OpFoldResult rhs, std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ assertValidValueDim(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ populateConstraints(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ populateConstraints(rhsVal, rhsDim);
+
+ return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
+ std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim) {
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs are not mapped.
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ if (!cstr.isMapped(lhsVal, dim))
+ return false;
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ if (!cstr.isMapped(rhsVal, dim))
+ return false;
+ // Keep processing as long as the relation cannot be proven.
+ return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+ };
+
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ValueDimList rhsOperands) {
+ int64_t lhsPos = -1, rhsPos = -1;
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs were not processed.
+ if (lhsPos >= cstr.positionToValueDim.size() ||
+ rhsPos >= cstr.positionToValueDim.size())
+ return false;
+ // Keep processing as long as the relation cannot be proven.
+ return cstr.comparePos(lhsPos, cmp, rhsPos);
+ };
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ lhsPos = cstr.insert(lhs, lhsOperands);
+ rhsPos = cstr.insert(rhs, rhsOperands);
+ return cstr.comparePos(lhsPos, cmp, rhsPos);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs,
+ ArrayRef<Value> lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ArrayRef<Value> rhsOperands) {
+ ValueDimList lhsValueDimOperands =
+ llvm::map_to_vector(lhsOperands, [](Value v) {
+ return std::make_pair(v, std::optional<int64_t>());
+ });
+ ValueDimList rhsValueDimOperands =
+ llvm::map_to_vector(rhsOperands, [](Value v) {
+ return std::make_pair(v, std::optional<int64_t>());
+ });
+ return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
+ rhsValueDimOperands);
+}
+
FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1,
std::optional<int64_t> dim2) {
- // Subtract the two values/dimensions from each other. If the result is 0,
- // both are equal.
- FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
- if (failed(delta))
- return failure();
- return *delta == 0;
-}
-
-FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
- OpFoldR...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/86915
More information about the llvm-branch-commits
mailing list