[Mlir-commits] [mlir] [mlir][Interfaces] `ValueBoundsOpInterface`: Add API to compare values (PR #86915)
Matthias Springer
llvmlistbot at llvm.org
Mon Apr 8 04:21:13 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86915
>From 772389c51124ba9144f1453986cd217da7cea3f4 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 8 Apr 2024 11:20:03 +0000
Subject: [PATCH] [mlir][Interfaces][WIP] Expose public `compare` API
Also use `compare` API for `areEqual` etc.
---
.../mlir/Interfaces/ValueBoundsOpInterface.h | 57 ++++-
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 31 +--
.../lib/Interfaces/ValueBoundsOpInterface.cpp | 237 +++++++++++++-----
.../value-bounds-op-interface-impl.mlir | 43 +++-
.../SCF/value-bounds-op-interface-impl.mlir | 12 +
.../value-bounds-op-interface-impl.mlir | 16 +-
.../Dialect/Affine/TestReifyValueBounds.cpp | 79 +++++-
7 files changed, 361 insertions(+), 114 deletions(-)
diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
index 3543ab52407a365..1d7bc6ea961cc3a 100644
--- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
@@ -211,7 +211,8 @@ class ValueBoundsConstraintSet
/// Comparison operator for `ValueBoundsConstraintSet::compare`.
enum ComparisonOperator { LT, LE, EQ, GT, GE };
- /// Try to prove that, based on the current state of this constraint set
+ /// Populate constraints for lhs/rhs (until the stop condition is met). Then,
+ /// try to prove that, based on the current state of this constraint set
/// (i.e., without analyzing additional IR or adding new constraints), the
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
///
@@ -220,24 +221,37 @@ class ValueBoundsConstraintSet
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
- bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
- Value rhs, std::optional<int64_t> rhsDim);
+ bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+
+ /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the
+ /// specified relation could not be proven. This could be because the
+ /// specified relation does in fact not hold or because there is not enough
+ /// information in the constraint set. In other words, if we do not know for
+ /// sure, this function returns "false".
+ ///
+ /// This function keeps traversing the backward slice of lhs/rhs until could
+ /// prove the relation or until it ran out of IR.
+ static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+ static bool compare(AffineMap lhs, ValueDimList lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ValueDimList rhsOperands);
+ static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ArrayRef<Value> rhsOperands);
/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
/// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are
/// index-typed.
- static FailureOr<bool> areEqual(Value value1, Value value2,
+ static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);
- /// Compute whether the given values/attributes are equal. Return "failure" if
- /// equality could not be determined.
- ///
- /// `ofr1`/`ofr2` must be of index type.
- static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2);
-
/// Return "true" if the given slices are guaranteed to be overlapping.
/// Return "false" if the given slices are guaranteed to be non-overlapping.
/// Return "failure" if unknown.
@@ -294,6 +308,20 @@ class ValueBoundsConstraintSet
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);
+ /// Return "true" if, based on the current state of the constraint system,
+ /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation
+ /// could not be proven. This could be because the specified relation does in
+ /// fact not hold or because there is not enough information in the constraint
+ /// set. In other words, if we do not know for sure, this function returns
+ /// "false".
+ ///
+ /// This function does not analyze any IR and does not populate any additional
+ /// constraints.
+ bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim);
+ bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos);
+
/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
@@ -319,6 +347,10 @@ class ValueBoundsConstraintSet
/// set.
AffineExpr getPosExpr(int64_t pos);
+ /// Return "true" if the given value/dim is mapped (i.e., has a corresponding
+ /// column in the constraint system).
+ bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const;
+
/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
/// "false", a dimension is added. The value/dimension is added to the
/// worklist if `addToWorklist` is set.
@@ -338,6 +370,11 @@ class ValueBoundsConstraintSet
/// dimensions but not for symbols.
int64_t insert(bool isSymbol = true);
+ /// Insert the given affine map and its bound operands as a new column in the
+ /// constraint system. Return the position of the new column. Any operands
+ /// that were not analyzed yet are put on the worklist.
+ int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true);
+
/// Project out the given column in the constraint set.
void projectOut(int64_t pos);
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 72c5aaa2306783e..087ffc438a830a3 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,20 +58,11 @@ struct ForOpInterface
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
Value initArg = forOp.getInitArgs()[iterArgIdx];
- // Populate constraints for the yielded value.
- cstr.populateConstraints(yieldedValue, dim);
- // Populate constraints for the iter_arg. This is just to ensure that the
- // iter_arg is mapped in the constraint set, which is a prerequisite for
- // `compare`. It may lead to a recursive call to this function in case the
- // iter_arg was not visited when the constraints for the yielded value were
- // populated, but no additional work is done.
- cstr.populateConstraints(iterArg, dim);
-
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
- if (cstr.compare(yieldedValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
- dim)) {
+ if (cstr.populateAndCompare(
+ yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ iterArg, dim)) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
@@ -113,10 +104,6 @@ struct IfOpInterface
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];
- // Populate constraints for the yielded value (and all values on the
- // backward slice, as long as the current stop condition is not satisfied).
- cstr.populateConstraints(thenValue, dim);
- cstr.populateConstraints(elseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];
@@ -125,9 +112,9 @@ struct IfOpInterface
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
- if (cstr.compare(thenValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ if (cstr.populateAndCompare(
+ thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+ elseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -139,9 +126,9 @@ struct IfOpInterface
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
- if (cstr.compare(elseValue, dim,
- ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ if (cstr.populateAndCompare(
+ elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
+ thenValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 6e3d6dd3c757591..c138056ab41cc33 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -202,6 +202,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) {
return pos;
}
+int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands,
+ bool isSymbol) {
+ assert(map.getNumResults() == 1 && "expected affine map with one result");
+ int64_t pos = insert(/*isSymbol=*/false);
+
+ // Add map and operands to the constraint set. Dimensions are converted to
+ // symbols. All operands are added to the worklist (unless they were already
+ // processed).
+ auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
+ return getExpr(v.first, v.second);
+ };
+ SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
+ SmallVector<AffineExpr> symReplacements = llvm::to_vector(
+ llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
+ addBound(
+ presburger::BoundType::EQ, pos,
+ map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
+
+ return pos;
+}
+
int64_t ValueBoundsConstraintSet::getPos(Value value,
std::optional<int64_t> dim) const {
#ifndef NDEBUG
@@ -224,6 +246,13 @@ AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
}
+bool ValueBoundsConstraintSet::isMapped(Value value,
+ std::optional<int64_t> dim) const {
+ auto it =
+ valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue)));
+ return it != valueDimToPosition.end();
+}
+
static Operation *getOwnerOfValue(Value value) {
if (auto bbArg = dyn_cast<BlockArgument>(value))
return bbArg.getOwner()->getParentOp();
@@ -560,27 +589,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value,
int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
ValueDimList operands) {
- assert(map.getNumResults() == 1 && "expected affine map with one result");
- int64_t pos = insert(/*isSymbol=*/false);
-
- // Add map and operands to the constraint set. Dimensions are converted to
- // symbols. All operands are added to the worklist (unless they were already
- // processed).
- auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
- return getExpr(v.first, v.second);
- };
- SmallVector<AffineExpr> dimReplacements = llvm::to_vector(
- llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper));
- SmallVector<AffineExpr> symReplacements = llvm::to_vector(
- llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper));
- addBound(
- presburger::BoundType::EQ, pos,
- map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements));
-
+ int64_t pos = insert(map, operands, /*isSymbol=*/false);
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
processWorklist();
-
return pos;
}
@@ -600,9 +612,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
{{value1, dim1}, {value2, dim2}});
}
-bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
- ComparisonOperator cmp, Value rhs,
- std::optional<int64_t> rhsDim) {
+bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs,
+ std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp,
+ OpFoldResult rhs,
+ std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ assertValidValueDim(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
@@ -621,24 +642,32 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
// EQ can be expressed as LE and GE.
if (cmp == EQ)
- return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
- compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
+ return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
+ compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);
// Construct inequality. For the above example: lhs > rhs.
// `IntegerRelation` inequalities are expressed in the "flattened" form and
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
- SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
+ SmallVector<int64_t> eq(cstr.getNumCols(), 0);
+ auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim,
+ int64_t factor) {
+ if (auto constVal = ::getConstantIntValue(ofr)) {
+ eq[cstr.getNumCols() - 1] += *constVal * factor;
+ } else {
+ eq[getPos(cast<Value>(ofr), dim)] += factor;
+ }
+ };
if (cmp == LT || cmp == LE) {
- ++eq[getPos(lhs, lhsDim)];
- --eq[getPos(rhs, rhsDim)];
+ addToEq(lhs, lhsDim, 1);
+ addToEq(rhs, rhsDim, -1);
} else if (cmp == GT || cmp == GE) {
- --eq[getPos(lhs, lhsDim)];
- ++eq[getPos(rhs, rhsDim)];
+ addToEq(lhs, lhsDim, -1);
+ addToEq(rhs, rhsDim, 1);
} else {
llvm_unreachable("unsupported comparison operator");
}
if (cmp == LE || cmp == GE)
- eq[cstr.getNumDimAndSymbolVars()] -= 1;
+ eq[cstr.getNumCols() - 1] -= 1;
// Add inequality to the constraint set and check if it made the constraint
// set empty.
@@ -649,40 +678,128 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
return isEmpty;
}
+bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos,
+ ComparisonOperator cmp,
+ int64_t rhsPos) {
+ // This function returns "true" if "lhs CMP rhs" is proven to hold. For
+ // detailed documentation, see `compareValueDims`.
+
+ // EQ can be expressed as LE and GE.
+ if (cmp == EQ)
+ return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) &&
+ comparePos(lhsPos, ComparisonOperator::GE, rhsPos);
+
+ // Construct inequality.
+ SmallVector<int64_t> eq(cstr.getNumCols(), 0);
+ if (cmp == LT || cmp == LE) {
+ ++eq[lhsPos];
+ --eq[rhsPos];
+ } else if (cmp == GT || cmp == GE) {
+ --eq[lhsPos];
+ ++eq[rhsPos];
+ } else {
+ llvm_unreachable("unsupported comparison operator");
+ }
+ if (cmp == LE || cmp == GE)
+ eq[cstr.getNumCols() - 1] -= 1;
+
+ // Add inequality to the constraint set and check if it made the constraint
+ // set empty.
+ int64_t ineqPos = cstr.getNumInequalities();
+ cstr.addInequality(eq);
+ bool isEmpty = cstr.isEmpty();
+ cstr.removeInequality(ineqPos);
+ return isEmpty;
+}
+
+bool ValueBoundsConstraintSet::populateAndCompare(
+ OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
+ OpFoldResult rhs, std::optional<int64_t> rhsDim) {
+#ifndef NDEBUG
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ assertValidValueDim(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ assertValidValueDim(rhsVal, rhsDim);
+#endif // NDEBUG
+
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ populateConstraints(lhsVal, lhsDim);
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ populateConstraints(rhsVal, rhsDim);
+
+ return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(OpFoldResult lhs,
+ std::optional<int64_t> lhsDim,
+ ComparisonOperator cmp, OpFoldResult rhs,
+ std::optional<int64_t> rhsDim) {
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs are not mapped.
+ if (auto lhsVal = dyn_cast<Value>(lhs))
+ if (!cstr.isMapped(lhsVal, dim))
+ return false;
+ if (auto rhsVal = dyn_cast<Value>(rhs))
+ if (!cstr.isMapped(rhsVal, dim))
+ return false;
+ // Keep processing as long as the relation cannot be proven.
+ return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim);
+ };
+
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ValueDimList rhsOperands) {
+ int64_t lhsPos = -1, rhsPos = -1;
+ auto stopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ // Keep processing as long as lhs/rhs were not processed.
+ if (lhsPos >= cstr.positionToValueDim.size() ||
+ rhsPos >= cstr.positionToValueDim.size())
+ return false;
+ // Keep processing as long as the relation cannot be proven.
+ return cstr.comparePos(lhsPos, cmp, rhsPos);
+ };
+ ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition);
+ lhsPos = cstr.insert(lhs, lhsOperands);
+ rhsPos = cstr.insert(rhs, rhsOperands);
+ cstr.processWorklist();
+ return cstr.comparePos(lhsPos, cmp, rhsPos);
+}
+
+bool ValueBoundsConstraintSet::compare(AffineMap lhs,
+ ArrayRef<Value> lhsOperands,
+ ComparisonOperator cmp, AffineMap rhs,
+ ArrayRef<Value> rhsOperands) {
+ ValueDimList lhsValueDimOperands =
+ llvm::map_to_vector(lhsOperands, [](Value v) {
+ return std::make_pair(v, std::optional<int64_t>());
+ });
+ ValueDimList rhsValueDimOperands =
+ llvm::map_to_vector(rhsOperands, [](Value v) {
+ return std::make_pair(v, std::optional<int64_t>());
+ });
+ return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs,
+ rhsValueDimOperands);
+}
+
FailureOr<bool>
-ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
+ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2,
std::optional<int64_t> dim1,
std::optional<int64_t> dim2) {
- // Subtract the two values/dimensions from each other. If the result is 0,
- // both are equal.
- FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2);
- if (failed(delta))
- return failure();
- return *delta == 0;
-}
-
-FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1,
- OpFoldResult ofr2) {
- Builder b(ofr1.getContext());
- AffineMap map =
- AffineMap::get(/*dimCount=*/0, /*symbolCount=*/2,
- b.getAffineSymbolExpr(0) - b.getAffineSymbolExpr(1));
- SmallVector<OpFoldResult> ofrOperands;
- ofrOperands.push_back(ofr1);
- ofrOperands.push_back(ofr2);
- SmallVector<Value> valueOperands;
- AffineMap foldedMap =
- foldAttributesIntoMap(b, map, ofrOperands, valueOperands);
- ValueDimList valueDims;
- for (Value v : valueOperands) {
- assert(v.getType().isIndex() && "expected index type");
- valueDims.emplace_back(v, std::nullopt);
- }
- FailureOr<int64_t> delta =
- computeConstantBound(presburger::BoundType::EQ, foldedMap, valueDims);
- if (failed(delta))
- return failure();
- return *delta == 0;
+ if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::EQ,
+ value2, dim2))
+ return true;
+ if (ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::LT,
+ value2, dim2) ||
+ ValueBoundsConstraintSet::compare(value1, dim1, ComparisonOperator::GT,
+ value2, dim2))
+ return false;
+ return failure();
}
FailureOr<bool>
diff --git a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
index 55282e8334abd72..10da91870f49d94 100644
--- a/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir
@@ -79,6 +79,17 @@ func.func @composed_affine_apply(%i1 : index) -> (index) {
}
+// -----
+
+func.func @are_equal(%i1 : index) {
+ %i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
+ %i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
+ %s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
+ // expected-remark @below{{false}}
+ "test.compare"(%i2, %i3) : (index, index) -> ()
+ return
+}
+
// -----
// Test for affine::fullyComposeAndCheckIfEqual
@@ -87,6 +98,36 @@ func.func @composed_are_equal(%i1 : index) {
%i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
%s = affine.apply affine_map<()[s0, s1] -> (s0 - s1)>()[%i2, %i3]
// expected-remark @below{{different}}
- "test.are_equal"(%i2, %i3) {compose} : (index, index) -> ()
+ "test.compare"(%i2, %i3) {compose} : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @compare_affine_max(%a: index, %b: index) {
+ %0 = affine.max affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b]
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> ()
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> ()
+ // expected-remark @below{{false}}
+ "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> ()
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
+ return
+}
+
+// -----
+
+func.func @compare_affine_min(%a: index, %b: index) {
+ %0 = affine.min affine_map<()[s0, s1] -> (s0, s1)>()[%a, %b]
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %a) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{false}}
+ "test.compare"(%0, %a) {cmp = "GT"} : (index, index) -> ()
+ // expected-error @below{{unknown}}
+ "test.compare"(%0, %a) {cmp = "LT"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%0, %a) {cmp = "LE"} : (index, index) -> ()
return
}
diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
index 0ea06737886d416..9ab03da1c9a94f8 100644
--- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir
@@ -219,3 +219,15 @@ func.func @scf_if_eq(%a: index, %b: index, %c : i1) {
"test.some_use"(%reify1) : (index) -> ()
return
}
+
+// -----
+
+func.func @compare_scf_for(%a: index, %b: index, %c: index) {
+ scf.for %iv = %a to %b step %c {
+ // expected-remark @below{{true}}
+ "test.compare"(%iv, %a) {cmp = "GE"} : (index, index) -> ()
+ // expected-remark @below{{true}}
+ "test.compare"(%iv, %b) {cmp = "LT"} : (index, index) -> ()
+ }
+ return
+}
diff --git a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
index 45520da6aeb0b04..0c90bcdb42028c8 100644
--- a/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir
@@ -163,8 +163,8 @@ func.func @dynamic_dims_are_equal(%t: tensor<?xf32>) {
%c0 = arith.constant 0 : index
%dim0 = tensor.dim %t, %c0 : tensor<?xf32>
%dim1 = tensor.dim %t, %c0 : tensor<?xf32>
- // expected-remark @below {{equal}}
- "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+ // expected-remark @below {{true}}
+ "test.compare"(%dim0, %dim1) : (index, index) -> ()
return
}
@@ -175,8 +175,8 @@ func.func @dynamic_dims_are_different(%t: tensor<?xf32>) {
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %t, %c0 : tensor<?xf32>
%val = arith.addi %dim0, %c1 : index
- // expected-remark @below {{different}}
- "test.are_equal"(%dim0, %val) : (index, index) -> ()
+ // expected-remark @below {{false}}
+ "test.compare"(%dim0, %val) : (index, index) -> ()
return
}
@@ -186,8 +186,8 @@ func.func @dynamic_dims_are_maybe_equal_1(%t: tensor<?xf32>) {
%c0 = arith.constant 0 : index
%c5 = arith.constant 5 : index
%dim0 = tensor.dim %t, %c0 : tensor<?xf32>
- // expected-error @below {{could not determine equality}}
- "test.are_equal"(%dim0, %c5) : (index, index) -> ()
+ // expected-error @below {{unknown}}
+ "test.compare"(%dim0, %c5) : (index, index) -> ()
return
}
@@ -198,7 +198,7 @@ func.func @dynamic_dims_are_maybe_equal_2(%t: tensor<?x?xf32>) {
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %t, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %t, %c1 : tensor<?x?xf32>
- // expected-error @below {{could not determine equality}}
- "test.are_equal"(%dim0, %dim1) : (index, index) -> ()
+ // expected-error @below {{unknown}}
+ "test.compare"(%dim0, %dim1) : (index, index) -> ()
return
}
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 4b2b1a06341b717..f38631054fb3c14 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -57,7 +57,7 @@ struct TestReifyValueBounds
} // namespace
-FailureOr<BoundType> parseBoundType(const std::string &type) {
+static FailureOr<BoundType> parseBoundType(const std::string &type) {
if (type == "EQ")
return BoundType::EQ;
if (type == "LB")
@@ -67,6 +67,34 @@ FailureOr<BoundType> parseBoundType(const std::string &type) {
return failure();
}
+static FailureOr<ValueBoundsConstraintSet::ComparisonOperator>
+parseComparisonOperator(const std::string &type) {
+ if (type == "EQ")
+ return ValueBoundsConstraintSet::ComparisonOperator::EQ;
+ if (type == "LT")
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ if (type == "LE")
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (type == "GT")
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (type == "GE")
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ return failure();
+}
+
+static ValueBoundsConstraintSet::ComparisonOperator
+invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
+ if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
+ return ValueBoundsConstraintSet::ComparisonOperator::GE;
+ if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
+ return ValueBoundsConstraintSet::ComparisonOperator::GT;
+ if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
+ return ValueBoundsConstraintSet::ComparisonOperator::LE;
+ if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
+ return ValueBoundsConstraintSet::ComparisonOperator::LT;
+ llvm_unreachable("unsupported comparison operator");
+}
+
/// Look for "test.reify_bound" ops in the input and replace their results with
/// the reified values.
static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
@@ -215,18 +243,34 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
return failure(result.wasInterrupted());
}
-/// Look for "test.are_equal" ops and emit errors/remarks.
+/// Look for "test.compare" ops and emit errors/remarks.
static LogicalResult testEquality(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
WalkResult result = funcOp.walk([&](Operation *op) {
- // Look for test.are_equal ops.
- if (op->getName().getStringRef() == "test.are_equal") {
+ // Look for test.compare ops.
+ if (op->getName().getStringRef() == "test.compare") {
if (op->getNumOperands() != 2 || !op->getOperand(0).getType().isIndex() ||
!op->getOperand(1).getType().isIndex()) {
op->emitOpError("invalid op");
return WalkResult::skip();
}
+
+ // Get comparison operator.
+ std::string cmpStr = "EQ";
+ if (auto cmpAttr = op->getAttrOfType<StringAttr>("cmp"))
+ cmpStr = cmpAttr.str();
+ auto cmpType = parseComparisonOperator(cmpStr);
+ if (failed(cmpType)) {
+ op->emitOpError("invalid comparison operator");
+ return WalkResult::interrupt();
+ }
+
if (op->hasAttr("compose")) {
+ if (cmpType != ValueBoundsConstraintSet::EQ) {
+ op->emitOpError(
+ "comparison operator must be EQ when 'composed' is specified");
+ return WalkResult::interrupt();
+ }
FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
op->getOperand(0), op->getOperand(1));
if (failed(delta)) {
@@ -236,16 +280,25 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
} else {
op->emitRemark("different");
}
+ return WalkResult::advance();
+ }
+
+ auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
+ return ValueBoundsConstraintSet::compare(
+ /*lhs=*/op->getOperand(0), /*lhsDim=*/std::nullopt, cmp,
+ /*rhs=*/op->getOperand(1), /*rhsDim=*/std::nullopt);
+ };
+ if (compare(*cmpType)) {
+ op->emitRemark("true");
+ } else if (*cmpType != ValueBoundsConstraintSet::EQ &&
+ compare(invertComparisonOperator(*cmpType))) {
+ op->emitRemark("false");
+ } else if (*cmpType == ValueBoundsConstraintSet::EQ &&
+ (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
+ compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
+ op->emitRemark("false");
} else {
- FailureOr<bool> equal = ValueBoundsConstraintSet::areEqual(
- op->getOperand(0), op->getOperand(1));
- if (failed(equal)) {
- op->emitError("could not determine equality");
- } else if (*equal) {
- op->emitRemark("equal");
- } else {
- op->emitRemark("different");
- }
+ op->emitError("unknown");
}
}
return WalkResult::advance();
More information about the Mlir-commits
mailing list