[Mlir-commits] [mlir] 8a7ead6 - [MLIR][Presburger] Support computing a representation of a set that only has locals that are divs
Arjun P
llvmlistbot at llvm.org
Sat Jun 25 06:23:34 PDT 2022
Author: Arjun P
Date: 2022-06-25T14:23:32+01:00
New Revision: 8a7ead691bad29b86017d9e42fa63a57c8c0d629
URL: https://github.com/llvm/llvm-project/commit/8a7ead691bad29b86017d9e42fa63a57c8c0d629
DIFF: https://github.com/llvm/llvm-project/commit/8a7ead691bad29b86017d9e42fa63a57c8c0d629.diff
LOG: [MLIR][Presburger] Support computing a representation of a set that only has locals that are divs
This paves the way for integer-exact projection, and for supporting
non-division locals in subtraction, complement, and equality checks.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D127463
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4a866c17dd3b1..935307d4bb88e 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -26,6 +26,7 @@ namespace presburger {
class IntegerRelation;
class IntegerPolyhedron;
+class PresburgerSet;
/// An IntegerRelation represents the set of points from a PresburgerSpace that
/// satisfy a list of affine constraints. Affine constraints can be inequalities
@@ -93,6 +94,17 @@ class IntegerRelation {
/// Returns a reference to the underlying space.
const PresburgerSpace &getSpace() const { return space; }
+ /// Set the space to `oSpace`, which should have the same number of ids as
+ /// the current space.
+ void setSpace(const PresburgerSpace &oSpace);
+
+ /// Set the space to `oSpace`, which should not have any local ids.
+ /// `oSpace` can have fewer ids than the current space; in that case, the
+ /// the extra ids in `this` that are not accounted for by `oSpace` will be
+ /// considered as local ids. `oSpace` should not have more ids than the
+ /// current space; this will result in an assert failure.
+ void setSpaceExceptLocals(const PresburgerSpace &oSpace);
+
/// Returns a copy of the space without locals.
PresburgerSpace getSpaceWithoutLocals() const {
return PresburgerSpace::getRelationSpace(space.getNumDomainIds(),
@@ -497,6 +509,9 @@ class IntegerRelation {
/// locals that have been added to `this`.
unsigned mergeLocalIds(IntegerRelation &other);
+ /// Check whether all local ids have a division representation.
+ bool hasOnlyDivLocals() const;
+
/// Changes the partition between dimensions and symbols. Depending on the new
/// symbol count, either a chunk of dimensional identifiers immediately before
/// the split become symbols, or some of the symbols immediately after the
@@ -739,6 +754,12 @@ class IntegerPolyhedron : public IntegerRelation {
/// first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+ /// Compute an equivalent representation of the same set, such that all local
+ /// ids have division representations. This representation may involve
+ /// local ids that correspond to divisions, and may also be a union of convex
+ /// disjuncts.
+ PresburgerSet computeReprWithOnlyDivLocals() const;
+
/// Compute the symbolic integer lexmin of the polyhedron.
/// This finds, for every assignment to the symbols, the lexicographically
/// minimum value attained by the dimensions. For example, the symbolic lexmin
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index e4aa36599537b..89a3deb30e689 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -55,6 +55,14 @@ class PresburgerRelation {
const PresburgerSpace &getSpace() const { return space; }
+ /// Set the space to `oSpace`. `oSpace` should not contain any local ids.
+ /// `oSpace` need not have the same number of ids as the current space;
+ /// it could have more or less. If it has less, the extra ids become
+ /// locals of the disjuncts. It can also have more, in which case the
+ /// disjuncts will have fewer locals. If its total number of ids
+ /// exceeds that of some disjunct, an assert failure will occur.
+ void setSpace(const PresburgerSpace &oSpace);
+
/// Return a reference to the list of disjuncts.
ArrayRef<IntegerRelation> getAllDisjuncts() const;
@@ -117,6 +125,9 @@ class PresburgerRelation {
/// disjuncts in the union.
PresburgerRelation coalesce() const;
+ /// Check whether all local ids in all disjuncts have a div representation.
+ bool hasOnlyDivLocals() const;
+
/// Print the set's internal state.
void print(raw_ostream &os) const;
void dump() const;
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 4caaf78bd4716..f583c59ef24be 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -572,10 +572,28 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// `constraints`, and no other ids.
SymbolicLexSimplex(const IntegerPolyhedron &constraints,
const IntegerPolyhedron &symbolDomain)
- : LexSimplexBase(constraints), domainPoly(symbolDomain),
- domainSimplex(symbolDomain) {
- assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
- assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
+ : SymbolicLexSimplex(constraints,
+ constraints.getIdKindOffset(IdKind::Symbol),
+ symbolDomain) {
+ assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds());
+ }
+
+ /// An overload to select some other subrange of ids as symbols for lexmin.
+ /// The symbol ids are the range of ids with absolute index
+ /// [symbolOffset, symbolOffset + symbolDomain.getNumIds())
+ /// symbolDomain should only have dim ids.
+ SymbolicLexSimplex(const IntegerPolyhedron &constraints,
+ unsigned symbolOffset,
+ const IntegerPolyhedron &symbolDomain)
+ : LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset,
+ symbolDomain.getNumIds()),
+ domainPoly(symbolDomain), domainSimplex(symbolDomain) {
+ // TODO consider supporting this case. It amounts
+ // to just returning the input constraints.
+ assert(domainPoly.getNumIds() > 0 &&
+ "there must be some non-symbols to optimize!");
+ assert(domainPoly.getNumIds() == domainPoly.getNumDimIds());
+ intersectIntegerRelation(constraints);
}
/// The lexmin will be stored as a function `lexmin` from symbols to
@@ -583,6 +601,9 @@ class SymbolicLexSimplex : public LexSimplexBase {
///
/// For some values of the symbols, the lexmin may be unbounded.
/// These parts of the symbol domain will be stored in `unboundedDomain`.
+ ///
+ /// The spaces of the sets in the result are compatible with the symbolDomain
+ /// passed in the SymbolicLexSimplex constructor.
SymbolicLexMin computeSymbolicIntegerLexMin();
private:
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 25d89f93d93d1..7376747663e62 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -38,6 +38,19 @@ std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
return std::make_unique<IntegerPolyhedron>(*this);
}
+void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
+ assert(space.getNumIds() == oSpace.getNumIds() && "invalid space!");
+ space = oSpace;
+}
+
+void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
+ assert(oSpace.getNumLocalIds() == 0 && "no locals should be present!");
+ assert(oSpace.getNumIds() <= getNumIds() && "invalid space!");
+ unsigned newNumLocals = getNumIds() - oSpace.getNumIds();
+ space = oSpace;
+ space.insertId(IdKind::Local, 0, newNumLocals);
+}
+
void IntegerRelation::append(const IntegerRelation &other) {
assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
@@ -152,6 +165,67 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
}
+PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
+ // If there are no locals, we're done.
+ if (getNumLocalIds() == 0)
+ return PresburgerSet(*this);
+
+ // Move all the non-div locals to the end, as the current API to
+ // SymbolicLexMin requires these to form a contiguous range.
+ //
+ // Take a copy so we can perform mutations.
+ IntegerPolyhedron copy = *this;
+ std::vector<MaybeLocalRepr> reprs;
+ copy.getLocalReprs(reprs);
+
+ // Iterate through all the locals. The last `numNonDivLocals` are the locals
+ // that have been scanned already and do not have division representations.
+ unsigned numNonDivLocals = 0;
+ unsigned offset = copy.getIdKindOffset(IdKind::Local);
+ for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) {
+ if (!reprs[i]) {
+ // Whenever we come across a local that does not have a division
+ // representation, we swap it to the `numNonDivLocals`-th last position
+ // and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
+ copy.swapId(offset + i, offset + e - numNonDivLocals - 1);
+ std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
+ ++numNonDivLocals;
+ continue;
+ }
+ ++i;
+ }
+
+ // If there are no non-div locals, we're done.
+ if (numNonDivLocals == 0)
+ return PresburgerSet(*this);
+
+ // We computeSymbolicIntegerLexMin by considering the non-div locals as
+ // "non-symbols" and considering everything else as "symbols". This will
+ // compute a function mapping assignments to "symbols" to the
+ // lexicographically minimal valid assignment of "non-symbols", when a
+ // satisfying assignment exists. It separately returns the set of assignments
+ // to the "symbols" such that a satisfying assignment to the "non-symbols"
+ // exists but the lexmin is unbounded. We basically want to find the set of
+ // values of the "symbols" such that an assignment to the "non-symbols"
+ // exists, which is the union of the domain of the returned lexmin function
+ // and the returned set of assignments to the "symbols" that makes the lexmin
+ // unbounded.
+ SymbolicLexMin lexminResult =
+ SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
+ IntegerPolyhedron(PresburgerSpace::getSetSpace(
+ /*numDims=*/copy.getNumIds() - numNonDivLocals)))
+ .computeSymbolicIntegerLexMin();
+ PresburgerSet result =
+ lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
+
+ // The result set might lie in the wrong space -- all its ids are dims.
+ // Set it to the desired space and return.
+ PresburgerSpace space = getSpace();
+ space.removeIdRange(IdKind::Local, 0, getNumLocalIds());
+ result.setSpace(space);
+ return result;
+}
+
SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
@@ -1120,6 +1194,13 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
return relA.getNumLocalIds() - oldALocals;
}
+bool IntegerRelation::hasOnlyDivLocals() const {
+ std::vector<MaybeLocalRepr> reprs;
+ getLocalReprs(reprs);
+ return llvm::all_of(reprs,
+ [](const MaybeLocalRepr &repr) { return bool(repr); });
+}
+
void IntegerRelation::removeDuplicateDivs() {
std::vector<SmallVector<int64_t, 8>> divs;
SmallVector<unsigned, 4> denoms;
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 9ce59d769d43c..1b5d48b9cf896 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -21,6 +21,13 @@ PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
unionInPlace(disjunct);
}
+void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
+ assert(space.getNumLocalIds() == 0 && "no locals should be present");
+ space = oSpace;
+ for (IntegerRelation &disjunct : disjuncts)
+ disjunct.setSpaceExceptLocals(space);
+}
+
unsigned PresburgerRelation::getNumDisjuncts() const {
return disjuncts.size();
}
@@ -770,6 +777,12 @@ PresburgerRelation PresburgerRelation::coalesce() const {
return SetCoalescer(*this).coalesce();
}
+bool PresburgerRelation::hasOnlyDivLocals() const {
+ return llvm::all_of(disjuncts, [](const IntegerRelation &rel) {
+ return rel.hasOnlyDivLocals();
+ });
+}
+
void PresburgerRelation::print(raw_ostream &os) const {
os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
for (const IntegerRelation &disjunct : disjuncts) {
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index ba3a0024f9732..0c98f488ea074 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -751,6 +751,54 @@ TEST(SetTest, computeVolume) {
/*resultBound=*/{});
}
+// The last `numToProject` dims will be projected out, i.e., converted to
+// locals.
+void testComputeReprAtPoints(IntegerPolyhedron poly,
+ ArrayRef<SmallVector<int64_t, 4>> points,
+ unsigned numToProject) {
+ poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
+ poly.getNumDimIds(), IdKind::Local);
+ PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
+ EXPECT_TRUE(repr.hasOnlyDivLocals());
+ EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
+ for (const SmallVector<int64_t, 4> &point : points) {
+ EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(),
+ repr.containsPoint(point));
+ }
+}
+
+void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
+ unsigned numToProject) {
+ poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
+ poly.getNumDimIds(), IdKind::Local);
+ PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
+ EXPECT_TRUE(repr.hasOnlyDivLocals());
+ EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
+ EXPECT_TRUE(repr.isEqual(expected));
+}
+
+TEST(SetTest, computeReprWithOnlyDivLocals) {
+ testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"),
+ {{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}},
+ /*numToProject=*/0);
+ testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"),
+ {{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1);
+
+ // Tests to check that the space is preserved.
+ testComputeReprAtPoints(parsePoly("(x, y)[z, w] : ()"), {},
+ /*numToProject=*/1);
+ testComputeReprAtPoints(parsePoly("(x, y)[z, w] : (z - (w floordiv 2) == 0)"),
+ {},
+ /*numToProject=*/1);
+
+ // Bezout's lemma: if a, b are constants,
+ // the set of values that ax + by can take is all multiples of gcd(a, b).
+ testComputeRepr(
+ parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"),
+ PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})),
+ /*numToProject=*/2);
+}
+
TEST(SetTest, subtractOutputSizeRegression) {
PresburgerSet set1 =
parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});
More information about the Mlir-commits
mailing list