[Mlir-commits] [mlir] 63dead2 - Introduce subtraction for FlatAffineConstraints

Alex Zinenko llvmlistbot at llvm.org
Wed Oct 7 08:31:17 PDT 2020


Author: Arjun P
Date: 2020-10-07T17:31:06+02:00
New Revision: 63dead2096cd6a2190ba11071938b937be8bf159

URL: https://github.com/llvm/llvm-project/commit/63dead2096cd6a2190ba11071938b937be8bf159
DIFF: https://github.com/llvm/llvm-project/commit/63dead2096cd6a2190ba11071938b937be8bf159.diff

LOG: Introduce subtraction for FlatAffineConstraints

Subtraction is a foundational arithmetic operation that is often used when computing, for example, data transfer sets or cache hits. Since the result of subtraction need not be a convex polytope, a new class `PresburgerSet` is introduced to represent unions of convex polytopes.

Reviewed By: ftynse, bondhugula

Differential Revision: https://reviews.llvm.org/D87068

Added: 
    mlir/include/mlir/Analysis/PresburgerSet.h
    mlir/lib/Analysis/PresburgerSet.cpp
    mlir/unittests/Analysis/PresburgerSetTest.cpp

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/lib/Analysis/CMakeLists.txt
    mlir/lib/Analysis/Presburger/CMakeLists.txt
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/unittests/Analysis/AffineStructuresTest.cpp
    mlir/unittests/Analysis/CMakeLists.txt
    mlir/unittests/Analysis/Presburger/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index d64a24e713d1..25071db100e3 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -97,6 +97,13 @@ class FlatAffineConstraints {
       ids.append(idArgs.begin(), idArgs.end());
   }
 
+  /// Return a system with no constraints, i.e., one which is satisfied by all
+  /// points.
+  static FlatAffineConstraints getUniverse(unsigned numDims = 0,
+                                           unsigned numSymbols = 0) {
+    return FlatAffineConstraints(numDims, numSymbols);
+  }
+
   /// Create a flat affine constraint system from an AffineValueMap or a list of
   /// these. The constructed system will only include equalities.
   explicit FlatAffineConstraints(const AffineValueMap &avm);
@@ -153,6 +160,10 @@ class FlatAffineConstraints {
   /// Returns such a point if one exists, or an empty Optional otherwise.
   Optional<SmallVector<int64_t, 8>> findIntegerSample() const;
 
+  /// Returns true if the given point satisfies the constraints, or false
+  /// otherwise.
+  bool containsPoint(ArrayRef<int64_t> point) const;
+
   // Clones this object.
   std::unique_ptr<FlatAffineConstraints> clone() const;
 

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 209382013de2..05d241e60958 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -169,6 +169,9 @@ class Simplex {
   /// Rollback to a snapshot. This invalidates all later snapshots.
   void rollback(unsigned snapshot);
 
+  /// Add all the constraints from the given FlatAffineConstraints.
+  void intersectFlatAffineConstraints(const FlatAffineConstraints &fac);
+
   /// Compute the maximum or minimum value of the given row, depending on
   /// direction. The specified row is never pivoted.
   ///

diff  --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h
new file mode 100644
index 000000000000..1f3a10a8a624
--- /dev/null
+++ b/mlir/include/mlir/Analysis/PresburgerSet.h
@@ -0,0 +1,112 @@
+//===- Set.h - MLIR PresburgerSet Class -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A class to represent unions of FlatAffineConstraints.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGERSET_H
+#define MLIR_ANALYSIS_PRESBURGERSET_H
+
+#include "mlir/Analysis/AffineStructures.h"
+
+namespace mlir {
+
+/// This class can represent a union of FlatAffineConstraints, with support for
+/// union, intersection, subtraction and complement operations, as well as
+/// sampling.
+///
+/// The FlatAffineConstraints (FACs) are stored in a vector, and the set
+/// represents the union of these FACs. An empty list corresponds to the empty
+/// set.
+///
+/// Note that there are no invariants guaranteed on the list of FACs other than
+/// that they are all in the same space, i.e., they all have the same number of
+/// dimensions and symbols. For example, the FACs may overlap each other.
+class PresburgerSet {
+public:
+  explicit PresburgerSet(const FlatAffineConstraints &fac);
+
+  /// Return the number of FACs in the union.
+  unsigned getNumFACs() const;
+
+  /// Return the number of real dimensions.
+  unsigned getNumDims() const;
+
+  /// Return the number of symbolic dimensions.
+  unsigned getNumSyms() const;
+
+  /// Return a reference to the list of FlatAffineConstraints.
+  ArrayRef<FlatAffineConstraints> getAllFlatAffineConstraints() const;
+
+  /// Return the FlatAffineConstraints at the specified index.
+  const FlatAffineConstraints &getFlatAffineConstraints(unsigned index) const;
+
+  /// Mutate this set, turning it into the union of this set and the given
+  /// FlatAffineConstraints.
+  void unionFACInPlace(const FlatAffineConstraints &fac);
+
+  /// Mutate this set, turning it into the union of this set and the given set.
+  void unionSetInPlace(const PresburgerSet &set);
+
+  /// Return the union of this set and the given set.
+  PresburgerSet unionSet(const PresburgerSet &set) const;
+
+  /// Return the intersection of this set and the given set.
+  PresburgerSet intersect(const PresburgerSet &set) const;
+
+  /// Return true if the set contains the given point, or false otherwise.
+  bool containsPoint(ArrayRef<int64_t> point) const;
+
+  /// Print the set's internal state.
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// Return the complement of this set.
+  PresburgerSet complement() const;
+
+  /// Return the set 
diff erence of this set and the given set, i.e.,
+  /// return `this \ set`.
+  PresburgerSet subtract(const PresburgerSet &set) const;
+
+  /// Return a universe set of the specified type that contains all points.
+  static PresburgerSet getUniverse(unsigned nDim = 0, unsigned nSym = 0);
+  /// Return an empty set of the specified type that contains no points.
+  static PresburgerSet getEmptySet(unsigned nDim = 0, unsigned nSym = 0);
+
+  /// Return true if all the sets in the union are known to be integer empty
+  /// false otherwise.
+  bool isIntegerEmpty() const;
+
+  /// Find an integer sample from the given set. This should not be called if
+  /// any of the FACs in the union are unbounded.
+  bool findIntegerSample(SmallVectorImpl<int64_t> &sample);
+
+private:
+  /// Construct an empty PresburgerSet.
+  PresburgerSet(unsigned nDim = 0, unsigned nSym = 0)
+      : nDim(nDim), nSym(nSym) {}
+
+  /// Return the set 
diff erence fac \ set.
+  static PresburgerSet getSetDifference(FlatAffineConstraints fac,
+                                        const PresburgerSet &set);
+
+  /// Number of identifiers corresponding to real dimensions.
+  unsigned nDim;
+
+  /// Number of symbolic dimensions, unknown but constant for analysis, as in
+  /// FlatAffineConstraints.
+  unsigned nSym;
+
+  /// The list of flatAffineConstraints that this set is the union of.
+  SmallVector<FlatAffineConstraints, 2> flatAffineConstraints;
+};
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGERSET_H

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 5b7f4d4982d0..341dde523e8b 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1056,6 +1056,33 @@ FlatAffineConstraints::findIntegerSample() const {
   return Simplex(*this).findIntegerSample();
 }
 
+/// Helper to evaluate an affine expression at a point.
+/// The expression is a list of coefficients for the dimensions followed by the
+/// constant term.
+static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
+  assert(expr.size() == 1 + point.size() &&
+         "Dimensionalities of point and expresion don't match!");
+  int64_t value = expr.back();
+  for (unsigned i = 0; i < point.size(); ++i)
+    value += expr[i] * point[i];
+  return value;
+}
+
+/// A point satisfies an equality iff the value of the equality at the
+/// expression is zero, and it satisfies an inequality iff the value of the
+/// inequality at that point is non-negative.
+bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
+  for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
+    if (valueAt(getEquality(i), point) != 0)
+      return false;
+  }
+  for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
+    if (valueAt(getInequality(i), point) < 0)
+      return false;
+  }
+  return true;
+}
+
 /// Tightens inequalities given that we are dealing with integer spaces. This is
 /// analogous to the GCD test but applied to inequalities. The constant term can
 /// be reduced to the preceding multiple of the GCD of the coefficients, i.e.,

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 217a94995c0a..4e334c94bd83 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
   Liveness.cpp
   LoopAnalysis.cpp
   NestedMatcher.cpp
+  PresburgerSet.cpp
   SliceAnalysis.cpp
   Utils.cpp
   )
@@ -25,7 +26,6 @@ add_mlir_library(MLIRAnalysis
   MLIRCallInterfaces
   MLIRControlFlowInterfaces
   MLIRInferTypeOpInterface
-  MLIRPresburger
   MLIRSCF
   )
 
@@ -34,6 +34,7 @@ add_mlir_library(MLIRLoopAnalysis
   AffineStructures.cpp
   LoopAnalysis.cpp
   NestedMatcher.cpp
+  PresburgerSet.cpp
   Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
@@ -51,4 +52,4 @@ add_mlir_library(MLIRLoopAnalysis
   MLIRSCF
   )
   
-add_subdirectory(Presburger)
+add_subdirectory(Presburger)
\ No newline at end of file

diff  --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt
index 2561013696d9..49cdd5ac1431 100644
--- a/mlir/lib/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt
@@ -1,4 +1,4 @@
 add_mlir_library(MLIRPresburger
   Simplex.cpp
   Matrix.cpp
-  )
+  )
\ No newline at end of file

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index db1e48f50e8e..65a8a689164f 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -451,6 +451,16 @@ void Simplex::rollback(unsigned snapshot) {
   }
 }
 
+/// Add all the constraints from the given FlatAffineConstraints.
+void Simplex::intersectFlatAffineConstraints(const FlatAffineConstraints &fac) {
+  assert(fac.getNumIds() == numVariables() &&
+         "FlatAffineConstraints must have same dimensionality as simplex");
+  for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
+    addInequality(fac.getInequality(i));
+  for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
+    addEquality(fac.getEquality(i));
+}
+
 Optional<Fraction> Simplex::computeRowOptimum(Direction direction,
                                               unsigned row) {
   // Keep trying to find a pivot for the row in the specified direction.

diff  --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp
new file mode 100644
index 000000000000..323dc3e56d54
--- /dev/null
+++ b/mlir/lib/Analysis/PresburgerSet.cpp
@@ -0,0 +1,316 @@
+//===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/PresburgerSet.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
+
+using namespace mlir;
+
+PresburgerSet::PresburgerSet(const FlatAffineConstraints &fac)
+    : nDim(fac.getNumDimIds()), nSym(fac.getNumSymbolIds()) {
+  unionFACInPlace(fac);
+}
+
+unsigned PresburgerSet::getNumFACs() const {
+  return flatAffineConstraints.size();
+}
+
+unsigned PresburgerSet::getNumDims() const { return nDim; }
+
+unsigned PresburgerSet::getNumSyms() const { return nSym; }
+
+ArrayRef<FlatAffineConstraints>
+PresburgerSet::getAllFlatAffineConstraints() const {
+  return flatAffineConstraints;
+}
+
+const FlatAffineConstraints &
+PresburgerSet::getFlatAffineConstraints(unsigned index) const {
+  assert(index < flatAffineConstraints.size() && "index out of bounds!");
+  return flatAffineConstraints[index];
+}
+
+/// Assert that the FlatAffineConstraints and PresburgerSet live in
+/// compatible spaces.
+static void assertDimensionsCompatible(const FlatAffineConstraints &fac,
+                                       const PresburgerSet &set) {
+  assert(fac.getNumDimIds() == set.getNumDims() &&
+         "Number of dimensions of the FlatAffineConstraints and PresburgerSet"
+         "do not match!");
+  assert(fac.getNumSymbolIds() == set.getNumSyms() &&
+         "Number of symbols of the FlatAffineConstraints and PresburgerSet"
+         "do not match!");
+}
+
+/// Assert that the two PresburgerSets live in compatible spaces.
+static void assertDimensionsCompatible(const PresburgerSet &setA,
+                                       const PresburgerSet &setB) {
+  assert(setA.getNumDims() == setB.getNumDims() &&
+         "Number of dimensions of the PresburgerSets do not match!");
+  assert(setA.getNumSyms() == setB.getNumSyms() &&
+         "Number of symbols of the PresburgerSets do not match!");
+}
+
+/// Mutate this set, turning it into the union of this set and the given
+/// FlatAffineConstraints.
+void PresburgerSet::unionFACInPlace(const FlatAffineConstraints &fac) {
+  assertDimensionsCompatible(fac, *this);
+  flatAffineConstraints.push_back(fac);
+}
+
+/// Mutate this set, turning it into the union of this set and the given set.
+///
+/// This is accomplished by simply adding all the FACs of the given set to this
+/// set.
+void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
+  assertDimensionsCompatible(set, *this);
+  for (const FlatAffineConstraints &fac : set.flatAffineConstraints)
+    unionFACInPlace(fac);
+}
+
+/// Return the union of this set and the given set.
+PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
+  assertDimensionsCompatible(set, *this);
+  PresburgerSet result = *this;
+  result.unionSetInPlace(set);
+  return result;
+}
+
+/// A point is contained in the union iff any of the parts contain the point.
+bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
+  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
+    if (fac.containsPoint(point))
+      return true;
+  }
+  return false;
+}
+
+PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
+  PresburgerSet result(nDim, nSym);
+  result.unionFACInPlace(FlatAffineConstraints::getUniverse(nDim, nSym));
+  return result;
+}
+
+PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
+  return PresburgerSet(nDim, nSym);
+}
+
+// Return the intersection of this set with the given set.
+//
+// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
+// as (S_1 and T_1) or (S_1 and T_2) or ...
+PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
+  assertDimensionsCompatible(set, *this);
+
+  PresburgerSet result(nDim, nSym);
+  for (const FlatAffineConstraints &csA : flatAffineConstraints) {
+    for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
+      FlatAffineConstraints intersection(csA);
+      intersection.append(csB);
+      if (!intersection.isEmpty())
+        result.unionFACInPlace(std::move(intersection));
+    }
+  }
+  return result;
+}
+
+/// Return `coeffs` with all the elements negated.
+static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
+  SmallVector<int64_t, 8> negatedCoeffs;
+  negatedCoeffs.reserve(coeffs.size());
+  for (int64_t coeff : coeffs)
+    negatedCoeffs.emplace_back(-coeff);
+  return negatedCoeffs;
+}
+
+/// Return the complement of the given inequality.
+///
+/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
+/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0.
+static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
+  SmallVector<int64_t, 8> coeffs;
+  coeffs.reserve(ineq.size());
+  for (int64_t coeff : ineq)
+    coeffs.emplace_back(-coeff);
+  --coeffs.back();
+  return coeffs;
+}
+
+/// Return the set 
diff erence b \ s and accumulate the result into `result`.
+/// `simplex` must correspond to b.
+///
+/// In the following, V denotes union, ^ denotes intersection, \ denotes set
+/// 
diff erence and ~ denotes complement.
+/// Let b be the FlatAffineConstraints and s = (V_i s_i) be the set. We want
+/// b \ (V_i s_i).
+///
+/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
+/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
+/// ~s_i = (~s_i1) V (s_i1 ^ ~s_i2) V (s_i1 ^ s_i2 ^ ~s_i3) V ...
+/// And the required result is (b ^ ~s_i1) V (b ^ s_i1 ^ ~s_i2) V ...
+/// We recurse by subtracting V_{j > i} S_j from each of these parts and
+/// returning the union of the results. Each equality is handled as a
+/// conjunction of two inequalities.
+///
+/// As a heuristic, we try adding all the constraints and check if simplex
+/// says that the intersection is empty. Also, in the process we find out that
+/// some constraints are redundant. These redundant constraints are ignored.
+static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
+                                const PresburgerSet &s, unsigned i,
+                                PresburgerSet &result) {
+  if (i == s.getNumFACs()) {
+    result.unionFACInPlace(b);
+    return;
+  }
+  const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
+  unsigned initialSnapshot = simplex.getSnapshot();
+  unsigned offset = simplex.numConstraints();
+  simplex.intersectFlatAffineConstraints(sI);
+
+  if (simplex.isEmpty()) {
+    /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
+    simplex.rollback(initialSnapshot);
+    subtractRecursively(b, simplex, s, i + 1, result);
+    return;
+  }
+
+  simplex.detectRedundant();
+  llvm::SmallBitVector isMarkedRedundant;
+  for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
+       j++)
+    isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
+
+  simplex.rollback(initialSnapshot);
+
+  // Recurse with the part b ^ ~ineq. Note that b is modified throughout
+  // subtractRecursively. At the time this function is called, the current b is
+  // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
+  // inequality, s_{i,j+1}. This function recurses into the next level i + 1
+  // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
+  auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
+    size_t snapshot = simplex.getSnapshot();
+    b.addInequality(ineq);
+    simplex.addInequality(ineq);
+    subtractRecursively(b, simplex, s, i + 1, result);
+    b.removeInequality(b.getNumInequalities() - 1);
+    simplex.rollback(snapshot);
+  };
+
+  // For each inequality ineq, we first recurse with the part where ineq
+  // is not satisfied, and then add the ineq to b and simplex because
+  // ineq must be satisfied by all later parts.
+  auto processInequality = [&](ArrayRef<int64_t> ineq) {
+    recurseWithInequality(getComplementIneq(ineq));
+    b.addInequality(ineq);
+    simplex.addInequality(ineq);
+  };
+
+  // processInequality appends some additional constraints to b. We want to
+  // rollback b to its initial state before returning, which we will do by
+  // removing all constraints beyond the original number of inequalities
+  // and equalities, so we store these counts first.
+  unsigned originalNumIneqs = b.getNumInequalities();
+  unsigned originalNumEqs = b.getNumEqualities();
+
+  for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
+    if (isMarkedRedundant[j])
+      continue;
+    processInequality(sI.getInequality(j));
+  }
+
+  offset = sI.getNumInequalities();
+  for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
+    const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
+    // Same as the above loop for inequalities, done once each for the positive
+    // and negative inequalities that make up this equality.
+    if (!isMarkedRedundant[offset + 2 * j])
+      processInequality(coeffs);
+    if (!isMarkedRedundant[offset + 2 * j + 1])
+      processInequality(getNegatedCoeffs(coeffs));
+  }
+
+  // Rollback b and simplex to their initial states.
+  for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
+    b.removeInequality(i - 1);
+
+  for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
+    b.removeEquality(i - 1);
+
+  simplex.rollback(initialSnapshot);
+}
+
+/// Return the set 
diff erence fac \ set.
+///
+/// The FAC here is modified in subtractRecursively, so it cannot be a const
+/// reference even though it is restored to its original state before returning
+/// from that function.
+PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
+                                              const PresburgerSet &set) {
+  assertDimensionsCompatible(fac, set);
+  if (fac.isEmptyByGCDTest())
+    return PresburgerSet::getEmptySet(fac.getNumDimIds(),
+                                      fac.getNumSymbolIds());
+
+  PresburgerSet result(fac.getNumDimIds(), fac.getNumSymbolIds());
+  Simplex simplex(fac);
+  subtractRecursively(fac, simplex, set, 0, result);
+  return result;
+}
+
+/// Return the complement of this set.
+PresburgerSet PresburgerSet::complement() const {
+  return getSetDifference(
+      FlatAffineConstraints::getUniverse(getNumDims(), getNumSyms()), *this);
+}
+
+/// Return the result of subtract the given set from this set, i.e.,
+/// return `this \ set`.
+PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
+  assertDimensionsCompatible(set, *this);
+  PresburgerSet result(nDim, nSym);
+  // We compute (V_i t_i) \ (V_i set_i) as V_i (t_i \ V_i set_i).
+  for (const FlatAffineConstraints &fac : flatAffineConstraints)
+    result.unionSetInPlace(getSetDifference(fac, set));
+  return result;
+}
+
+/// Return true if all the sets in the union are known to be integer empty,
+/// false otherwise.
+bool PresburgerSet::isIntegerEmpty() const {
+  assert(nSym == 0 && "isIntegerEmpty is intended for non-symbolic sets");
+  // The set is empty iff all of the disjuncts are empty.
+  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
+    if (!fac.isIntegerEmpty())
+      return false;
+  }
+  return true;
+}
+
+bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
+  assert(nSym == 0 && "findIntegerSample is intended for non-symbolic sets");
+  // A sample exists iff any of the disjuncts contains a sample.
+  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
+    if (Optional<SmallVector<int64_t, 8>> opt = fac.findIntegerSample()) {
+      sample = std::move(*opt);
+      return true;
+    }
+  }
+  return false;
+}
+
+void PresburgerSet::print(raw_ostream &os) const {
+  os << getNumFACs() << " FlatAffineConstraints:\n";
+  for (const FlatAffineConstraints &fac : flatAffineConstraints) {
+    fac.print(os);
+    os << '\n';
+  }
+}
+
+void PresburgerSet::dump() const { print(llvm::errs()); }

diff  --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp
index bf47f4c302a7..6fcb1c489cfc 100644
--- a/mlir/unittests/Analysis/AffineStructuresTest.cpp
+++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp
@@ -15,22 +15,11 @@
 
 namespace mlir {
 
-/// Evaluate the value of the given affine expression at the specified point.
-/// The expression is a list of coefficients for the dimensions followed by the
-/// constant term.
-int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
-  assert(expr.size() == 1 + point.size());
-  int64_t value = expr.back();
-  for (unsigned i = 0; i < point.size(); ++i)
-    value += expr[i] * point[i];
-  return value;
-}
-
 /// If 'hasValue' is true, check that findIntegerSample returns a valid sample
 /// for the FlatAffineConstraints fac.
 ///
 /// If hasValue is false, check that findIntegerSample does not return None.
-void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
+static void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
   Optional<SmallVector<int64_t, 8>> maybeSample = fac.findIntegerSample();
   if (!hasValue) {
     EXPECT_FALSE(maybeSample.hasValue());
@@ -41,16 +30,13 @@ void checkSample(bool hasValue, const FlatAffineConstraints &fac) {
     }
   } else {
     ASSERT_TRUE(maybeSample.hasValue());
-    for (unsigned i = 0; i < fac.getNumEqualities(); ++i)
-      EXPECT_EQ(valueAt(fac.getEquality(i), *maybeSample), 0);
-    for (unsigned i = 0; i < fac.getNumInequalities(); ++i)
-      EXPECT_GE(valueAt(fac.getInequality(i), *maybeSample), 0);
+    EXPECT_TRUE(fac.containsPoint(*maybeSample));
   }
 }
 
 /// Construct a FlatAffineConstraints from a set of inequality and
 /// equality constraints.
-FlatAffineConstraints
+static FlatAffineConstraints
 makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
                        ArrayRef<SmallVector<int64_t, 4>> eqs) {
   FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims);
@@ -66,9 +52,9 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
 /// orderings may cause the algorithm to proceed 
diff erently. At least some of
 ///.these permutations should make it past the heuristics and test the
 /// implementation of the GBR algorithm itself.
-void checkPermutationsSample(bool hasValue, unsigned nDim,
-                             ArrayRef<SmallVector<int64_t, 4>> ineqs,
-                             ArrayRef<SmallVector<int64_t, 4>> eqs) {
+static void checkPermutationsSample(bool hasValue, unsigned nDim,
+                                    ArrayRef<SmallVector<int64_t, 4>> ineqs,
+                                    ArrayRef<SmallVector<int64_t, 4>> eqs) {
   SmallVector<unsigned, 4> perm(nDim);
   std::iota(perm.begin(), perm.end(), 0);
   auto permute = [&perm](ArrayRef<int64_t> coeffs) {

diff  --git a/mlir/unittests/Analysis/CMakeLists.txt b/mlir/unittests/Analysis/CMakeLists.txt
index 16d084dc452f..6317aeb8df89 100644
--- a/mlir/unittests/Analysis/CMakeLists.txt
+++ b/mlir/unittests/Analysis/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRAnalysisTests
   AffineStructuresTest.cpp
+  PresburgerSetTest.cpp
 )
 
 target_link_libraries(MLIRAnalysisTests

diff  --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index 0cfda9b0c8aa..5dd69edfad08 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -5,3 +5,4 @@ add_mlir_unittest(MLIRPresburgerTests
 
 target_link_libraries(MLIRPresburgerTests
   PRIVATE MLIRPresburger)
+

diff  --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp
new file mode 100644
index 000000000000..99a0e8622232
--- /dev/null
+++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp
@@ -0,0 +1,524 @@
+//===- SetTest.cpp - Tests for PresburgerSet ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains tests for PresburgerSet. Each test works by computing
+// an operation (union, intersection, subtract, or complement) on two sets
+// and checking, for a set of points, that the resulting set contains the point
+// iff the result is supposed to contain it.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/PresburgerSet.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace mlir {
+
+/// Compute the union of s and t, and check that each of the given points
+/// belongs to the union iff it belongs to at least one of s and t.
+static void testUnionAtPoints(PresburgerSet s, PresburgerSet t,
+                              ArrayRef<SmallVector<int64_t, 4>> points) {
+  PresburgerSet unionSet = s.unionSet(t);
+  for (const SmallVector<int64_t, 4> &point : points) {
+    bool inS = s.containsPoint(point);
+    bool inT = t.containsPoint(point);
+    bool inUnion = unionSet.containsPoint(point);
+    EXPECT_EQ(inUnion, inS || inT);
+  }
+}
+
+/// Compute the intersection of s and t, and check that each of the given points
+/// belongs to the intersection iff it belongs to both of s and t.
+static void testIntersectAtPoints(PresburgerSet s, PresburgerSet t,
+                                  ArrayRef<SmallVector<int64_t, 4>> points) {
+  PresburgerSet intersection = s.intersect(t);
+  for (const SmallVector<int64_t, 4> &point : points) {
+    bool inS = s.containsPoint(point);
+    bool inT = t.containsPoint(point);
+    bool inIntersection = intersection.containsPoint(point);
+    EXPECT_EQ(inIntersection, inS && inT);
+  }
+}
+
+/// Compute the set 
diff erence s \ t, and check that each of the given points
+/// belongs to the 
diff erence iff it belongs to s and does not belong to t.
+static void testSubtractAtPoints(PresburgerSet s, PresburgerSet t,
+                                 ArrayRef<SmallVector<int64_t, 4>> points) {
+  PresburgerSet 
diff  = s.subtract(t);
+  for (const SmallVector<int64_t, 4> &point : points) {
+    bool inS = s.containsPoint(point);
+    bool inT = t.containsPoint(point);
+    bool inDiff = 
diff .containsPoint(point);
+    if (inT)
+      EXPECT_FALSE(inDiff);
+    else
+      EXPECT_EQ(inDiff, inS);
+  }
+}
+
+/// Compute the complement of s, and check that each of the given points
+/// belongs to the complement iff it does not belong to s.
+static void testComplementAtPoints(PresburgerSet s,
+                                   ArrayRef<SmallVector<int64_t, 4>> points) {
+  PresburgerSet complement = s.complement();
+  complement.complement();
+  for (const SmallVector<int64_t, 4> &point : points) {
+    bool inS = s.containsPoint(point);
+    bool inComplement = complement.containsPoint(point);
+    if (inS)
+      EXPECT_FALSE(inComplement);
+    else
+      EXPECT_TRUE(inComplement);
+  }
+}
+
+/// Construct a FlatAffineConstraints from a set of inequality and
+/// equality constraints.
+static FlatAffineConstraints
+makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
+                       ArrayRef<SmallVector<int64_t, 4>> eqs) {
+  FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims);
+  for (const SmallVector<int64_t, 4> &eq : eqs)
+    fac.addEquality(eq);
+  for (const SmallVector<int64_t, 4> &ineq : ineqs)
+    fac.addInequality(ineq);
+  return fac;
+}
+
+static FlatAffineConstraints
+makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
+  return makeFACFromConstraints(dims, ineqs, {});
+}
+
+static PresburgerSet makeSetFromFACs(unsigned dims,
+                                     ArrayRef<FlatAffineConstraints> facs) {
+  PresburgerSet set = PresburgerSet::getEmptySet(dims);
+  for (const FlatAffineConstraints &fac : facs)
+    set.unionFACInPlace(fac);
+  return set;
+}
+
+TEST(SetTest, containsPoint) {
+  PresburgerSet setA =
+      makeSetFromFACs(1, {
+                             makeFACFromIneqs(1, {{1, -2},    // x >= 2.
+                                                  {-1, 8}}),  // x <= 8.
+                             makeFACFromIneqs(1, {{1, -10},   // x >= 10.
+                                                  {-1, 20}}), // x <= 20.
+                         });
+  for (unsigned x = 0; x <= 21; ++x) {
+    if ((2 <= x && x <= 8) || (10 <= x && x <= 20))
+      EXPECT_TRUE(setA.containsPoint({x}));
+    else
+      EXPECT_FALSE(setA.containsPoint({x}));
+  }
+
+  // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} union
+  // a square with opposite corners (2, 2) and (10, 10).
+  PresburgerSet setB =
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 1, -2},   // x + y >= 4.
+                                               {-1, -1, 30}, // x + y <= 32.
+                                               {1, -1, 0},   // x - y >= 2.
+                                               {-1, 1, 10},  // x - y <= 16.
+                                           }),
+                          makeFACFromIneqs(2, {
+                                                  {1, 0, -2},  // x >= 2.
+                                                  {0, 1, -2},  // y >= 2.
+                                                  {-1, 0, 10}, // x <= 10.
+                                                  {0, -1, 10}  // y <= 10.
+                                              })});
+
+  for (unsigned x = 1; x <= 25; ++x) {
+    for (unsigned y = -6; y <= 16; ++y) {
+      if (4 <= x + y && x + y <= 32 && 2 <= x - y && x - y <= 16)
+        EXPECT_TRUE(setB.containsPoint({x, y}));
+      else if (2 <= x && x <= 10 && 2 <= y && y <= 10)
+        EXPECT_TRUE(setB.containsPoint({x, y}));
+      else
+        EXPECT_FALSE(setB.containsPoint({x, y}));
+    }
+  }
+}
+
+TEST(SetTest, Union) {
+  PresburgerSet set =
+      makeSetFromFACs(1, {
+                             makeFACFromIneqs(1, {{1, -2},    // x >= 2.
+                                                  {-1, 8}}),  // x <= 8.
+                             makeFACFromIneqs(1, {{1, -10},   // x >= 10.
+                                                  {-1, 20}}), // x <= 20.
+                         });
+
+  // Universe union set.
+  testUnionAtPoints(PresburgerSet::getUniverse(1), set,
+                    {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // empty set union set.
+  testUnionAtPoints(PresburgerSet::getEmptySet(1), set,
+                    {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // empty set union Universe.
+  testUnionAtPoints(PresburgerSet::getEmptySet(1),
+                    PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
+
+  // Universe union empty set.
+  testUnionAtPoints(PresburgerSet::getUniverse(1),
+                    PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
+
+  // empty set union empty set.
+  testUnionAtPoints(PresburgerSet::getEmptySet(1),
+                    PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
+}
+
+TEST(SetTest, Intersect) {
+  PresburgerSet set =
+      makeSetFromFACs(1, {
+                             makeFACFromIneqs(1, {{1, -2},    // x >= 2.
+                                                  {-1, 8}}),  // x <= 8.
+                             makeFACFromIneqs(1, {{1, -10},   // x >= 10.
+                                                  {-1, 20}}), // x <= 20.
+                         });
+
+  // Universe intersection set.
+  testIntersectAtPoints(PresburgerSet::getUniverse(1), set,
+                        {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // empty set intersection set.
+  testIntersectAtPoints(PresburgerSet::getEmptySet(1), set,
+                        {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // empty set intersection Universe.
+  testIntersectAtPoints(PresburgerSet::getEmptySet(1),
+                        PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
+
+  // Universe intersection empty set.
+  testIntersectAtPoints(PresburgerSet::getUniverse(1),
+                        PresburgerSet::getEmptySet(1), {{1}, {2}, {0}, {-1}});
+
+  // Universe intersection Universe.
+  testIntersectAtPoints(PresburgerSet::getUniverse(1),
+                        PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
+}
+
+TEST(SetTest, Subtract) {
+  // The interval [2, 8] minus
+  // the interval [10, 20].
+  testSubtractAtPoints(
+      makeSetFromFACs(1, {makeFACFromIneqs(1, {})}),
+      makeSetFromFACs(1,
+                      {
+                          makeFACFromIneqs(1, {{1, -2},    // x >= 2.
+                                               {-1, 8}}),  // x <= 8.
+                          makeFACFromIneqs(1, {{1, -10},   // x >= 10.
+                                               {-1, 20}}), // x <= 20.
+                      }),
+      {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // ((-infinity, 0] U [3, 4] U [6, 7]) - ([2, 3] U [5, 6])
+  testSubtractAtPoints(
+      makeSetFromFACs(1,
+                      {
+                          makeFACFromIneqs(1,
+                                           {
+                                               {-1, 0} // x <= 0.
+                                           }),
+                          makeFACFromIneqs(1,
+                                           {
+                                               {1, -3}, // x >= 3.
+                                               {-1, 4}  // x <= 4.
+                                           }),
+                          makeFACFromIneqs(1,
+                                           {
+                                               {1, -6}, // x >= 6.
+                                               {-1, 7}  // x <= 7.
+                                           }),
+                      }),
+      makeSetFromFACs(1, {makeFACFromIneqs(1,
+                                           {
+                                               {1, -2}, // x >= 2.
+                                               {-1, 3}, // x <= 3.
+                                           }),
+                          makeFACFromIneqs(1,
+                                           {
+                                               {1, -5}, // x >= 5.
+                                               {-1, 6}  // x <= 6.
+                                           })}),
+      {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
+
+  // Expected result is {[x, y] : x > y}, i.e., {[x, y] : x >= y + 1}.
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, -1, 0} // x >= y.
+                                           })}),
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 1, 0} // x >= -y.
+                                           })}),
+      {{0, 1}, {1, 1}, {1, 0}, {1, -1}, {0, -1}});
+
+  // A rectangle with corners at (2, 2) and (10, 10), minus
+  // a rectangle with corners at (5, -10) and (7, 100).
+  // This splits the former rectangle into two halves, (2, 2) to (5, 10) and
+  // (7, 2) to (10, 10).
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 0, -2},  // x >= 2.
+                                               {0, 1, -2},  // y >= 2.
+                                               {-1, 0, 10}, // x <= 10.
+                                               {0, -1, 10}  // y <= 10.
+                                           })}),
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 0, -5},   // x >= 5.
+                                               {0, 1, 10},   // y >= -10.
+                                               {-1, 0, 7},   // x <= 7.
+                                               {0, -1, 100}, // y <= 100.
+                                           })}),
+      {{1, 2},  {2, 2},  {4, 2},  {5, 2},  {7, 2},  {8, 2},  {11, 2},
+       {1, 1},  {2, 1},  {4, 1},  {5, 1},  {7, 1},  {8, 1},  {11, 1},
+       {1, 10}, {2, 10}, {4, 10}, {5, 10}, {7, 10}, {8, 10}, {11, 10},
+       {1, 11}, {2, 11}, {4, 11}, {5, 11}, {7, 11}, {8, 11}, {11, 11}});
+
+  // A rectangle with corners at (2, 2) and (10, 10), minus
+  // a rectangle with corners at (5, 4) and (7, 8).
+  // This creates a hole in the middle of the former rectangle, and the
+  // resulting set can be represented as a union of four rectangles.
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 0, -2},  // x >= 2.
+                                               {0, 1, -2},  // y >= 2.
+                                               {-1, 0, 10}, // x <= 10.
+                                               {0, -1, 10}  // y <= 10.
+                                           })}),
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 0, -5}, // x >= 5.
+                                               {0, 1, -4}, // y >= 4.
+                                               {-1, 0, 7}, // x <= 7.
+                                               {0, -1, 8}, // y <= 8.
+                                           })}),
+      {{1, 1},
+       {2, 2},
+       {10, 10},
+       {11, 11},
+       {5, 4},
+       {7, 4},
+       {5, 8},
+       {7, 8},
+       {4, 4},
+       {8, 4},
+       {4, 8},
+       {8, 8}});
+
+  // The second set is a superset of the first one, since on the line x + y = 0,
+  // y <= 1 is equivalent to x >= -1. So the result is empty.
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromConstraints(2,
+                                                 {
+                                                     {1, 0, 0} // x >= 0.
+                                                 },
+                                                 {
+                                                     {1, 1, 0} // x + y = 0.
+                                                 })}),
+      makeSetFromFACs(2, {makeFACFromConstraints(2,
+                                                 {
+                                                     {0, -1, 1} // y <= 1.
+                                                 },
+                                                 {
+                                                     {1, 1, 0} // x + y = 0.
+                                                 })}),
+      {{0, 0},
+       {1, -1},
+       {2, -2},
+       {-1, 1},
+       {-2, 2},
+       {1, 1},
+       {-1, -1},
+       {-1, 1},
+       {1, -1}});
+
+  // The result should be {0} U {2}.
+  testSubtractAtPoints(
+      makeSetFromFACs(1,
+                      {
+                          makeFACFromIneqs(1, {{1, 0},    // x >= 0.
+                                               {-1, 2}}), // x <= 2.
+                      }),
+      makeSetFromFACs(1,
+                      {
+                          makeFACFromConstraints(1, {},
+                                                 {
+                                                     {1, -1} // x = 1.
+                                                 }),
+                      }),
+      {{-1}, {0}, {1}, {2}, {3}});
+
+  // Sets with lots of redundant inequalities to test the redundancy heuristic.
+  // (the heuristic is for the subtrahend, the second set which is the one being
+  // subtracted)
+
+  // A parallelogram with vertices {(3, 1), (10, -6), (24, 8), (17, 15)} minus
+  // a triangle with vertices {(2, 2), (10, 2), (10, 10)}.
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 1, -2},   // x + y >= 4.
+                                               {-1, -1, 30}, // x + y <= 32.
+                                               {1, -1, 0},   // x - y >= 2.
+                                               {-1, 1, 10},  // x - y <= 16.
+                                           })}),
+      makeSetFromFACs(
+          2, {makeFACFromIneqs(2,
+                               {
+                                   {1, 0, -2},   // x >= 2. [redundant]
+                                   {0, 1, -2},   // y >= 2.
+                                   {-1, 0, 10},  // x <= 10.
+                                   {0, -1, 10},  // y <= 10. [redundant]
+                                   {1, 1, -2},   // x + y >= 2. [redundant]
+                                   {-1, -1, 30}, // x + y <= 30. [redundant]
+                                   {1, -1, 0},   // x - y >= 0.
+                                   {-1, 1, 10},  // x - y <= 10.
+                               })}),
+      {{1, 2},  {2, 2},   {3, 2},   {4, 2},  {1, 1},   {2, 1},   {3, 1},
+       {4, 1},  {2, 0},   {3, 0},   {4, 0},  {5, 0},   {10, 2},  {11, 2},
+       {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
+       {24, 8}, {24, 7},  {17, 15}, {16, 15}});
+
+  testSubtractAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 1, -2},   // x + y >= 4.
+                                               {-1, -1, 30}, // x + y <= 32.
+                                               {1, -1, 0},   // x - y >= 2.
+                                               {-1, 1, 10},  // x - y <= 16.
+                                           })}),
+      makeSetFromFACs(
+          2, {makeFACFromIneqs(2,
+                               {
+                                   {1, 0, -2},   // x >= 2. [redundant]
+                                   {0, 1, -2},   // y >= 2.
+                                   {-1, 0, 10},  // x <= 10.
+                                   {0, -1, 10},  // y <= 10. [redundant]
+                                   {1, 1, -2},   // x + y >= 2. [redundant]
+                                   {-1, -1, 30}, // x + y <= 30. [redundant]
+                                   {1, -1, 0},   // x - y >= 0.
+                                   {-1, 1, 10},  // x - y <= 10.
+                               })}),
+      {{1, 2},  {2, 2},   {3, 2},   {4, 2},  {1, 1},   {2, 1},   {3, 1},
+       {4, 1},  {2, 0},   {3, 0},   {4, 0},  {5, 0},   {10, 2},  {11, 2},
+       {10, 1}, {10, 10}, {10, 11}, {10, 9}, {11, 10}, {10, -6}, {11, -6},
+       {24, 8}, {24, 7},  {17, 15}, {16, 15}});
+
+  // ((-infinity, -5] U [3, 3] U [4, 4] U [5, 5]) - ([-2, -10] U [3, 4] U [6,
+  // 7])
+  testSubtractAtPoints(
+      makeSetFromFACs(1,
+                      {
+                          makeFACFromIneqs(1,
+                                           {
+                                               {-1, -5}, // x <= -5.
+                                           }),
+                          makeFACFromConstraints(1, {},
+                                                 {
+                                                     {1, -3} // x = 3.
+                                                 }),
+                          makeFACFromConstraints(1, {},
+                                                 {
+                                                     {1, -4} // x = 4.
+                                                 }),
+                          makeFACFromConstraints(1, {},
+                                                 {
+                                                     {1, -5} // x = 5.
+                                                 }),
+                      }),
+      makeSetFromFACs(
+          1,
+          {
+              makeFACFromIneqs(1,
+                               {
+                                   {-1, -2},  // x <= -2.
+                                   {1, -10},  // x >= -10.
+                                   {-1, 0},   // x <= 0. [redundant]
+                                   {-1, 10},  // x <= 10. [redundant]
+                                   {1, -100}, // x >= -100. [redundant]
+                                   {1, -50}   // x >= -50. [redundant]
+                               }),
+              makeFACFromIneqs(1,
+                               {
+                                   {1, -3}, // x >= 3.
+                                   {-1, 4}, // x <= 4.
+                                   {1, 1},  // x >= -1. [redundant]
+                                   {1, 7},  // x >= -7. [redundant]
+                                   {-1, 10} // x <= 10. [redundant]
+                               }),
+              makeFACFromIneqs(1,
+                               {
+                                   {1, -6}, // x >= 6.
+                                   {-1, 7}, // x <= 7.
+                                   {1, 1},  // x >= -1. [redundant]
+                                   {1, -3}, // x >= -3. [redundant]
+                                   {-1, 5}  // x <= 5. [redundant]
+                               }),
+          }),
+      {{-6},
+       {-5},
+       {-4},
+       {-9},
+       {-10},
+       {-11},
+       {0},
+       {1},
+       {2},
+       {3},
+       {4},
+       {5},
+       {6},
+       {7},
+       {8}});
+}
+
+TEST(SetTest, Complement) {
+  // Complement of universe.
+  testComplementAtPoints(
+      PresburgerSet::getUniverse(1),
+      {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  // Complement of empty set.
+  testComplementAtPoints(
+      PresburgerSet::getEmptySet(1),
+      {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
+
+  testComplementAtPoints(
+      makeSetFromFACs(2, {makeFACFromIneqs(2,
+                                           {
+                                               {1, 0, -2},  // x >= 2.
+                                               {0, 1, -2},  // y >= 2.
+                                               {-1, 0, 10}, // x <= 10.
+                                               {0, -1, 10}  // y <= 10.
+                                           })}),
+      {{1, 1},
+       {2, 1},
+       {1, 2},
+       {2, 2},
+       {2, 3},
+       {3, 2},
+       {10, 10},
+       {10, 11},
+       {11, 10},
+       {2, 10},
+       {2, 11},
+       {1, 10}});
+}
+
+} // namespace mlir


        


More information about the Mlir-commits mailing list