[Mlir-commits] [mlir] bb2226a - [MLIR][Presburger] Refactor MultiAffineFunction to be defined over universe
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Sep 10 17:12:38 PDT 2022
Author: Groverkss
Date: 2022-09-11T01:12:09+01:00
New Revision: bb2226ac53aa255d7520146ab9e0048cfc43a225
URL: https://github.com/llvm/llvm-project/commit/bb2226ac53aa255d7520146ab9e0048cfc43a225
DIFF: https://github.com/llvm/llvm-project/commit/bb2226ac53aa255d7520146ab9e0048cfc43a225.diff
LOG: [MLIR][Presburger] Refactor MultiAffineFunction to be defined over universe
This patch refactors MAF to be defined over the universe in a given space
instead of being defined over a restricted domain.
The reasoning for this refactor is to store division representation for local
variables explicitly for the function outputs. This change is required for
unionLexMax/Min to support local variables which will be upstreamed after this
patch. Another reason for this refactor is to have a flattened form of
AffineMap as MultiAffineFunction.
Reviewed By: arjunp
Differential Revision: https://reviews.llvm.org/D131864
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/Matrix.h
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/include/mlir/Analysis/Presburger/Utils.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/Matrix.cpp
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/lib/Analysis/Presburger/Utils.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
mlir/unittests/Analysis/Presburger/Utils.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index bf32aafc6019d..e9251ba5d031c 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -128,6 +128,8 @@ class Matrix {
/// Add `scale` multiples of the source row to the target row.
void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale);
+ /// Add `scale` multiples of the rowVec row to the specified row.
+ void addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale);
/// Add `scale` multiples of the source column to the target column.
void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale);
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index c4626a2945f01..63f3ecfca968f 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -22,94 +22,93 @@
namespace mlir {
namespace presburger {
-/// This class represents a multi-affine function whose domain is given by an
-/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
-/// tuple of integer values attached to every point in the polyhedron, with the
-/// value of each element of the tuple given by an affine expression in the vars
-/// of the polyhedron. For example we could have the domain
-///
-/// (x, y) : (x >= 5, y >= x)
-///
-/// and a tuple of three integers defined at every point in the polyhedron:
+/// This class represents a multi-affine function with the domain as Z^d, where
+/// `d` is the number of domain variables of the function. For example:
///
/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
///
-/// In this way every point in the polyhedron has a tuple of integers associated
-/// with it. If the integer polyhedron has local vars, then the output
-/// expressions can use them as well. The output expressions are represented as
-/// a matrix with one row for every element in the output vector one column for
-/// each var, and an extra column at the end for the constant term.
+/// The output expressions are represented as a matrix with one row for every
+/// output, one column for each var including division variables, and an extra
+/// column at the end for the constant term.
///
/// Checking equality of two such functions is supported, as well as finding the
/// value of the function at a specified point.
class MultiAffineFunction {
public:
- MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
- : domainSet(domain), output(output) {}
- MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
- : domainSet(space), output(output) {}
-
- unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolVars(); }
- unsigned getNumOutputs() const { return output.getNumRows(); }
- bool isConsistent() const {
- return output.getNumColumns() == domainSet.getNumVars() + 1;
+ MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
+ : space(space), output(output),
+ divs(space.getNumVars() - space.getNumRangeVars()) {
+ assertIsConsistent();
+ }
+
+ MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
+ const DivisionRepr &divs)
+ : space(space), output(output), divs(divs) {
+ assertIsConsistent();
}
- /// Get the space of the input domain of this function.
- const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
- /// Get the input domain of this function.
- const IntegerPolyhedron &getDomain() const { return domainSet; }
+ unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
+ unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
+ unsigned getNumOutputs() const { return space.getNumRangeVars(); }
+ unsigned getNumDivs() const { return space.getNumLocalVars(); }
+
+ /// Get the space of this function.
+ const PresburgerSpace &getSpace() const { return space; }
+ /// Get the domain/output space of the function. The returned space is a set
+ /// space.
+ PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
+ PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
+
/// Get a matrix with each row representing row^th output expression.
const Matrix &getOutputMatrix() const { return output; }
/// Get the `i^th` output expression.
ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
- /// Insert `num` variables of the specified kind at position `pos`.
- /// Positions are relative to the kind of variable. The coefficient columns
- /// corresponding to the added variables are initialized to zero. Return the
- /// absolute column position (i.e., not relative to the kind of variable)
- /// of the first added variable.
- unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1);
-
- /// Remove the specified range of vars.
- void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit);
-
- /// Given a MAF `other`, merges local variables such that both funcitons
- /// have union of local vars, without changing the set of points in domain or
- /// the output.
- void mergeLocalVars(MultiAffineFunction &other);
-
- /// Return whether the outputs of `this` and `other` agree wherever both
- /// functions are defined, i.e., the outputs should be equal for all points in
- /// the intersection of the domains.
- bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
-
- /// Return whether the `this` and `other` are equal. This is the case if
- /// they lie in the same space, i.e. have the same dimensions, and their
- /// domains are identical and their outputs are equal on their domain.
+ // Remove the specified range of outputs.
+ void removeOutputs(unsigned start, unsigned end);
+
+ /// Given a MAF `other`, merges division variables such that both functions
+ /// have the union of the division vars that exist in the functions.
+ void mergeDivs(MultiAffineFunction &other);
+
+ /// Return the output of the function at the given point.
+ SmallVector<int64_t, 8> valueAt(ArrayRef<int64_t> point) const;
+
+ /// Return whether the `this` and `other` are equal when the domain is
+ /// restricted to `domain`. This is the case if they lie in the same space,
+ /// and their outputs are equal for every point in `domain`.
bool isEqual(const MultiAffineFunction &other) const;
+ bool isEqual(const MultiAffineFunction &other,
+ const IntegerPolyhedron &domain) const;
+ bool isEqual(const MultiAffineFunction &other,
+ const PresburgerSet &domain) const;
- /// Get the value of the function at the specified point. If the point lies
- /// outside the domain, an empty optional is returned.
- Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+ void subtract(const MultiAffineFunction &other);
- /// Truncate the output dimensions to the first `count` dimensions.
- ///
- /// TODO: refactor so that this can be accomplished through removeVarRange.
- void truncateOutput(unsigned count);
+ /// Get this function as a relation.
+ IntegerRelation getAsRelation() const;
void print(raw_ostream &os) const;
void dump() const;
private:
- /// The IntegerPolyhedron representing the domain over which the function is
- /// defined.
- IntegerPolyhedron domainSet;
+ /// Assert that the MAF is consistent.
+ void assertIsConsistent() const;
+
+ /// The space of this function. The domain variables are considered as the
+ /// input variables of the function. The range variables are considered as
+ /// the outputs. The symbols parametrize the function and locals are used to
+ /// represent divisions. Each local variable has a corressponding division
+ /// representation stored in `divs`.
+ PresburgerSpace space;
/// The function's output is a tuple of integers, with the ith element of the
/// tuple defined by the affine expression given by the ith row of this output
/// matrix.
Matrix output;
+
+ /// Storage for division representation for each local variable in space.
+ DivisionRepr divs;
};
/// This class represents a piece-wise MultiAffineFunction. This can be thought
@@ -132,33 +131,47 @@ class MultiAffineFunction {
/// finding the value of the function at a point.
class PWMAFunction {
public:
- PWMAFunction(const PresburgerSpace &space, unsigned numOutputs)
- : space(space), numOutputs(numOutputs) {
- assert(space.getNumDomainVars() == 0 &&
- "Set type space should have zero domain vars.");
+ struct Piece {
+ PresburgerSet domain;
+ MultiAffineFunction output;
+
+ bool isConsistent() const {
+ return domain.getSpace().isCompatible(output.getDomainSpace());
+ }
+ };
+
+ PWMAFunction(const PresburgerSpace &space) : space(space) {
assert(space.getNumLocalVars() == 0 &&
"PWMAFunction cannot have local vars.");
- assert(numOutputs >= 1 && "The function must output something!");
}
+ // Get the space of this function.
const PresburgerSpace &getSpace() const { return space; }
- void addPiece(const MultiAffineFunction &piece);
- void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
- void addPiece(const PresburgerSet &domain, const Matrix &output);
+ // Add a piece ([domain, output] pair) to this function.
+ void addPiece(const Piece &piece);
- const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
unsigned getNumPieces() const { return pieces.size(); }
- unsigned getNumOutputs() const { return numOutputs; }
- unsigned getNumInputs() const { return space.getNumVars(); }
- MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
+ unsigned getNumVarKind(VarKind kind) const {
+ return space.getNumVarKind(kind);
+ }
+ unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
+ unsigned getNumOutputs() const { return space.getNumRangeVars(); }
+ unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
+
+ /// Remove the specified range of outputs.
+ void removeOutputs(unsigned start, unsigned end);
+
+ /// Get the domain/output space of the function. The returned space is a set
+ /// space.
+ PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
+ PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); }
/// Return the domain of this piece-wise MultiAffineFunction. This is the
/// union of the domains of all the pieces.
PresburgerSet getDomain() const;
- /// Return the value at the specified point and an empty optional if the
- /// point does not lie in the domain.
+ /// Return the output of the function at the given point.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
@@ -166,11 +179,6 @@ class PWMAFunction {
/// value at every point in the domain.
bool isEqual(const PWMAFunction &other) const;
- /// Truncate the output dimensions to the first `count` dimensions.
- ///
- /// TODO: refactor so that this can be accomplished through removeVarRange.
- void truncateOutput(unsigned count);
-
/// Return a function defined on the union of the domains of this and func,
/// such that when only one of the functions is defined, it outputs the same
/// as that function, and if both are defined, it outputs the lexmax/lexmin of
@@ -178,8 +186,8 @@ class PWMAFunction {
/// function is not defined either.
///
/// Currently this does not support PWMAFunctions which have pieces containing
- /// local variables.
- /// TODO: Support local variables in peices.
+ /// divisions.
+ /// TODO: Support division in pieces.
PWMAFunction unionLexMin(const PWMAFunction &func);
PWMAFunction unionLexMax(const PWMAFunction &func);
@@ -200,19 +208,17 @@ class PWMAFunction {
///
/// The PresburgerSet returned by `tiebreak` should be disjoint.
/// TODO: Remove this constraint of returning disjoint set.
- PWMAFunction
- unionFunction(const PWMAFunction &func,
- llvm::function_ref<PresburgerSet(MultiAffineFunction mafA,
- MultiAffineFunction mafB)>
- tiebreak) const;
+ PWMAFunction unionFunction(
+ const PWMAFunction &func,
+ llvm::function_ref<PresburgerSet(Piece mafA, Piece mafB)> tiebreak) const;
+ /// The space of this function. The domain variables are considered as the
+ /// input variables of the function. The range variables are considered as
+ /// the outputs. The symbols paramterize the function.
PresburgerSpace space;
- /// The list of pieces in this piece-wise MultiAffineFunction.
- SmallVector<MultiAffineFunction, 4> pieces;
-
- /// The number of output vars.
- unsigned numOutputs;
+ // The pieces of the PWMAFunction.
+ SmallVector<Piece, 4> pieces;
};
} // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index ff2c2ad85edd9..03a5dfb0631e3 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -90,6 +90,11 @@ class PresburgerSpace {
numLocals);
}
+ // Get the domain/range space of this space. The returned space is a set
+ // space.
+ PresburgerSpace getDomainSpace() const;
+ PresburgerSpace getRangeSpace() const;
+
unsigned getNumDomainVars() const { return numDomain; }
unsigned getNumRangeVars() const { return numRange; }
unsigned getNumSetDimVars() const { return numRange; }
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 2c3aedea8452f..485a064c6ccef 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -529,9 +529,9 @@ class LexSimplex : public LexSimplexBase {
/// Represents the result of a symbolic lexicographic minimization computation.
struct SymbolicLexMin {
- SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
- : lexmin(domainSpace, numOutputs),
- unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
+ SymbolicLexMin(const PresburgerSpace &space)
+ : lexmin(space),
+ unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {}
/// This maps assignments of symbols to the corresponding lexmin.
/// Takes no value when no integer sample exists for the assignment or if the
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 6d81c65b0faea..3801cb63af5ea 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -118,7 +118,7 @@ class DivisionRepr {
DivisionRepr(unsigned numVars, unsigned numDivs)
: dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
- DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {}
+ DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {}
unsigned getNumVars() const { return dividends.getNumColumns() - 1; }
unsigned getNumDivs() const { return dividends.getNumRows(); }
@@ -142,16 +142,25 @@ class DivisionRepr {
return dividends.getRow(i);
}
+ // For a given point containing values for each variable other than the
+ // division variables, try to find the values for each division variable from
+ // their division representation.
+ SmallVector<Optional<int64_t>, 4> divValuesAt(ArrayRef<int64_t> point) const;
+
// Get the `i^th` denominator.
unsigned &getDenom(unsigned i) { return denoms[i]; }
unsigned getDenom(unsigned i) const { return denoms[i]; }
ArrayRef<unsigned> getDenoms() const { return denoms; }
- void setDividend(unsigned i, ArrayRef<int64_t> dividend) {
+ void setDiv(unsigned i, ArrayRef<int64_t> dividend, unsigned divisor) {
dividends.setRow(i, dividend);
+ denoms[i] = divisor;
}
+ void insertDiv(unsigned pos, ArrayRef<int64_t> dividend, unsigned divisor);
+ void insertDiv(unsigned pos, unsigned num = 1);
+
/// Removes duplicate divisions. On every possible duplicate division found,
/// `merge(i, j)`, where `i`, `j` are current index of the duplicate
/// divisions, is called and division at index `j` is merged into division at
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 18b48b5cdc8e2..03252ce0f4c8b 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -238,6 +238,7 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
getVarKindEnd(VarKind::Domain));
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
+ // The resultant space of lexmin is the space of the relation itself.
SymbolicLexMin result =
SymbolicLexSimplex(*this,
IntegerPolyhedron(PresburgerSpace::getSetSpace(
@@ -248,8 +249,8 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
// We want to return only the lexmin over the dims, so strip the locals from
// the computed lexmin.
- result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
- getNumLocalVars());
+ result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(),
+ result.lexmin.getNumOutputs());
return result;
}
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index c9767ae3cee2b..c51aa3c922eac 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -192,10 +192,14 @@ void Matrix::fillRow(unsigned row, int64_t value) {
}
void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
+ addToRow(targetRow, getRow(sourceRow), scale);
+}
+
+void Matrix::addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale) {
if (scale == 0)
return;
for (unsigned col = 0; col < nColumns; ++col)
- at(targetRow, col) += scale * at(sourceRow, col);
+ at(row, col) += scale * rowVec[col];
}
void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 18b5d0e7c68dc..d1d3925c59462 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -12,11 +12,25 @@
using namespace mlir;
using namespace presburger;
+void MultiAffineFunction::assertIsConsistent() const {
+ assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
+ output.getNumColumns() &&
+ "Inconsistent number of output columns");
+ assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
+ divs.getNumNonDivs() &&
+ "Inconsistent number of non-division variables in divs");
+ assert(space.getNumRangeVars() == output.getNumRows() &&
+ "Inconsistent number of output rows");
+ assert(space.getNumLocalVars() == divs.getNumDivs() &&
+ "Inconsistent number of divisions.");
+ assert(divs.hasAllReprs() && "All divisions should have a representation");
+}
+
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
-static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
- ArrayRef<int64_t> vecB) {
+static SmallVector<int64_t, 8> subtractExprs(ArrayRef<int64_t> vecA,
+ ArrayRef<int64_t> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of
diff ering lengths!");
SmallVector<int64_t, 8> result;
@@ -27,152 +41,135 @@ static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
}
PresburgerSet PWMAFunction::getDomain() const {
- PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
- for (const MultiAffineFunction &piece : pieces)
- domain.unionInPlace(piece.getDomain());
+ PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
+ for (const Piece &piece : pieces)
+ domain.unionInPlace(piece.domain);
return domain;
}
-Optional<SmallVector<int64_t, 8>>
+void MultiAffineFunction::print(raw_ostream &os) const {
+ space.print(os);
+ os << "Division Representation:\n";
+ divs.print(os);
+ os << "Output:\n";
+ output.print(os);
+}
+
+SmallVector<int64_t, 8>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
- assert(point.size() == domainSet.getNumDimAndSymbolVars() &&
+ assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
"Point has incorrect dimensionality!");
- Optional<SmallVector<int64_t, 8>> maybeLocalValues =
- getDomain().containsPointNoLocal(point);
- if (!maybeLocalValues)
- return {};
-
- // The point lies in the domain, so we need to compute the output value.
SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
- // The given point didn't include the values of locals which the output is a
- // function of; we have computed one possible set of values and use them
- // here. The function is not allowed to have local vars that take more than
- // one possible value.
- pointHomogenous.append(*maybeLocalValues);
+ // Get the division values at this point.
+ SmallVector<Optional<int64_t>, 8> divValues = divs.divValuesAt(point);
+ // The given point didn't include the values of the divs which the output is a
+ // function of; we have computed one possible set of values and use them here.
+ pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
+ for (const Optional<int64_t> &divVal : divValues)
+ pointHomogenous.push_back(*divVal);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
- pointHomogenous.emplace_back(1);
+ pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
-Optional<SmallVector<int64_t, 8>>
-PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
- assert(point.size() == getNumInputs() &&
- "Point has incorrect dimensionality!");
- for (const MultiAffineFunction &piece : pieces)
- if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
- return output;
- return {};
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
+ assert(space.isCompatible(other.space) &&
+ "Spaces should be compatible for equality check.");
+ return getAsRelation().isEqual(other.getAsRelation());
}
-void MultiAffineFunction::print(raw_ostream &os) const {
- os << "Domain:";
- domainSet.print(os);
- os << "Output:\n";
- output.print(os);
- os << "\n";
-}
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
+ const IntegerPolyhedron &domain) const {
+ assert(space.isCompatible(other.space) &&
+ "Spaces should be compatible for equality check.");
+ IntegerRelation restrictedThis = getAsRelation();
+ restrictedThis.intersectDomain(domain);
-void MultiAffineFunction::dump() const { print(llvm::errs()); }
+ IntegerRelation restrictedOther = other.getAsRelation();
+ restrictedOther.intersectDomain(domain);
-bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
- return getDomainSpace().isCompatible(other.getDomainSpace()) &&
- getDomain().isEqual(other.getDomain()) &&
- isEqualWhereDomainsOverlap(other);
+ return restrictedThis.isEqual(restrictedOther);
}
-unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos,
- unsigned num) {
- assert(kind != VarKind::Domain && "Domain has to be zero in a set");
- unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos;
- output.insertColumns(absolutePos, num);
- return domainSet.insertVar(kind, pos, num);
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
+ const PresburgerSet &domain) const {
+ assert(space.isCompatible(other.space) &&
+ "Spaces should be compatible for equality check.");
+ return llvm::all_of(domain.getAllDisjuncts(),
+ [&](const IntegerRelation &disjunct) {
+ return isEqual(other, IntegerPolyhedron(disjunct));
+ });
}
-void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart,
- unsigned varLimit) {
- output.removeColumns(varStart + domainSet.getVarKindOffset(kind),
- varLimit - varStart);
- domainSet.removeVarRange(kind, varStart, varLimit);
-}
+void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
+ assert(end <= getNumOutputs() && "Invalid range");
-void MultiAffineFunction::truncateOutput(unsigned count) {
- assert(count <= output.getNumRows());
- output.resizeVertically(count);
-}
+ if (start >= end)
+ return;
-void PWMAFunction::truncateOutput(unsigned count) {
- assert(count <= numOutputs);
- for (MultiAffineFunction &piece : pieces)
- piece.truncateOutput(count);
- numOutputs = count;
+ space.removeVarRange(VarKind::Range, start, end);
+ output.removeRows(start, end - start);
}
-void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) {
- // Merge output local vars of both functions without using division
- // information i.e. append local vars of `other` to `this` and insert
- // local vars of `this` to `other` at the start of it's local vars.
- output.insertColumns(domainSet.getVarKindEnd(VarKind::Local),
- other.domainSet.getNumLocalVars());
- other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local),
- domainSet.getNumLocalVars());
+void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
+ assert(space.isCompatible(other.space) && "Functions should be compatible");
- auto merge = [this, &other](unsigned i, unsigned j) -> bool {
- // Merge local at position j into local at position i in function domain.
- domainSet.eliminateRedundantLocalVar(i, j);
- other.domainSet.eliminateRedundantLocalVar(i, j);
+ unsigned nDivs = getNumDivs();
+ unsigned divOffset = divs.getDivOffset();
- unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local);
+ other.divs.insertDiv(0, nDivs);
- // Merge local at position j into local at position i in output domain.
- output.addToColumn(localOffset + j, localOffset + i, 1);
- output.removeColumn(localOffset + j);
- other.output.addToColumn(localOffset + j, localOffset + i, 1);
- other.output.removeColumn(localOffset + j);
+ SmallVector<int64_t, 8> div(other.divs.getNumVars() + 1);
+ for (unsigned i = 0; i < nDivs; ++i) {
+ // Zero fill.
+ std::fill(div.begin(), div.end(), 0);
+ // Fill div with dividend from `divs`. Do not fill the constant.
+ std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
+ div.begin());
+ // Fill constant.
+ div.back() = divs.getDividend(i).back();
+ other.divs.setDiv(i, div, divs.getDenom(i));
+ }
+
+ other.space.insertVar(VarKind::Local, 0, nDivs);
+ other.output.insertColumns(divOffset, nDivs);
+
+ auto merge = [&](unsigned i, unsigned j) {
+ // We only merge from local at pos j to local at pos i, where j > i.
+ if (i >= j)
+ return false;
+ // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
+ // do not want to merge duplicates in `this`, we ignore this call.
+ if (j < nDivs)
+ return false;
+
+ // Merge things in space and output.
+ other.space.removeVarRange(VarKind::Local, j, j + 1);
+ other.output.addToColumn(divOffset + i, divOffset + j, 1);
+ other.output.removeColumn(divOffset + j);
return true;
};
- presburger::mergeLocalVars(domainSet, other.domainSet, merge);
-}
+ other.divs.removeDuplicateDivs(merge);
-bool MultiAffineFunction::isEqualWhereDomainsOverlap(
- MultiAffineFunction other) const {
- if (!getDomainSpace().isCompatible(other.getDomainSpace()))
- return false;
+ unsigned newDivs = other.divs.getNumDivs() - nDivs;
- // `commonFunc` has the same output as `this`.
- MultiAffineFunction commonFunc = *this;
- // After this merge, `commonFunc` and `other` have the same local vars; they
- // are merged.
- commonFunc.mergeLocalVars(other);
- // After this, the domain of `commonFunc` will be the intersection of the
- // domains of `this` and `other`.
- commonFunc.domainSet.append(other.domainSet);
-
- // `commonDomainMatching` contains the subset of the common domain
- // where the outputs of `this` and `other` match.
- //
- // We want to add constraints equating the outputs of `this` and `other`.
- // However, `this` may have
diff erence local vars from `other`, whereas we
- // need both to have the same locals. Accordingly, we use `commonFunc.output`
- // in place of `this->output`, since `commonFunc` has the same output but also
- // has its locals merged.
- IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
- for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
- commonDomainMatching.addEquality(
- subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
-
- // If the whole common domain is a subset of commonDomainMatching, then they
- // are equal and the two functions match on the whole common domain.
- return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
+ space.insertVar(VarKind::Local, nDivs, newDivs);
+ output.insertColumns(divOffset + nDivs, newDivs);
+ divs = other.divs;
+
+ // Check consistency.
+ assertIsConsistent();
+ other.assertIsConsistent();
}
/// Two PWMAFunctions are equal if they have the same dimensionalities,
@@ -188,89 +185,79 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
// overlap, they take the same output value. If `this` and `other` have the
// same domain (checked above), then this check passes iff the two functions
// have the same output at every point in the domain.
- for (const MultiAffineFunction &aPiece : this->pieces)
- for (const MultiAffineFunction &bPiece : other.pieces)
- if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
- return false;
- return true;
+ return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
+ return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
+ PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
+ return pieceA.output.isEqual(pieceB.output, commonDomain);
+ });
+ });
}
-void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
- assert(space.isCompatible(piece.getDomainSpace()) &&
- "Piece to be added is not compatible with this PWMAFunction!");
- assert(piece.isConsistent() && "Piece is internally inconsistent!");
- assert(this->getDomain()
- .intersect(PresburgerSet(piece.getDomain()))
- .isIntegerEmpty() &&
- "New piece's domain overlaps with that of existing pieces!");
+void PWMAFunction::addPiece(const Piece &piece) {
+ assert(piece.isConsistent() && "Piece should be consistent");
pieces.push_back(piece);
}
-void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
- const Matrix &output) {
- addPiece(MultiAffineFunction(domain, output));
-}
-
-void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
- for (const IntegerRelation &newDom : domain.getAllDisjuncts())
- addPiece(IntegerPolyhedron(newDom), output);
-}
-
void PWMAFunction::print(raw_ostream &os) const {
- os << pieces.size() << " pieces:\n";
- for (const MultiAffineFunction &piece : pieces)
- piece.print(os);
+ space.print(os);
+ os << getNumPieces() << " pieces:\n";
+ for (const Piece &piece : pieces) {
+ os << "Domain of piece:\n";
+ piece.domain.print(os);
+ os << "Output of piece\n";
+ piece.output.print(os);
+ }
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
- llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
- MultiAffineFunction maf2)>
- tiebreak) const {
+ llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
- "Number of outputs of functions should be same.");
+ "Ranges of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
// The algorithm used here is as follows:
- // - Add the output of funcB for the part of the domain where both funcA and
- // funcB are defined, and `tiebreak` chooses the output of funcB.
- // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
- // funcA over funcB.
- // - Add the output of funcB, where funcA is not defined.
-
- // Add parts of the common domain where funcB's output is used. Also
- // add all the parts where funcA's output is used, both common and non-common.
- PWMAFunction result(getSpace(), getNumOutputs());
- for (const MultiAffineFunction &funcA : pieces) {
- PresburgerSet dom(funcA.getDomain());
- for (const MultiAffineFunction &funcB : func.pieces) {
- PresburgerSet better = tiebreak(funcB, funcA);
- // Add the output of funcB, where it is better than output of funcA.
+ // - Add the output of pieceB for the part of the domain where both pieceA and
+ // pieceB are defined, and `tiebreak` chooses the output of pieceB.
+ // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
+ // chooses
+ // pieceA over pieceB.
+ // - Add the output of pieceB, where pieceA is not defined.
+
+ // Add parts of the common domain where pieceB's output is used. Also
+ // add all the parts where pieceA's output is used, both common and
+ // non-common.
+ PWMAFunction result(getSpace());
+ for (const Piece &pieceA : pieces) {
+ PresburgerSet dom(pieceA.domain);
+ for (const Piece &pieceB : func.pieces) {
+ PresburgerSet better = tiebreak(pieceB, pieceA);
+ // Add the output of pieceB, where it is better than output of pieceA.
// The disjuncts in "better" will be disjoint as tiebreak should gurantee
// that.
- result.addPiece(better, funcB.getOutputMatrix());
+ result.addPiece({better, pieceB.output});
dom = dom.subtract(better);
}
- // Add output of funcA, where it is better than funcB, or funcB is not
+ // Add output of pieceA, where it is better than pieceB, or pieceB is not
// defined.
//
// `dom` here is guranteed to be disjoint from already added pieces
// because because the pieces added before are either:
// - Subsets of the domain of other MAFs in `this`, which are guranteed
// to be disjoint from `dom`, or
- // - They are one of the pieces added for `funcB`, and we have been
+ // - They are one of the pieces added for `pieceB`, and we have been
// subtracting all such pieces from `dom`, so `dom` is disjoint from those
// pieces as well.
- result.addPiece(dom, funcA.getOutputMatrix());
+ result.addPiece({dom, pieceA.output});
}
- // Add parts of funcB which are not shared with funcA.
+ // Add parts of pieceB which are not shared with pieceA.
PresburgerSet dom = getDomain();
- for (const MultiAffineFunction &funcB : func.pieces)
- result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
+ for (const Piece &pieceB : func.pieces)
+ result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
return result;
}
@@ -280,21 +267,19 @@ PWMAFunction PWMAFunction::unionFunction(
/// taking the lexicographically smaller output and otherwise, by taking the
/// lexicographically larger output.
template <bool lexMin>
-static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
- const MultiAffineFunction &mafB) {
+static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
+ const PWMAFunction::Piece &pieceB) {
// TODO: Support local variables here.
- assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
- "Domain spaces should be compatible.");
- assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
- "Number of outputs of both functions should be same.");
- assert(mafA.getDomain().getNumLocalVars() == 0 &&
+ assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
+ "Pieces should be compatible");
+ assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
"Local variables are not supported yet.");
- PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
- const PresburgerSpace &space = mafA.getDomain().getSpace();
+ PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
+ const PresburgerSpace &space = pieceA.domain.getSpace();
// We first create the set `result`, corresponding to the set where output
- // of mafA is lexicographically larger/smaller than mafB. This is done by
+ // of pieceA is lexicographically larger/smaller than pieceB. This is done by
// creating a PresburgerSet with the following constraints:
//
// (outA[0] > outB[0]) U
@@ -312,14 +297,15 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
- IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
- /*numReservedEqualities=*/mafA.getNumOutputs(),
- /*numReservedCols=*/space.getNumVars() + 1, space);
- for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
+ IntegerPolyhedron levelSet(
+ /*numReservedInequalities=*/1,
+ /*numReservedEqualities=*/pieceA.output.getNumOutputs(),
+ /*numReservedCols=*/space.getNumVars() + 1, space);
+ for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
// Create the expression `outA - outB` for this level.
- SmallVector<int64_t, 8> subExpr =
- subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
+ SmallVector<int64_t, 8> subExpr = subtractExprs(
+ pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
if (lexMin) {
// For lexMin, we add an upper bound of -1:
@@ -343,10 +329,9 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
levelSet.addEquality(subExpr);
}
- // We then intersect `result` with the domain of mafA and mafB, to only
+ // We then intersect `result` with the domain of pieceA and pieceB, to only
// tiebreak on the domain where both are defined.
- result = result.intersect(PresburgerSet(mafA.getDomain()))
- .intersect(PresburgerSet(mafB.getDomain()));
+ result = result.intersect(pieceA.domain).intersect(pieceB.domain);
return result;
}
@@ -358,3 +343,93 @@ PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*lexMin=*/false>);
}
+
+void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
+ assert(space.isCompatible(other.space) &&
+ "Spaces should be compatible for subtraction.");
+
+ MultiAffineFunction copyOther = other;
+ mergeDivs(copyOther);
+ for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
+ output.addToRow(i, copyOther.getOutputExpr(i), -1);
+
+ // Check consistency.
+ assertIsConsistent();
+}
+
+/// Adds division constraints corresponding to local variables, given a
+/// relation and division representations of the local variables in the
+/// relation.
+static void addDivisionConstraints(IntegerRelation &rel,
+ const DivisionRepr &divs) {
+ assert(divs.hasAllReprs() &&
+ "All divisions in divs should have a representation");
+ assert(rel.getNumVars() == divs.getNumVars() &&
+ "Relation and divs should have the same number of vars");
+ assert(rel.getNumLocalVars() == divs.getNumDivs() &&
+ "Relation and divs should have the same number of local vars");
+
+ for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
+ rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
+ divs.getDivOffset() + i));
+ rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
+ divs.getDivOffset() + i));
+ }
+}
+
+IntegerRelation MultiAffineFunction::getAsRelation() const {
+ // Create a relation corressponding to the input space plus the divisions
+ // used in outputs.
+ IntegerRelation result(PresburgerSpace::getRelationSpace(
+ space.getNumDomainVars(), 0, space.getNumSymbolVars(),
+ space.getNumLocalVars()));
+ // Add division constraints corresponding to divisions used in outputs.
+ addDivisionConstraints(result, divs);
+ // The outputs are represented as range variables in the relation. We add
+ // range variables for the outputs.
+ result.insertVar(VarKind::Range, 0, getNumOutputs());
+
+ // Add equalities such that the i^th range variable is equal to the i^th
+ // output expression.
+ SmallVector<int64_t, 8> eq(result.getNumCols());
+ for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
+ // TODO: Add functions to get VarKind offsets in output in MAF and use them
+ // here.
+ // The output expression does not contain range variables, while the
+ // equality does. So, we need to copy all variables and mark all range
+ // variables as 0 in the equality.
+ ArrayRef<int64_t> expr = getOutputExpr(i);
+ // Copy domain variables in `expr` to domain variables in `eq`.
+ std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
+ // Fill the range variables in `eq` as zero.
+ std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
+ eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
+ // Copy remaining variables in `expr` to the remaining variables in `eq`.
+ std::copy(expr.begin() + getNumDomainVars(), expr.end(),
+ eq.begin() + result.getVarKindEnd(VarKind::Range));
+
+ // Set the i^th range var to -1 in `eq` to equate the output expression to
+ // this range var.
+ eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
+ // Add the equality `rangeVar_i = output[i]`.
+ result.addEquality(eq);
+ }
+
+ return result;
+}
+
+void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
+ space.removeVarRange(VarKind::Range, start, end);
+ for (Piece &piece : pieces)
+ piece.output.removeOutputs(start, end);
+}
+
+Optional<SmallVector<int64_t, 8>>
+PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
+ assert(point.size() == getNumDomainVars() + getNumSymbolVars());
+
+ for (const Piece &piece : pieces)
+ if (piece.domain.containsPoint(point))
+ return piece.output.valueAt(point);
+ return None;
+}
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index abb8796848063..a79b229318a1f 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -13,6 +13,15 @@
using namespace mlir;
using namespace presburger;
+PresburgerSpace PresburgerSpace::getDomainSpace() const {
+ // TODO: Preserve identifiers here.
+ return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals);
+}
+
+PresburgerSpace PresburgerSpace::getRangeSpace() const {
+ return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
+}
+
unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
if (kind == VarKind::Domain)
return getNumDomainVars();
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index b856c8931114b..d3c66027b98e3 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -466,7 +466,14 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
}
output.appendExtraRow(sample);
}
- result.lexmin.addPiece(domainPoly, output);
+
+ // Store the output in a MultiAffineFunction and add it the result.
+ PresburgerSpace funcSpace = result.lexmin.getSpace();
+ funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars());
+
+ result.lexmin.addPiece(
+ {PresburgerSet(domainPoly),
+ MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())});
}
Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
@@ -508,7 +515,10 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
}
SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
- SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
+ SymbolicLexMin result(PresburgerSpace::getRelationSpace(
+ /*numDomain=*/domainPoly.getNumDimVars(),
+ /*numRange=*/var.size() - nSymbol,
+ /*numSymbols=*/domainPoly.getNumSymbolVars()));
/// The algorithm is more naturally expressed recursively, but we implement
/// it iteratively here to avoid potential issues with stack overflows in the
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index c3ac8232c6ec5..e22fa02cfce97 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -16,6 +16,8 @@
#include "mlir/Support/MathExtras.h"
#include <numeric>
+#include <numeric>
+
using namespace mlir;
using namespace presburger;
@@ -280,10 +282,8 @@ void presburger::mergeLocalVars(
DivisionRepr divsA = relA.getLocalReprs();
DivisionRepr divsB = relB.getLocalReprs();
- for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) {
- divsA.setDividend(i, divsB.getDividend(i));
- divsA.getDenom(i) = divsB.getDenom(i);
- }
+ for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i)
+ divsA.setDiv(i, divsB.getDividend(i), divsB.getDenom(i));
// Remove duplicate divisions from divsA. The removing duplicate divisions
// call, calls `merge` to effectively merge divisions in relA and relB.
@@ -357,6 +357,55 @@ SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
return coeffs;
}
+SmallVector<Optional<int64_t>, 4>
+DivisionRepr::divValuesAt(ArrayRef<int64_t> point) const {
+ assert(point.size() == getNumNonDivs() && "Incorrect point size");
+
+ SmallVector<Optional<int64_t>, 4> divValues(getNumDivs(), None);
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ for (unsigned i = 0, e = getNumDivs(); i < e; ++i) {
+ // If division value is found, continue;
+ if (divValues[i])
+ continue;
+
+ ArrayRef<int64_t> dividend = getDividend(i);
+ int64_t divVal = 0;
+
+ // Check if we have all the division values required for this division.
+ unsigned j, f;
+ for (j = 0, f = getNumDivs(); j < f; ++j) {
+ if (dividend[getDivOffset() + j] == 0)
+ continue;
+ // Division value required, but not found yet.
+ if (!divValues[j])
+ break;
+ divVal += dividend[getDivOffset() + j] * divValues[j].value();
+ }
+
+ // We have some division values that are still not found, but are required
+ // to find the value of this division.
+ if (j < f)
+ continue;
+
+ // Fill remaining values.
+ divVal = std::inner_product(point.begin(), point.end(), dividend.begin(),
+ divVal);
+ // Add constant.
+ divVal += dividend.back();
+ // Take floor division with denominator.
+ divVal = floorDiv(divVal, denoms[i]);
+
+ // Set div value and continue.
+ divValues[i] = divVal;
+ changed = true;
+ }
+ }
+
+ return divValues;
+}
+
void DivisionRepr::removeDuplicateDivs(
llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
@@ -402,6 +451,23 @@ void DivisionRepr::removeDuplicateDivs(
}
}
+void DivisionRepr::insertDiv(unsigned pos, ArrayRef<int64_t> dividend,
+ unsigned divisor) {
+ assert(pos <= getNumDivs() && "Invalid insertion position");
+ assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size");
+
+ dividends.appendExtraRow(dividend);
+ denoms.insert(denoms.begin() + pos, divisor);
+ dividends.insertColumn(getDivOffset() + pos);
+}
+
+void DivisionRepr::insertDiv(unsigned pos, unsigned num) {
+ assert(pos <= getNumDivs() && "Invalid insertion position");
+ dividends.insertColumns(getDivOffset() + pos, num);
+ dividends.insertRows(pos, num);
+ denoms.insert(denoms.begin() + pos, num, 0);
+}
+
void DivisionRepr::print(raw_ostream &os) const {
os << "Dividends:\n";
dividends.print(os);
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 2ed4cd62b4a09..de7413973e1b6 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1171,7 +1171,7 @@ void expectSymbolicIntegerLexMin(
ASSERT_NE(poly.getNumSymbolVars(), 0u);
PWMAFunction expectedLexmin =
- parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(),
+ parsePWMAF(/*numInputs=*/0,
/*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr,
/*numSymbols=*/poly.getNumSymbolVars());
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index c241d5d2f0e81..18efe55c2479a 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -130,7 +130,7 @@ TEST(IntegerRelationTest, symbolicLexmin) {
.findSymbolicIntegerLexMin();
PWMAFunction expectedLexmin =
- parsePWMAF(/*numInputs=*/2,
+ parsePWMAF(/*numInputs=*/1,
/*numOutputs=*/1,
{
{"(a)[b] : (a - b >= 0)", {{1, 0, 0}}}, // a
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index 3b4a479e97ded..417cce15e8428 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -73,14 +73,20 @@ inline PWMAFunction parsePWMAF(
unsigned numSymbols = 0) {
static MLIRContext context;
- PWMAFunction result(PresburgerSpace::getSetSpace(
- /*numDims=*/numInputs - numSymbols, numSymbols),
- numOutputs);
+ PWMAFunction result(
+ PresburgerSpace::getRelationSpace(numInputs, numOutputs, numSymbols));
for (const auto &pair : data) {
IntegerPolyhedron domain = parsePoly(pair.first);
+ PresburgerSpace funcSpace = result.getSpace();
+ funcSpace.insertVar(VarKind::Local, 0, domain.getNumLocalVars());
+
result.addPiece(
- domain, makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second));
+ {PresburgerSet(domain),
+ MultiAffineFunction(
+ funcSpace,
+ makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second),
+ domain.getLocalReprs())});
}
return result;
}
More information about the Mlir-commits
mailing list