[llvm-branch-commits] [mlir] [mlir][Interfaces][WIP] `Variable` abstraction for `ValueBoundsOpInterface` (PR #87980)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Apr 8 04:31:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit generalizes and cleans up the `ValueBoundsConstraintSet` API. The API used to provide function overloads for comparing/computing bounds of:
- index-typed SSA value
- dimension of shaped value
- affine map + operands
This commit removes all overloads. There is now a single entry point for each `compare` variant and each `computeBound` variant. These functions now take a `Variable`, which is internally represented as an affine map and map operands.
This commit also adds support for computing bounds for an affine map + operands. There was previously no public API for that.
WIP until I added a test case for `computeBounds(AffineMap)`.
---
Patch is 47.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/87980.diff
15 Files Affected:
- (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+61-56)
- (modified) mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp (+2-4)
- (modified) mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp (+1-1)
- (modified) mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp (+69)
- (modified) mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp (+1-1)
- (modified) mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp (+3-1)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Padding.cpp (+4-2)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp (+2-4)
- (modified) mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp (+2-3)
- (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+10-7)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+1-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp (+2-1)
- (modified) mlir/lib/Dialect/Tensor/Utils/Utils.cpp (+2-2)
- (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+149-188)
- (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 1d7bc6ea961cc3a..3e1502b4f5c357a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include <queue>
@@ -111,6 +112,39 @@ class ValueBoundsConstraintSet
public:
static char ID;
+ /// A variable that can be added to the constraint set as a "column". The
+ /// value bounds infrastructure can compute bounds for variables and compare
+ /// two variables.
+ ///
+ /// Internally, a variable is represented as an affine map and operands.
+ class Variable {
+ public:
+ /// Construct a variable for an index-typed attribute or SSA value.
+ Variable(OpFoldResult ofr);
+
+ /// Construct a variable for an index-typed SSA value.
+ Variable(Value indexValue);
+
+ /// Construct a variable for a dimension of a shaped value.
+ Variable(Value shapedValue, int64_t dim);
+
+ /// Construct a variable for an index-typed attribute/SSA value or for a
+ /// dimension of a shaped value. A non-null dimension must be provided if
+ /// and only if `ofr` is a shaped value.
+ Variable(OpFoldResult ofr, std::optional<int64_t> dim);
+
+ /// Construct a variable for a map and its operands.
+ Variable(AffineMap map, ArrayRef<Variable> mapOperands);
+ Variable(AffineMap map, ArrayRef<Value> mapOperands);
+
+ MLIRContext *getContext() const { return map.getContext(); }
+
+ private:
+ friend class ValueBoundsConstraintSet;
+ AffineMap map;
+ ValueDimList mapOperands;
+ };
+
/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
@@ -121,35 +155,31 @@ class ValueBoundsConstraintSet
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
- /// Compute a bound for the given index-typed value or shape dimension size.
- /// The computed bound is stored in `resultMap`. The operands of the bound are
- /// stored in `mapOperands`. An operand is either an index-type SSA value
- /// or a shaped value and a dimension.
+ /// Compute a bound for the given variable. The computed bound is stored in
+ /// `resultMap`. The operands of the bound are stored in `mapOperands`. An
+ /// operand is either an index-type SSA value or a shaped value and a
+ /// dimension.
///
- /// `dim` must be `nullopt` if and only if `value` is index-typed. The bound
- /// is computed in terms of values/dimensions for which `stopCondition`
- /// evaluates to "true". To that end, the backward slice (reverse use-def
- /// chain) of the given value is visited in a worklist-driven manner and the
- /// constraint set is populated according to `ValueBoundsOpInterface` for each
- /// visited value.
+ /// The bound is computed in terms of values/dimensions for which
+ /// `stopCondition` evaluates to "true". To that end, the backward slice
+ /// (reverse use-def chain) of the given value is visited in a worklist-driven
+ /// manner and the constraint set is populated according to
+ /// `ValueBoundsOpInterface` for each visited value.
///
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
- static LogicalResult computeBound(AffineMap &resultMap,
- ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim,
- StopConditionFn stopCondition,
- bool closedUB = false);
+ static LogicalResult
+ computeBound(AffineMap &resultMap, ValueDimList &mapOperands,
+ presburger::BoundType type, const Variable &var,
+ StopConditionFn stopCondition, bool closedUB = false);
/// Compute a bound in terms of the values/dimensions in `dependencies`. The
/// computed bound consists of only constant terms and dependent values (or
/// dimension sizes thereof).
static LogicalResult
computeDependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueDimList dependencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueDimList dependencies, bool closedUB = false);
/// Compute a bound in that is independent of all values in `independencies`.
///
@@ -161,13 +191,10 @@ class ValueBoundsConstraintSet
/// appear in the computed bound.
static LogicalResult
computeIndependentBound(AffineMap &resultMap, ValueDimList &mapOperands,
- presburger::BoundType type, Value value,
- std::optional<int64_t> dim, ValueRange independencies,
- bool closedUB = false);
+ presburger::BoundType type, const Variable &var,
+ ValueRange independencies, bool closedUB = false);
- /// Compute a constant bound for the given affine map, where dims and symbols
- /// are bound to the given operands. The affine map must have exactly one
- /// result.
+ /// Compute a constant bound for the given variable.
///
/// This function traverses the backward slice of the given operands in a
/// worklist-driven manner until `stopCondition` evaluates to "true". The
@@ -182,16 +209,9 @@ class ValueBoundsConstraintSet
/// By default, lower/equal bounds are closed and upper bounds are open. If
/// `closedUB` is set to "true", upper bounds are also closed.
static FailureOr<int64_t>
- computeConstantBound(presburger::BoundType type, Value value,
- std::optional<int64_t> dim = std::nullopt,
+ computeConstantBound(presburger::BoundType type, const Variable &var,
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
- static FailureOr<int64_t> computeConstantBound(
- presburger::BoundType type, AffineMap map, ArrayRef<Value> mapOperands,
- StopConditionFn stopCondition = nullptr, bool closedUB = false);
/// Compute a constant delta between the given two values. Return "failure"
/// if a constant delta could not be determined.
@@ -221,9 +241,7 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
+ bool populateAndCompare(Variable lhs, ComparisonOperator cmp, Variable rhs);
/// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
/// specified relation could not be proven. This could be because the
@@ -233,24 +251,11 @@ class ValueBoundsConstraintSet
///
/// This function keeps traversing the backward slice of lhs/rhs until could
/// prove the relation or until it ran out of IR.
- static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
- static bool compare(AffineMap lhs, ValueDimList lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ValueDimList rhsOperands);
- static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
- ComparisonOperator cmp, AffineMap rhs,
- ArrayRef<Value> rhsOperands);
-
- /// Compute whether the given values/dimensions are equal. Return "failure" if
+ static bool compare(Variable lhs, ComparisonOperator cmp, Variable rhs);
+
+ /// Compute whether the given variables are equal. Return "failure" if
/// equality could not be determined.
- ///
- /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
- /// index-typed.
- static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
- std::optional<int64_t> dim1 = std::nullopt,
- std::optional<int64_t> dim2 = std::nullopt);
+ static FailureOr<bool> areEqual(Variable var1, Variable var2);
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
@@ -317,9 +322,6 @@ class ValueBoundsConstraintSet
///
/// This function does not analyze any IR and does not populate any additional
/// constraints.
- bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, OpFoldResult rhs,
- std::optional<int64_t> rhsDim);
bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
/// Given an affine map with a single result (and map operands), add a new
@@ -374,6 +376,7 @@ class ValueBoundsConstraintSet
/// constraint system. Return the position of the new column. Any operands
/// that were not analyzed yet are put on the worklist.
int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+ int64_t insert(const Variable &var, bool isSymbol = true);
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
@@ -381,6 +384,8 @@ class ValueBoundsConstraintSet
/// Project out all columns for which the condition holds.
void projectOut(function_ref<bool(ValueDim)> condition);
+ void projectOutAnonymous(std::optional<int64_t> except = std::nullopt);
+
/// Mapping of columns to values/shape dimensions.
SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe7a0f71d1..82a9fb0d490882f 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
- ValueDimList valueDims;
- for (Value v : mapOperands)
- valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::EQ, map, valueDims);
+ presburger::BoundType::EQ,
+ ValueBoundsConstraintSet::Variable(map, mapOperands));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e8701ad7c..6c59df91e8af781 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -25,7 +25,7 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, {value, dim}, stopCondition, closedUB)))
return failure();
// Reify bound.
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index 90895e381c74b5a..411fc117a4d9f5d 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -75,6 +75,75 @@ struct MulIOpInterface
}
};
+struct SelectOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+
+ static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ Value value = selectOp.getResult();
+ Value condition = selectOp.getCondition();
+ Value trueValue = selectOp.getTrueValue();
+ Value falseValue = selectOp.getFalseValue();
+
+ if (isa<ShapedType>(condition.getType())) {
+ // If the condition is a shaped type, the condition is applied
+ // element-wise. All three operands must have the same shape.
+ cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
+ return;
+ }
+
+ // Populate constraints for the true/false values (and all values on the
+ // backward slice, as long as the current stop condition is not satisfied).
+ cstr.populateConstraints(trueValue, dim);
+ cstr.populateConstraints(falseValue, dim);
+ auto boundsBuilder = cstr.bound(value);
+ if (dim)
+ boundsBuilder[*dim];
+
+ // Compare yielded values.
+ // If trueValue <= falseValue:
+ // * result <= falseValue
+ // * result >= trueValue
+ if (cstr.compare(/*lhs=*/{trueValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{falseValue, dim})) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
+ } else {
+ cstr.bound(value) >= trueValue;
+ cstr.bound(value) <= falseValue;
+ }
+ }
+ // If falseValue <= trueValue:
+ // * result <= trueValue
+ // * result >= falseValue
+ if (cstr.compare(/*lhs=*/{falseValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{trueValue, dim})) {
+ if (dim) {
+ cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
+ cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
+ } else {
+ cstr.bound(value) >= falseValue;
+ cstr.bound(value) <= trueValue;
+ }
+ }
+ }
+
+ void populateBoundsForIndexValue(Operation *op, Value value,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
+ }
+
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ populateBounds(cast<SelectOp>(op), dim, cstr);
+ }
+};
} // namespace
} // namespace arith
} // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6ed2e99a2..f87f3d6350c0221 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad221288f190ed..5bb7d83bf1e3f86 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -70,7 +70,9 @@ reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type,
+ ValueBoundsConstraintSet::Variable(value, dim), stopCondition,
+ closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70db2489897..518d2e138c02a97 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, opOperand->get(),
- /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB,
+ {opOperand->get(),
+ /*dim=*/i},
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6c30d049d..71eb59d40836c1f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
- Value materializedSize =
- getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
- ? materializedSize
+ ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508265e7b9f..1f06318cbd60e04 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueRange independencies) {
if (ofr.is<Attribute>())
return ofr;
- Value value = ofr.get<Value>();
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
- boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+ /*closedUB=*/true)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc438a830a3..17a1c016ea16d5a 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
// An EQ constraint can be added if the y...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/87980
More information about the llvm-branch-commits
mailing list