[Mlir-commits] [mlir] 4a57f5d - [MLIR] PresburgerSet: support divisions in operations
Arjun P
llvmlistbot at llvm.org
Fri Sep 24 03:06:51 PDT 2021
Author: Arjun P
Date: 2021-09-24T15:36:47+05:30
New Revision: 4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77
URL: https://github.com/llvm/llvm-project/commit/4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77
DIFF: https://github.com/llvm/llvm-project/commit/4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77.diff
LOG: [MLIR] PresburgerSet: support divisions in operations
Add support for intersecting, subtracting, complementing and checking equality of sets having divisions.
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D110138
Added:
Modified:
mlir/include/mlir/Analysis/PresburgerSet.h
mlir/lib/Analysis/PresburgerSet.cpp
mlir/unittests/Analysis/PresburgerSetTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h
index 4863c876745aa..41dfc0018efab 100644
--- a/mlir/include/mlir/Analysis/PresburgerSet.h
+++ b/mlir/include/mlir/Analysis/PresburgerSet.h
@@ -67,17 +67,18 @@ class PresburgerSet {
void print(raw_ostream &os) const;
void dump() const;
- /// Return the complement of this set. Computing the complement of a set
- /// containing divisions is not yet supported.
+ /// Return the complement of this set. All local variables in the set must
+ /// correspond to floor divisions.
PresburgerSet complement() const;
/// Return the set
diff erence of this set and the given set, i.e.,
- /// return `this \ set`. Subtracting when either set contains divisions is not
- /// yet supported.
+ /// return `this \ set`. All local variables in `set` must correspond
+ /// to floor divisions, but local variables in `this` need not correspond to
+ /// divisions.
PresburgerSet subtract(const PresburgerSet &set) const;
/// Return true if this set is equal to the given set, and false otherwise.
- /// Checking equality when either set contains divisions is not yet supported.
+ /// All local variables in both sets must correspond to floor divisions.
bool isEqual(const PresburgerSet &set) const;
/// Return a universe set of the specified type that contains all points.
diff --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp
index 81fc8cc1eafd2..7beeb99d340bb 100644
--- a/mlir/lib/Analysis/PresburgerSet.cpp
+++ b/mlir/lib/Analysis/PresburgerSet.cpp
@@ -106,16 +106,20 @@ PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
//
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
// as (S_1 and T_1) or (S_1 and T_2) or ...
+//
+// If S_i or T_j have local variables, then S_i and T_j contains the local
+// variables of both.
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
for (const FlatAffineConstraints &csA : flatAffineConstraints) {
for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
- FlatAffineConstraints intersection(csA);
- intersection.append(csB);
- if (!intersection.isEmpty())
- result.unionFACInPlace(std::move(intersection));
+ FlatAffineConstraints csACopy = csA, csBCopy = csB;
+ csACopy.mergeLocalIds(csBCopy);
+ csACopy.append(std::move(csBCopy));
+ if (!csACopy.isEmpty())
+ result.unionFACInPlace(std::move(csACopy));
}
}
return result;
@@ -160,6 +164,17 @@ static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
///
+/// Note that the same approach works even if an inequality involves a floor
+/// division. For example, the complement of x <= 7*floor(x/7) is still
+/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
+/// (or the complements of those inequalities), b \ s_i may contain the
+/// divisions present in both b and s_i. Therefore, we need to add the local
+/// division variables of both b and s_i to each part in the result. This means
+/// adding the local variables of both b and s_i, as well as the corresponding
+/// division inequalities to each part. Since the division inequalities are
+/// added to each part, we can skip the parts where the complement of any
+/// division inequality is added, as these parts will become empty anyway.
+///
/// As a heuristic, we try adding all the constraints and check if simplex
/// says that the intersection is empty. If it is, then subtracting this FAC is
/// a no-op and we just skip it. Also, in the process we find out that some
@@ -174,27 +189,63 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
result.unionFACInPlace(b);
return;
}
- const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
- assert(sI.getNumLocalIds() == 0 &&
- "Subtracting sets with divisions is not yet supported!");
+ FlatAffineConstraints sI = s.getFlatAffineConstraints(i);
+ unsigned bInitNumLocals = b.getNumLocalIds();
+
+ // Find out which inequalities of sI correspond to division inequalities for
+ // the local variables of sI.
+ std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> repr(
+ sI.getNumLocalIds());
+ sI.getLocalReprLbUbPairs(repr);
+
+ // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
+ // sI's locals.
+ b.mergeLocalIds(sI);
+
+ // Mark which inequalities of sI are division inequalities and add all such
+ // inequalities to b.
+ llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
+ for (Optional<std::pair<unsigned, unsigned>> &maybePair : repr) {
+ assert(maybePair &&
+ "Subtraction is not supported when a representation of the local "
+ "variables of the subtrahend cannot be found!");
+
+ b.addInequality(sI.getInequality(maybePair->first));
+ b.addInequality(sI.getInequality(maybePair->second));
+
+ assert(maybePair->first != maybePair->second &&
+ "Upper and lower bounds must be
diff erent inequalities!");
+ isDivInequality[maybePair->first] = true;
+ isDivInequality[maybePair->second] = true;
+ }
+
unsigned initialSnapshot = simplex.getSnapshot();
unsigned offset = simplex.getNumConstraints();
+ unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
+ simplex.appendVariable(numLocalsAdded);
+
+ unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectFlatAffineConstraints(sI);
if (simplex.isEmpty()) {
/// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
simplex.rollback(initialSnapshot);
+ b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
+ b.getNumLocalIds());
subtractRecursively(b, simplex, s, i + 1, result);
return;
}
simplex.detectRedundant();
- llvm::SmallBitVector isMarkedRedundant;
- for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
- j++)
- isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
- simplex.rollback(initialSnapshot);
+ // Equalities are added to simplex as a pair of inequalities.
+ unsigned totalNewSimplexInequalities =
+ 2 * sI.getNumEqualities() + sI.getNumInequalities();
+ llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
+ for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
+ isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
+
+ simplex.rollback(snapshotBeforeIntersect);
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
@@ -223,20 +274,28 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
- unsigned originalNumIneqs = b.getNumInequalities();
- unsigned originalNumEqs = b.getNumEqualities();
+ unsigned bInitNumIneqs = b.getNumInequalities();
+ unsigned bInitNumEqs = b.getNumEqualities();
+ // Process all the inequalities, ignoring redundant inequalities and division
+ // inequalities. The result is correct whether or not we ignore these, but
+ // ignoring them makes the result simpler.
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
if (isMarkedRedundant[j])
continue;
+ if (isDivInequality[j])
+ continue;
processInequality(sI.getInequality(j));
}
offset = sI.getNumInequalities();
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
- const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
- // Same as the above loop for inequalities, done once each for the positive
- // and negative inequalities that make up this equality.
+ ArrayRef<int64_t> coeffs = sI.getEquality(j);
+ // For each equality, process the positive and negative inequalities that
+ // make up this equality. If Simplex found an inequality to be redundant, we
+ // skip it as above to make the result simpler. Divisions are always
+ // represented in terms of inequalities and not equalities, so we do not
+ // check for division inequalities here.
if (!isMarkedRedundant[offset + 2 * j])
processInequality(coeffs);
if (!isMarkedRedundant[offset + 2 * j + 1])
@@ -244,11 +303,10 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
}
// Rollback b and simplex to their initial states.
- for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
- b.removeInequality(i - 1);
-
- for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
- b.removeEquality(i - 1);
+ b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
+ b.getNumLocalIds());
+ b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
+ b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
simplex.rollback(initialSnapshot);
}
@@ -261,8 +319,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
const PresburgerSet &set) {
assertDimensionsCompatible(fac, set);
- assert(fac.getNumLocalIds() == 0 &&
- "Subtracting sets with divisions is not yet supported!");
if (fac.isEmptyByGCDTest())
return PresburgerSet::getEmptySet(fac.getNumDimIds(),
fac.getNumSymbolIds());
diff --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp
index ac065fa74879c..4ae76d7f0c329 100644
--- a/mlir/unittests/Analysis/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp
@@ -80,12 +80,17 @@ static void testComplementAtPoints(PresburgerSet s,
}
/// Construct a FlatAffineConstraints from a set of inequality and
-/// equality constraints.
+/// equality constraints. `numIds` is the total number of ids, of which
+/// `numLocals` is the number of local ids.
static FlatAffineConstraints
-makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
- ArrayRef<SmallVector<int64_t, 4>> eqs) {
- FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims,
- /*numSymbols=*/0, /*numLocals=*/0);
+makeFACFromConstraints(unsigned numIds, ArrayRef<SmallVector<int64_t, 4>> ineqs,
+ ArrayRef<SmallVector<int64_t, 4>> eqs,
+ unsigned numLocals = 0) {
+ FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(),
+ /*numReservedEqualities=*/eqs.size(),
+ /*numReservedCols=*/numIds + 1,
+ /*numDims=*/numIds - numLocals,
+ /*numSymbols=*/0, numLocals);
for (const SmallVector<int64_t, 4> &eq : eqs)
fac.addEquality(eq);
for (const SmallVector<int64_t, 4> &ineq : ineqs)
@@ -93,14 +98,22 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
return fac;
}
+/// Construct a FlatAffineConstraints having `numDims` dimensions from the given
+/// set of inequality constraints. This is a convenience function to be used
+/// when the FAC to be constructed does not have any local ids and does not have
+/// equalties.
static FlatAffineConstraints
-makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
- return makeFACFromConstraints(dims, ineqs, {});
+makeFACFromIneqs(unsigned numDims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
+ return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{});
}
-static PresburgerSet makeSetFromFACs(unsigned dims,
+/// Construct a PresburgerSet having `numDims` dimensions and no symbols from
+/// the given list of FlatAffineConstraints. Each FAC in `facs` should also have
+/// `numDims` dimensions and no symbols, although it can have any number of
+/// local ids.
+static PresburgerSet makeSetFromFACs(unsigned numDims,
ArrayRef<FlatAffineConstraints> facs) {
- PresburgerSet set = PresburgerSet::getEmptySet(dims);
+ PresburgerSet set = PresburgerSet::getEmptySet(numDims);
for (const FlatAffineConstraints &fac : facs)
set.unionFACInPlace(fac);
return set;
@@ -592,4 +605,37 @@ TEST(SetTest, isEqual) {
EXPECT_FALSE(rect.complement().isEqual(square.complement()));
}
+void expectEqual(PresburgerSet s, PresburgerSet t) {
+ EXPECT_TRUE(s.isEqual(t));
+}
+
+void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
+
+TEST(SetTest, divisions) {
+ // Note: we currently need to add the equalities as inequalities to the FAC
+ // since detecting divisions based on equalities is not yet supported.
+
+ // evens = {x : exists q, x = 2q}.
+ PresburgerSet evens{
+ makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)};
+ // odds = {x : exists q, x = 2q + 1}.
+ PresburgerSet odds{
+ makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)};
+ // multiples6 = {x : exists q, x = 6q}.
+ PresburgerSet multiples3{
+ makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)};
+ // multiples6 = {x : exists q, x = 6q}.
+ PresburgerSet multiples6{
+ makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)};
+
+ // evens /\ odds = empty.
+ expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
+ // evens U odds = universe.
+ expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
+ expectEqual(evens.complement(), odds);
+ expectEqual(odds.complement(), evens);
+ // even multiples of 3 = multiples of 6.
+ expectEqual(multiples3.intersect(evens), multiples6);
+}
+
} // namespace mlir
More information about the Mlir-commits
mailing list