[Mlir-commits] [mlir] ff93064 - [mlir][Interfaces] ValueBoundsOpInterface: Check if two values are equal
Matthias Springer
llvmlistbot at llvm.org
Thu May 25 10:08:59 PDT 2023
Author: Matthias Springer
Date: 2023-05-25T19:08:48+02:00
New Revision: ff9306459ff2f09b18ccbd4b5cefaf833730f42d
URL: https://github.com/llvm/llvm-project/commit/ff9306459ff2f09b18ccbd4b5cefaf833730f42d
DIFF: https://github.com/llvm/llvm-project/commit/ff9306459ff2f09b18ccbd4b5cefaf833730f42d.diff
LOG: [mlir][Interfaces] ValueBoundsOpInterface: Check if two values are equal
Add a helper function that computes if two SSA values have the same value, utilizing the `ValueBoundsOpInterface` infrastructure. Two SSA values have the same value, an equality bound of 0 can be derived for their subtraction.
The helper function can also be used to determine if two tensor dimension sizes are equal.
Differential Revision: https://reviews.llvm.org/D151443
Added:
Modified:
mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index c4d446e2ca4bc..367ae28cb773a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -156,6 +156,35 @@ class ValueBoundsConstraintSet {
StopConditionFn stopCondition = nullptr,
bool closedUB = false);
+ /// Compute a constant bound for the given affine map, where dims and symbols
+ /// are bound to the given operands. The affine map must have exactly one
+ /// result.
+ ///
+ /// This function traverses the backward slice of the given operands in a
+ /// worklist-driven manner until `stopCondition` evaluates to "true". The
+ /// constraint set is populated according to `ValueBoundsOpInterface` for each
+ /// visited value. (No constraints are added for values for which the stop
+ /// condition evaluates to "true".)
+ ///
+ /// The stop condition is optional: If none is specified, the backward slice
+ /// is traversed in a breadth-first manner until a constant bound could be
+ /// computed.
+ ///
+ /// By default, lower/equal bounds are closed and upper bounds are open. If
+ /// `closedUB` is set to "true", upper bounds are also closed.
+ static FailureOr<int64_t> computeConstantBound(
+ presburger::BoundType type, AffineMap map, ValueDimList mapOperands,
+ StopConditionFn stopCondition = nullptr, bool closedUB = false);
+
+ /// Compute whether the given values/dimensions are equal. Return "failure" if
+ /// equality could not be determined.
+ ///
+ /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
+ /// index-typed.
+ static FailureOr<bool> areEqual(Value value1, Value value2,
+ std::optional<int64_t> dim1 = std::nullopt,
+ std::optional<int64_t> dim2 = std::nullopt);
+
/// Add a bound for the given index-typed value or shaped value. This function
/// returns a builder that adds the bound.
BoundBuilder bound(Value value) { return BoundBuilder(*this, value); }
@@ -199,13 +228,23 @@ class ValueBoundsConstraintSet {
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
- /// "false", a dimension is added.
+ /// "false", a dimension is added. The value/dimension is added to the
+ /// worklist.
///
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
/// cannot be multiplied. Furthermore, bounds can only be queried for
/// dimensions but not for symbols.
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);
+ /// Insert an anonymous column into the constraint set. The column is not
+ /// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
+ /// is added.
+ ///
+ /// Note: There are certain affine restrictions wrt. dimensions. E.g., they
+ /// cannot be multiplied. Furthermore, bounds can only be queried for
+ /// dimensions but not for symbols.
+ int64_t insert(bool isSymbol = true);
+
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
@@ -213,7 +252,7 @@ class ValueBoundsConstraintSet {
void projectOut(function_ref<bool(ValueDim)> condition);
/// Mapping of columns to values/shape dimensions.
- SmallVector<ValueDim> positionToValueDim;
+ SmallVector<std::optional<ValueDim>> positionToValueDim;
/// Reverse mapping of values/shape dimensions to columns.
DenseMap<ValueDim, int64_t> valueDimToPosition;
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 28f34a9644b29..bc7d6b45cba57 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -124,12 +124,24 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
positionToValueDim.insert(positionToValueDim.begin() + pos, valueDim);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
- valueDimToPosition[positionToValueDim[i]] = i;
+ if (positionToValueDim[i].has_value())
+ valueDimToPosition[*positionToValueDim[i]] = i;
worklist.push(pos);
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
+ int64_t pos = isSymbol ? cstr.appendVar(VarKind::Symbol)
+ : cstr.appendVar(VarKind::SetDim);
+ positionToValueDim.insert(positionToValueDim.begin() + pos, std::nullopt);
+ // Update reverse mapping.
+ for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
+ if (positionToValueDim[i].has_value())
+ valueDimToPosition[*positionToValueDim[i]] = i;
+ return pos;
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -155,7 +167,9 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
while (!worklist.empty()) {
int64_t pos = worklist.front();
worklist.pop();
- ValueDim valueDim = positionToValueDim[pos];
+ assert(positionToValueDim[pos].has_value() &&
+ "did not expect std::nullopt on worklist");
+ ValueDim valueDim = *positionToValueDim[pos];
Value value = valueDim.first;
int64_t dim = valueDim.second;
@@ -191,20 +205,24 @@ void ValueBoundsConstraintSet::projectOut(int64_t pos) {
assert(pos >= 0 && pos < static_cast<int64_t>(positionToValueDim.size()) &&
"invalid position");
cstr.projectOut(pos);
- bool erased = valueDimToPosition.erase(positionToValueDim[pos]);
- (void)erased;
- assert(erased && "inconsistent reverse mapping");
+ if (positionToValueDim[pos].has_value()) {
+ bool erased = valueDimToPosition.erase(*positionToValueDim[pos]);
+ (void)erased;
+ assert(erased && "inconsistent reverse mapping");
+ }
positionToValueDim.erase(positionToValueDim.begin() + pos);
// Update reverse mapping.
for (int64_t i = pos, e = positionToValueDim.size(); i < e; ++i)
- valueDimToPosition[positionToValueDim[i]] = i;
+ if (positionToValueDim[i].has_value())
+ valueDimToPosition[*positionToValueDim[i]] = i;
}
void ValueBoundsConstraintSet::projectOut(
function_ref<bool(ValueDim)> condition) {
int64_t nextPos = 0;
while (nextPos < static_cast<int64_t>(positionToValueDim.size())) {
- if (condition(positionToValueDim[nextPos])) {
+ if (positionToValueDim[nextPos].has_value() &&
+ condition(*positionToValueDim[nextPos])) {
projectOut(nextPos);
// The column was projected out so another column is now at that position.
// Do not increase the counter.
@@ -332,7 +350,9 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
replacementSymbols.push_back(b.getAffineSymbolExpr(numSymbols++));
}
- ValueBoundsConstraintSet::ValueDim valueDim = cstr.positionToValueDim[i];
+ assert(cstr.positionToValueDim[i].has_value() &&
+ "cannot build affine map in terms of anonymous column");
+ ValueBoundsConstraintSet::ValueDim valueDim = *cstr.positionToValueDim[i];
Value value = valueDim.first;
int64_t dim = valueDim.second;
if (dim == ValueBoundsConstraintSet::kIndexValue) {
@@ -406,10 +426,35 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
assertValidValueDim(value, dim);
#endif // NDEBUG
- // Process the backward slice of `value` (i.e., reverse use-def chain) until
- // `stopCondition` is met.
- ValueBoundsConstraintSet cstr(value.getContext());
- int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+ Builder(value.getContext()).getAffineDimExpr(0));
+ return computeConstantBound(type, map, {{value, dim}}, stopCondition,
+ closedUB);
+}
+
+FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType type, AffineMap map, ValueDimList operands,
+ StopConditionFn stopCondition, bool closedUB) {
+ assert(map.getNumResults() == 1 && "expected affine map with one result");
+ ValueBoundsConstraintSet cstr(map.getContext());
+ int64_t pos = cstr.insert(/*isSymbol=*/false);
+
+ // Add map and operands to the constraint set. Dimensions are converted to
+ // symbols. All operands are added to the worklist.
+ auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+ return cstr.getExpr(v.first, v.second);
+ };
+ SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+ SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+ cstr.addBound(
+ presburger::BoundType::EQ, pos,
+ map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+ // Process the backward slice of `operands` (i.e., reverse use-def chain)
+ // until `stopCondition` is met.
if (stopCondition) {
cstr.processWorklist(stopCondition);
} else {
@@ -428,6 +473,27 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
+FailureOr<bool>
+ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ std::optional<int64_t> dim1,
+ std::optional<int64_t> dim2) {
+#ifndef NDEBUG
+ assertValidValueDim(value1, dim1);
+ assertValidValueDim(value2, dim2);
+#endif // NDEBUG
+
+ // Subtract the two values/dimensions from each other. If the result is 0,
+ // both are equal.
+ Builder b(value1.getContext());
+ AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
+ b.getAffineDimExpr(0) - b.getAffineDimExpr(1));
+ FailureOr<int64_t> bound = computeConstantBound(
+ presburger::BoundType::EQ, map, {{value1, dim1}, {value2, dim2}});
+ if (failed(bound))
+ return failure();
+ return *bound == 0;
+}
+
ValueBoundsConstraintSet::BoundBuilder &
ValueBoundsConstraintSet::BoundBuilder::operator[](int64_t dim) {
assert(!this->dim.has_value() && "dim was already set");
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 614c6014fec98..45520da6aeb0b 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -156,3 +156,49 @@ func.func @rank(%t: tensor<5xf32>) -> index {
%1 = "test.reify_bound"(%0) : (index) -> (index)
return %1 : index
}
+
+// -----
+
+func.func @dynamic_dims_are_equal(%t: tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+ %dim1 = tensor.dim %t, %c0 : tensor<?xf32>
+ // expected-remark @below {{equal}}
+ "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @dynamic_dims_are_
diff erent(%t: tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+ %val = arith.addi %dim0, %c1 : index
+ // expected-remark @below {{
diff erent}}
+ "test.are_equal"(%dim0, %val) : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @dynamic_dims_are_maybe_equal_1(%t: tensor<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c5 = arith.constant 5 : index
+ %dim0 = tensor.dim %t, %c0 : tensor<?xf32>
+ // expected-error @below {{could not determine equality}}
+ "test.are_equal"(%dim0, %c5) : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+ %dim1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+ // expected-error @below {{could not determine equality}}
+ "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index dff619efda28d..db3b9a1decf07 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -175,10 +175,38 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return failure(result.wasInterrupted());
}
+/// Look for "test.are_equal" ops and emit errors/remarks.
+static LogicalResult testEquality(func::FuncOp funcOp) {
+ IRRewriter rewriter(funcOp.getContext());
+ WalkResult result = funcOp.walk([&](Operation *op) {
+ // Look for test.are_equal ops.
+ if (op->getName().getStringRef() == "test.are_equal") {
+ if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
+ !op->getOperand(1).getType().isIndex()) {
+ op->emitOpError("invalid op");
+ return WalkResult::skip();
+ }
+ FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
+ op->getOperand(0), op->getOperand(1));
+ if (failed(equal)) {
+ op->emitError("could not determine equality");
+ } else if (*equal) {
+ op->emitRemark("equal");
+ } else {
+ op->emitRemark("
diff erent");
+ }
+ }
+ return WalkResult::advance();
+ });
+ return failure(result.wasInterrupted());
+}
+
void TestReifyValueBounds::runOnOperation() {
if (failed(
testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
signalPassFailure();
+ if (failed(testEquality(getOperation())))
+ signalPassFailure();
}
namespace mlir {
More information about the Mlir-commits
mailing list