[Mlir-commits] [mlir] 56863ad - [MLIR][Presburger] Implement findSymbolicIntegerLexMax for IntegerRelation
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 26 07:17:15 PDT 2023
Author: iambrj
Date: 2023-07-26T19:36:29+05:30
New Revision: 56863adf8eecebb16b4e3a901ba5ddc8b7074f01
URL: https://github.com/llvm/llvm-project/commit/56863adf8eecebb16b4e3a901ba5ddc8b7074f01
DIFF: https://github.com/llvm/llvm-project/commit/56863adf8eecebb16b4e3a901ba5ddc8b7074f01.diff
LOG: [MLIR][Presburger] Implement findSymbolicIntegerLexMax for IntegerRelation
This patch implements findSymbolicIntegerLexMax for IntegerRelation
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D156023
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 9646894736de06..369b31511afdc0 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -29,7 +29,7 @@ class IntegerRelation;
class IntegerPolyhedron;
class PresburgerSet;
class PresburgerRelation;
-struct SymbolicLexMin;
+struct SymbolicLexOpt;
/// The type of bound: equal, lower bound or upper bound.
enum class BoundType { EQ, LB, UB };
@@ -659,15 +659,18 @@ class IntegerRelation {
/// x = a if b <= a, a <= c
/// x = b if a < b, b <= c
///
- /// This function is stored in the `lexmin` function in the result.
+ /// This function is stored in the `lexopt` function in the result.
/// Some assignments to the symbols might make the set empty.
/// Such points are not part of the function's domain.
/// In the above example, this happens when max(a, b) > c.
///
/// For some values of the symbols, the lexmin may be unbounded.
- /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
+ /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate
/// `PresburgerSet`, `unboundedDomain`.
- SymbolicLexMin findSymbolicIntegerLexMin() const;
+ SymbolicLexOpt findSymbolicIntegerLexMin() const;
+
+ /// Same as findSymbolicIntegerLexMin but produces lexmax instead of lexmin
+ SymbolicLexOpt findSymbolicIntegerLexMax() const;
/// Return the set
diff erence of this set and the given set, i.e.,
/// return `this \ set`.
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 470d483cbb5648..79a42d6c38d411 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -43,7 +43,7 @@ class GBRSimplex;
/// these constraints that are redundant, i.e. a subset of constraints that
/// doesn't constrain the affine set further after adding the non-redundant
/// constraints. The LexSimplex class provides support for computing the
-/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class
+/// lexicographic minimum of an IntegerRelation. The SymbolicLexOpt class
/// provides support for computing symbolic lexicographic minimums. All of these
/// classes can be constructed from an IntegerRelation, and all inherit common
/// functionality from SimplexBase.
@@ -529,18 +529,18 @@ class LexSimplex : public LexSimplexBase {
std::optional<unsigned> maybeGetNonIntegralVarRow() const;
};
-/// Represents the result of a symbolic lexicographic minimization computation.
-struct SymbolicLexMin {
- SymbolicLexMin(const PresburgerSpace &space)
- : lexmin(space),
+/// Represents the result of a symbolic lexicographic optimization computation.
+struct SymbolicLexOpt {
+ SymbolicLexOpt(const PresburgerSpace &space)
+ : lexopt(space),
unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {}
- /// This maps assignments of symbols to the corresponding lexmin.
+ /// This maps assignments of symbols to the corresponding lexopt.
/// Takes no value when no integer sample exists for the assignment or if the
- /// lexmin is unbounded.
- PWMAFunction lexmin;
- /// Contains all assignments to the symbols that made the lexmin unbounded.
- /// Note that the symbols of the input set to the symbolic lexmin are dims
+ /// lexopt is unbounded.
+ PWMAFunction lexopt;
+ /// Contains all assignments to the symbols that made the lexopt unbounded.
+ /// Note that the symbols of the input set to the symbolic lexopt are dims
/// of this PrebsurgerSet.
PresburgerSet unboundedDomain;
};
@@ -575,13 +575,13 @@ struct SymbolicLexMin {
/// where it is.
class SymbolicLexSimplex : public LexSimplexBase {
public:
- /// `constraints` is the set for which the symbolic lexmin will be computed.
- /// `symbolDomain` is the set of values of the symbols for which the lexmin
+ /// `constraints` is the set for which the symbolic lexopt will be computed.
+ /// `symbolDomain` is the set of values of the symbols for which the lexopt
/// will be computed. `symbolDomain` should have a dim var for every symbol in
/// `constraints`, and no other vars. `isSymbol` specifies which vars of
/// `constraints` should be considered as symbols.
///
- /// The resulting SymbolicLexMin's space will be compatible with that of
+ /// The resulting SymbolicLexOpt's space will be compatible with that of
/// symbolDomain.
SymbolicLexSimplex(const IntegerRelation &constraints,
const IntegerPolyhedron &symbolDomain,
@@ -594,7 +594,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
"there must be some non-symbols to optimize!");
}
- /// An overload to select some subrange of ids as symbols for lexmin.
+ /// An overload to select some subrange of ids as symbols for lexopt.
/// The symbol ids are the range of ids with absolute index
/// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
@@ -604,7 +604,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
symbolOffset,
symbolDomain.getNumVars())) {}
- /// An overload to select the symbols of `constraints` as symbols for lexmin.
+ /// An overload to select the symbols of `constraints` as symbols for lexopt.
SymbolicLexSimplex(const IntegerRelation &constraints,
const IntegerPolyhedron &symbolDomain)
: SymbolicLexSimplex(constraints,
@@ -614,7 +614,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
"symbolDomain must have as many vars as constraints has symbols!");
}
- /// The lexmin will be stored as a function `lexmin` from symbols to
+ /// The lexmin will be stored as a function `lexopt` from symbols to
/// non-symbols in the result.
///
/// For some values of the symbols, the lexmin may be unbounded.
@@ -622,7 +622,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
///
/// The spaces of the sets in the result are compatible with the symbolDomain
/// passed in the SymbolicLexSimplex constructor.
- SymbolicLexMin computeSymbolicIntegerLexMin();
+ SymbolicLexOpt computeSymbolicIntegerLexMin();
private:
/// Perform all pivots that do not require branching.
@@ -670,7 +670,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// Record a lexmin. The tableau must be consistent with all variables
/// having symbolic samples with integer coefficients.
- void recordOutput(SymbolicLexMin &result) const;
+ void recordOutput(SymbolicLexOpt &result) const;
/// The symbol domain.
IntegerPolyhedron domainPoly;
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 75c6adbf6bbc2b..b2359aea461ae7 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -172,7 +172,7 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
return PresburgerRelation(*this);
// Move all the non-div locals to the end, as the current API to
- // SymbolicLexMin requires these to form a contiguous range.
+ // SymbolicLexOpt requires these to form a contiguous range.
//
// Take a copy so we can perform mutations.
IntegerRelation copy = *this;
@@ -211,13 +211,13 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
// exists, which is the union of the domain of the returned lexmin function
// and the returned set of assignments to the "symbols" that makes the lexmin
// unbounded.
- SymbolicLexMin lexminResult =
+ SymbolicLexOpt lexminResult =
SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
/*numDims=*/copy.getNumVars() - numNonDivLocals)))
.computeSymbolicIntegerLexMin();
PresburgerRelation result =
- lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
+ lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain);
// The result set might lie in the wrong space -- all its ids are dims.
// Set it to the desired space and return.
@@ -227,7 +227,7 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
return result;
}
-SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
+SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMin() const {
// Symbol and Domain vars will be used as symbols for symbolic lexmin.
// In other words, for every value of the symbols and domain, return the
// lexmin value of the (range, locals).
@@ -239,7 +239,7 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
// The resultant space of lexmin is the space of the relation itself.
- SymbolicLexMin result =
+ SymbolicLexOpt result =
SymbolicLexSimplex(*this,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
/*numDims=*/getNumDomainVars(),
@@ -249,11 +249,49 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
// We want to return only the lexmin over the dims, so strip the locals from
// the computed lexmin.
- result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(),
- result.lexmin.getNumOutputs());
+ result.lexopt.removeOutputs(result.lexopt.getNumOutputs() - getNumLocalVars(),
+ result.lexopt.getNumOutputs());
return result;
}
+/// findSymbolicIntegerLexMax is implemented using findSymbolicIntegerLexMin as
+/// follows:
+/// 1. A new relation is created which is `this` relation with the sign of
+/// each dimension variable in range flipped;
+/// 2. findSymbolicIntegerLexMin is called on the range negated relation to
+/// compute the negated lexmax of `this` relation;
+/// 3. The sign of the negated lexmax is flipped and returned.
+SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const {
+ IntegerRelation flippedRel = *this;
+ // Flip range sign by flipping the sign of range variables in all constraints.
+ for (unsigned j = getNumDomainVars(),
+ b = getNumDomainVars() + getNumRangeVars();
+ j < b; j++) {
+ for (unsigned i = 0, a = getNumEqualities(); i < a; i++)
+ flippedRel.atEq(i, j) = -1 * atEq(i, j);
+ for (unsigned i = 0, a = getNumInequalities(); i < a; i++)
+ flippedRel.atIneq(i, j) = -1 * atIneq(i, j);
+ }
+ // Compute negated lexmax by computing lexmin.
+ SymbolicLexOpt flippedSymbolicIntegerLexMax =
+ flippedRel.findSymbolicIntegerLexMin(),
+ symbolicIntegerLexMax(
+ flippedSymbolicIntegerLexMax.lexopt.getSpace());
+ // Get lexmax by flipping range sign in the PWMA constraints.
+ for (auto &flippedPiece :
+ flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) {
+ Matrix mat = flippedPiece.output.getOutputMatrix();
+ for (unsigned i = 0, e = mat.getNumRows(); i < e; i++)
+ mat.negateRow(i);
+ MultiAffineFunction maf(flippedPiece.output.getSpace(), mat);
+ PWMAFunction::Piece piece = {flippedPiece.domain, maf};
+ symbolicIntegerLexMax.lexopt.addPiece(piece);
+ }
+ symbolicIntegerLexMax.unboundedDomain =
+ flippedSymbolicIntegerLexMax.unboundedDomain;
+ return symbolicIntegerLexMax;
+}
+
PresburgerRelation
IntegerRelation::subtract(const PresburgerRelation &set) const {
return PresburgerRelation(*this).subtract(set);
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 61c39bd315f187..eff312b69e1de1 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -435,9 +435,9 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
return moveRowUnknownToColumn(cutRow);
}
-void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
+void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const {
Matrix output(0, domainPoly.getNumVars() + 1);
- output.reserveRows(result.lexmin.getNumOutputs());
+ output.reserveRows(result.lexopt.getNumOutputs());
for (const Unknown &u : var) {
if (u.isSymbol)
continue;
@@ -469,10 +469,10 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
}
// Store the output in a MultiAffineFunction and add it the result.
- PresburgerSpace funcSpace = result.lexmin.getSpace();
+ PresburgerSpace funcSpace = result.lexopt.getSpace();
funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars());
- result.lexmin.addPiece(
+ result.lexopt.addPiece(
{PresburgerSet(domainPoly),
MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())});
}
@@ -515,8 +515,8 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
return success();
}
-SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
- SymbolicLexMin result(PresburgerSpace::getRelationSpace(
+SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
+ SymbolicLexOpt result(PresburgerSpace::getRelationSpace(
/*numDomain=*/domainPoly.getNumDimVars(),
/*numRange=*/var.size() - nSymbol,
/*numSymbols=*/domainPoly.getNumSymbolVars()));
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 6beb9384c8bf22..ba035e84ff1fd7 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1198,13 +1198,13 @@ void expectSymbolicIntegerLexMin(
ASSERT_NE(poly.getNumDimVars(), 0u);
ASSERT_NE(poly.getNumSymbolVars(), 0u);
- SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
+ SymbolicLexOpt result = poly.findSymbolicIntegerLexMin();
if (expectedLexminRepr.empty()) {
- EXPECT_TRUE(result.lexmin.getDomain().isIntegerEmpty());
+ EXPECT_TRUE(result.lexopt.getDomain().isIntegerEmpty());
} else {
PWMAFunction expectedLexmin = parsePWMAF(expectedLexminRepr);
- EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin));
+ EXPECT_TRUE(result.lexopt.isEqual(expectedLexmin));
}
if (expectedUnboundedDomainRepr.empty()) {
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 1a9241bfe3ffcd..dd20e058e358dd 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -126,7 +126,7 @@ TEST(IntegerRelationTest, applyDomainAndRange) {
}
TEST(IntegerRelationTest, symbolicLexmin) {
- SymbolicLexMin lexmin =
+ SymbolicLexOpt lexmin =
parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1)
.findSymbolicIntegerLexMin();
@@ -135,5 +135,43 @@ TEST(IntegerRelationTest, symbolicLexmin) {
{"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (b)"}, // b
});
EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty());
- EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin));
+ EXPECT_TRUE(lexmin.lexopt.isEqual(expectedLexmin));
+}
+
+TEST(IntegerRelationTest, symbolicLexmax) {
+ SymbolicLexOpt lexmax1 =
+ parseRelationFromSet("(a, x)[b] : (a - x >= 0, b - x >= 0)", 1)
+ .findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexmax1 = parsePWMAF({
+ {"(a)[b] : (a - b >= 0)", "(a)[b] -> (b)"},
+ {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (a)"},
+ });
+
+ SymbolicLexOpt lexmax2 =
+ parseRelationFromSet("(i, j)[N] : (i >= 0, j >= 0, N - i - j >= 0)", 1)
+ .findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexmax2 = parsePWMAF({
+ {"(i)[N] : (i >= 0, N - i >= 0)", "(i)[N] -> (N - i)"},
+ });
+
+ SymbolicLexOpt lexmax3 =
+ parseRelationFromSet("(x, y)[N] : (x >= 0, 2 * N - x >= 0, y >= 0, x - y "
+ "+ 2 * N >= 0, 4 * N - x - y >= 0)",
+ 1)
+ .findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexmax3 =
+ parsePWMAF({{"(x)[N] : (x >= 0, 2 * N - x >= 0, x - N - 1 >= 0)",
+ "(x)[N] -> (4 * N - x)"},
+ {"(x)[N] : (x >= 0, 2 * N - x >= 0, -x + N >= 0)",
+ "(x)[N] -> (x + 2 * N)"}});
+
+ EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexmax1));
+ EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexmax2));
+ EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexmax3));
}
More information about the Mlir-commits
mailing list