[Mlir-commits] [mlir] 479c4f6 - [MLIR][Presburger] Refactor division representation to DivisionRepr
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 7 07:06:43 PDT 2022
Author: Groverkss
Date: 2022-07-07T15:05:28+01:00
New Revision: 479c4f648a021f1efdc30312bab804a71447e15f
URL: https://github.com/llvm/llvm-project/commit/479c4f648a021f1efdc30312bab804a71447e15f
DIFF: https://github.com/llvm/llvm-project/commit/479c4f648a021f1efdc30312bab804a71447e15f.diff
LOG: [MLIR][Presburger] Refactor division representation to DivisionRepr
This patch refactors existing implementations of division representation storage
into a new class, DivisionRepr. This refactoring is done so that the common
division utilities can be shared in an upcoming patch.
Reviewed By: arjunp
Differential Revision: https://reviews.llvm.org/D129146
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/Matrix.h
mlir/include/mlir/Analysis/Presburger/Utils.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/Matrix.cpp
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/lib/Analysis/Presburger/Utils.cpp
mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index a49d50e081a13..333669f69c2e1 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -352,22 +352,16 @@ class IntegerRelation {
Optional<SmallVector<int64_t, 8>>
containsPointNoLocal(ArrayRef<int64_t> point) const;
- /// Find equality and pairs of inequality constraints identified by their
- /// position indices, using which an explicit representation for each local
- /// variable can be computed. The indices of the constraints are stored in
- /// `MaybeLocalRepr` struct. If no such pair can be found, the kind attribute
- /// in `MaybeLocalRepr` is set to None.
+ /// Returns a `DivisonRepr` representing the division representation of local
+ /// variables in the constraint system.
///
- /// The dividends of the explicit representations are stored in `dividends`
- /// and the denominators in `denominators`. If no explicit representation
- /// could be found for the `i^th` local variable, `denominators[i]` is set
- /// to 0.
- void getLocalReprs(std::vector<SmallVector<int64_t, 8>> ÷nds,
- SmallVector<unsigned, 4> &denominators,
- std::vector<MaybeLocalRepr> &repr) const;
- void getLocalReprs(std::vector<MaybeLocalRepr> &repr) const;
- void getLocalReprs(std::vector<SmallVector<int64_t, 8>> ÷nds,
- SmallVector<unsigned, 4> &denominators) const;
+ /// If `repr` is not `nullptr`, the equality and pairs of inequality
+ /// constraints identified by their position indices using which an explicit
+ /// representation for each local variable can be computed are set in `repr`
+ /// in the form of a `MaybeLocalRepr` struct. If no such inequality
+ /// pair/equality can be found, the kind attribute in `MaybeLocalRepr` is set
+ /// to None.
+ DivisionRepr getLocalReprs(std::vector<MaybeLocalRepr> *repr = nullptr) const;
/// The type of bound: equal, lower bound or upper bound.
enum BoundType { EQ, LB, UB };
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 5b53cc5f045f9..bf32aafc6019d 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -89,6 +89,9 @@ class Matrix {
MutableArrayRef<int64_t> getRow(unsigned row);
ArrayRef<int64_t> getRow(unsigned row) const;
+ /// Set the specified row to `elems`.
+ void setRow(unsigned row, ArrayRef<int64_t> elems);
+
/// Insert columns having positions pos, pos + 1, ... pos + count - 1.
/// Columns that were at positions 0 to pos - 1 will stay where they are;
/// columns that were at positions pos to nColumns - 1 will be pushed to the
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 52321c5413def..e322cb9189396 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -17,6 +17,8 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
+
namespace mlir {
namespace presburger {
@@ -102,6 +104,80 @@ struct MaybeLocalRepr {
} repr;
};
+/// Class storing division representation of local variables of a constraint
+/// system. The coefficients of the dividends are stored in order:
+/// [nonLocalVars, localVars, constant]. Each local variable may or may not have
+/// a representation. If the local does not have a representation, the dividend
+/// of the division has no meaning and the denominator is zero.
+///
+/// The i^th division here, represents the division representation of the
+/// variable at position `divOffset + i` in the constraint system.
+class DivisionRepr {
+public:
+ DivisionRepr(unsigned numVars, unsigned numDivs)
+ : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
+
+ DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {}
+
+ unsigned getNumVars() const { return dividends.getNumColumns() - 1; }
+ unsigned getNumDivs() const { return dividends.getNumRows(); }
+ unsigned getNumNonDivs() const { return getNumVars() - getNumDivs(); }
+ // Get the offset from where division variables start.
+ unsigned getDivOffset() const { return getNumVars() - getNumDivs(); }
+
+ // Check whether the `i^th` division has a division representation or not.
+ bool hasRepr(unsigned i) const { return denoms[i] != 0; }
+ // Check whether all the divisions have a division representation or not.
+ bool hasAllReprs() const {
+ return all_of(denoms, [](unsigned denom) { return denom != 0; });
+ }
+
+ // Clear the division representation of the i^th local variable.
+ void clearRepr(unsigned i) { denoms[i] = 0; }
+
+ // Get the dividend of the `i^th` division.
+ MutableArrayRef<int64_t> getDividend(unsigned i) {
+ return dividends.getRow(i);
+ }
+ ArrayRef<int64_t> getDividend(unsigned i) const {
+ return dividends.getRow(i);
+ }
+
+ // Get the `i^th` denominator.
+ unsigned &getDenom(unsigned i) { return denoms[i]; }
+ unsigned getDenom(unsigned i) const { return denoms[i]; }
+
+ ArrayRef<unsigned> getDenoms() const { return denoms; }
+
+ void setDividend(unsigned i, ArrayRef<int64_t> dividend) {
+ dividends.setRow(i, dividend);
+ }
+
+ /// Removes duplicate divisions. On every possible duplicate division found,
+ /// `merge(i, j)`, where `i`, `j` are current index of the duplicate
+ /// divisions, is called and division at index `j` is merged into division at
+ /// index `i`. If `merge(i, j)` returns `true`, the divisions are merged i.e.
+ /// `j^th` division gets eliminated and it's each instance is replaced by
+ /// `i^th` division. If it returns `false`, the divisions are not merged.
+ /// `merge` can also do side effects, For example it can merge the local
+ /// variables in IntegerRelation.
+ void
+ removeDuplicateDivs(llvm::function_ref<bool(unsigned i, unsigned j)> merge);
+
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+private:
+ /// Each row of the Matrix represents a single division dividend. The
+ /// `i^th` row represents the dividend of the variable at `divOffset + i`
+ /// in the constraint system (and the `i^th` local variable).
+ Matrix dividends;
+
+ /// Denominators of each division. If a denominator of a division is `0`, the
+ /// division variable is considered to not have a division representation.
+ SmallVector<unsigned, 4> denoms;
+};
+
/// If `q` is defined to be equal to `expr floordiv d`, this equivalent to
/// saying that `q` is an integer and `q` is subject to the inequalities
/// `0 <= expr - d*q <= c - 1` (quotient remainder theorem).
@@ -135,25 +211,9 @@ llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
/// `MaybeLocalRepr` is set to None.
MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst,
ArrayRef<bool> foundRepr, unsigned pos,
- SmallVector<int64_t, 8> ÷nd,
+ MutableArrayRef<int64_t> dividend,
unsigned &divisor);
-/// Given dividends of divisions `divs` and denominators `denoms`, detects and
-/// removes duplicate divisions. `localOffset` is the offset in dividend of a
-/// division from where local variables start.
-///
-/// On every possible duplicate division found, `merge(i, j)`, where `i`, `j`
-/// are current index of the duplicate divisions, is called and division at
-/// index `j` is merged into division at index `i`. If `merge(i, j)` returns
-/// `true`, the divisions are merged i.e. `j^th` division gets eliminated and
-/// it's each instance is replaced by `i^th` division. If it returns `false`,
-/// the divisions are not merged. `merge` can also do side effects, For example
-/// it can merge the local variables in IntegerRelation.
-void removeDuplicateDivs(
- std::vector<SmallVector<int64_t, 8>> &divs,
- SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
- llvm::function_ref<bool(unsigned i, unsigned j)> merge);
-
/// Given two relations, A and B, add additional local vars to the sets such
/// that both have the union of the local vars in each set, without changing
/// the set of points that lie in A and B.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index b455f4818acf6..d2427eb39b734 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -175,8 +175,8 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
//
// Take a copy so we can perform mutations.
IntegerRelation copy = *this;
- std::vector<MaybeLocalRepr> reprs;
- copy.getLocalReprs(reprs);
+ std::vector<MaybeLocalRepr> reprs(getNumLocalVars());
+ copy.getLocalReprs(&reprs);
// Iterate through all the locals. The last `numNonDivLocals` are the locals
// that have been scanned already and do not have division representations.
@@ -912,56 +912,39 @@ IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
return copy.findIntegerSample();
}
-void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
- std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalVars());
- SmallVector<unsigned, 4> denominators(getNumLocalVars());
- getLocalReprs(dividends, denominators, repr);
-}
-
-void IntegerRelation::getLocalReprs(
- std::vector<SmallVector<int64_t, 8>> ÷nds,
- SmallVector<unsigned, 4> &denominators) const {
- std::vector<MaybeLocalRepr> repr(getNumLocalVars());
- getLocalReprs(dividends, denominators, repr);
-}
-
-void IntegerRelation::getLocalReprs(
- std::vector<SmallVector<int64_t, 8>> ÷nds,
- SmallVector<unsigned, 4> &denominators,
- std::vector<MaybeLocalRepr> &repr) const {
-
- repr.resize(getNumLocalVars());
- dividends.resize(getNumLocalVars());
- denominators.resize(getNumLocalVars());
-
+DivisionRepr
+IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> *repr) const {
SmallVector<bool, 8> foundRepr(getNumVars(), false);
for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
foundRepr[i] = true;
- unsigned divOffset = getNumDimAndSymbolVars();
+ unsigned localOffset = getVarKindOffset(VarKind::Local);
+ DivisionRepr divs(getNumVars(), getNumLocalVars());
bool changed;
do {
// Each time changed is true, at end of this iteration, one or more local
// vars have been detected as floor divs.
changed = false;
for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) {
- if (!foundRepr[i + divOffset]) {
- MaybeLocalRepr res = computeSingleVarRepr(
- *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
- if (!res)
+ if (!foundRepr[i + localOffset]) {
+ MaybeLocalRepr res =
+ computeSingleVarRepr(*this, foundRepr, localOffset + i,
+ divs.getDividend(i), divs.getDenom(i));
+ if (!res) {
+ // No representation was found, so clear the representation and
+ // continue.
+ divs.clearRepr(i);
continue;
- foundRepr[i + divOffset] = true;
- repr[i] = res;
+ }
+ foundRepr[localOffset + i] = true;
+ if (repr)
+ (*repr)[i] = res;
changed = true;
}
}
} while (changed);
- // Set 0 denominator for variables for which no division representation
- // could be found.
- for (unsigned i = 0, e = repr.size(); i < e; ++i)
- if (!repr[i])
- denominators[i] = 0;
+ return divs;
}
/// Tightens inequalities given that we are dealing with integer spaces. This is
@@ -1211,23 +1194,16 @@ unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) {
}
bool IntegerRelation::hasOnlyDivLocals() const {
- std::vector<MaybeLocalRepr> reprs;
- getLocalReprs(reprs);
- return llvm::all_of(reprs,
- [](const MaybeLocalRepr &repr) { return bool(repr); });
+ return getLocalReprs().hasAllReprs();
}
void IntegerRelation::removeDuplicateDivs() {
- std::vector<SmallVector<int64_t, 8>> divs;
- SmallVector<unsigned, 4> denoms;
-
- getLocalReprs(divs, denoms);
+ DivisionRepr divs = getLocalReprs();
auto merge = [this](unsigned i, unsigned j) -> bool {
eliminateRedundantLocalVar(i, j);
return true;
};
- presburger::removeDuplicateDivs(divs, denoms,
- getVarKindOffset(VarKind::Local), merge);
+ divs.removeDuplicateDivs(merge);
}
/// Removes local variables using equalities. Each equality is checked if it
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 17bbb8185ed4e..c9767ae3cee2b 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -92,6 +92,13 @@ ArrayRef<int64_t> Matrix::getRow(unsigned row) const {
return {&data[row * nReservedColumns], nColumns};
}
+void Matrix::setRow(unsigned row, ArrayRef<int64_t> elems) {
+ assert(elems.size() == getNumColumns() &&
+ "elems size must match row length!");
+ for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
+ at(row, i) = elems[i];
+}
+
void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); }
void Matrix::insertColumns(unsigned pos, unsigned count) {
if (count == 0)
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index a2e524ea9fd7c..c6a88f285c1c0 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -253,10 +253,8 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
//
// Careful! This has to be done after the merge above; otherwise, the
// dividends won't contain the new ids inserted during the merge.
- std::vector<MaybeLocalRepr> repr;
- std::vector<SmallVector<int64_t, 8>> dividends;
- SmallVector<unsigned, 4> divisors;
- sI.getLocalReprs(dividends, divisors, repr);
+ std::vector<MaybeLocalRepr> repr(sI.getNumLocalVars());
+ DivisionRepr divs = sI.getLocalReprs(&repr);
// Mark which inequalities of sI are division inequalities and add all
// such inequalities to b.
@@ -301,10 +299,10 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
// not be because they were never a part of sI; we just infer them
// from the equality and add them only to b.
b.addInequality(
- getDivLowerBound(dividends[i], divisors[i],
+ getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
sI.getVarKindOffset(VarKind::Local) + i));
b.addInequality(
- getDivUpperBound(dividends[i], divisors[i],
+ getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
sI.getVarKindOffset(VarKind::Local) + i));
}
}
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 9a6fe16fe2e36..978ced1a967f8 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -21,7 +21,7 @@ using namespace presburger;
/// Normalize a division's `dividend` and the `divisor` by their GCD. For
/// example: if the dividend and divisor are [2,0,4] and 4 respectively,
/// they get normalized to [1,0,2] and 2.
-static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> ÷nd,
+static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend,
unsigned &divisor) {
if (divisor == 0 || dividend.empty())
return;
@@ -89,7 +89,7 @@ static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> ÷nd,
/// normalized by GCD.
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
unsigned ubIneq, unsigned lbIneq,
- SmallVector<int64_t, 8> &expr,
+ MutableArrayRef<int64_t> expr,
unsigned &divisor) {
assert(pos <= cst.getNumVars() && "Invalid variable position");
@@ -97,6 +97,7 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
"Invalid upper bound inequality position");
assert(lbIneq <= cst.getNumInequalities() &&
"Invalid upper bound inequality position");
+ assert(expr.size() == cst.getNumCols() && "Invalid expression size");
// Extract divisor from the lower bound.
divisor = cst.atIneq(lbIneq, pos);
@@ -126,7 +127,6 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
// The inequality pair can be used to extract the division.
// Set `expr` to the dividend of the division except the constant term, which
// is set below.
- expr.resize(cst.getNumCols(), 0);
for (i = 0, e = cst.getNumVars(); i < e; ++i)
if (i != pos)
expr[i] = cst.atIneq(ubIneq, i);
@@ -152,11 +152,12 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
/// set to the denominator of the division. The final division expression is
/// normalized by GCD.
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
- unsigned eqInd, SmallVector<int64_t, 8> &expr,
+ unsigned eqInd, MutableArrayRef<int64_t> expr,
unsigned &divisor) {
assert(pos <= cst.getNumVars() && "Invalid variable position");
assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
+ assert(expr.size() == cst.getNumCols() && "Invalid expression size");
// Extract divisor, the divisor can be negative and hence its sign information
// is stored in `signDiv` to reverse the sign of dividend's coefficients.
@@ -169,7 +170,6 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
// The divisor is always a positive integer.
divisor = tempDiv * signDiv;
- expr.resize(cst.getNumCols(), 0);
for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
if (i != pos)
expr[i] = -signDiv * cst.atEq(eqInd, i);
@@ -215,10 +215,11 @@ static bool checkExplicitRepresentation(const IntegerRelation &cst,
/// `MaybeLocalRepr` is set to None.
MaybeLocalRepr presburger::computeSingleVarRepr(
const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
- SmallVector<int64_t, 8> ÷nd, unsigned &divisor) {
+ MutableArrayRef<int64_t> dividend, unsigned &divisor) {
assert(pos < cst.getNumVars() && "invalid position");
assert(foundRepr.size() == cst.getNumVars() &&
"Size of foundRepr does not match total number of variables");
+ assert(dividend.size() == cst.getNumCols() && "Invalid dividend size");
SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, &eqIndices);
@@ -261,57 +262,6 @@ llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
return vec;
}
-void presburger::removeDuplicateDivs(
- std::vector<SmallVector<int64_t, 8>> &divs,
- SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
- llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
-
- // Find and merge duplicate divisions.
- // TODO: Add division normalization to support divisions that
diff er by
- // a constant.
- // TODO: Add division ordering such that a division representation for local
- // variable at position `i` only depends on local variables at position <
- // `i`. This would make sure that all divisions depending on other local
- // variables that can be merged, are merged.
- for (unsigned i = 0; i < divs.size(); ++i) {
- // Check if a division representation exists for the `i^th` local var.
- if (denoms[i] == 0)
- continue;
- // Check if a division exists which is a duplicate of the division at `i`.
- for (unsigned j = i + 1; j < divs.size(); ++j) {
- // Check if a division representation exists for the `j^th` local var.
- if (denoms[j] == 0)
- continue;
- // Check if the denominators match.
- if (denoms[i] != denoms[j])
- continue;
- // Check if the representations are equal.
- if (divs[i] != divs[j])
- continue;
-
- // Merge divisions at position `j` into division at position `i`. If
- // merge fails, do not merge these divs.
- bool mergeResult = merge(i, j);
- if (!mergeResult)
- continue;
-
- // Update division information to reflect merging.
- for (unsigned k = 0, g = divs.size(); k < g; ++k) {
- SmallVector<int64_t, 8> &div = divs[k];
- if (denoms[k] != 0) {
- div[localOffset + i] += div[localOffset + j];
- div.erase(div.begin() + localOffset + j);
- }
- }
-
- divs.erase(divs.begin() + j);
- denoms.erase(denoms.begin() + j);
- // Since `j` can never be zero, we do not need to worry about overflows.
- --j;
- }
- }
-}
-
void presburger::mergeLocalVars(
IntegerRelation &relA, IntegerRelation &relB,
llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
@@ -327,23 +277,17 @@ void presburger::mergeLocalVars(
relB.insertVar(VarKind::Local, 0, initLocals);
// Get division representations from each rel.
- std::vector<SmallVector<int64_t, 8>> divsA, divsB;
- SmallVector<unsigned, 4> denomsA, denomsB;
- relA.getLocalReprs(divsA, denomsA);
- relB.getLocalReprs(divsB, denomsB);
-
- // Copy division information for relB into `divsA` and `denomsA`, so that
- // these have the combined division information of both rels. Since newly
- // added local variables in relA and relB have no constraints, they will not
- // have any division representation.
- std::copy(divsB.begin() + initLocals, divsB.end(),
- divsA.begin() + initLocals);
- std::copy(denomsB.begin() + initLocals, denomsB.end(),
- denomsA.begin() + initLocals);
-
- // Merge all divisions by removing duplicate divisions.
- unsigned localOffset = relA.getVarKindOffset(VarKind::Local);
- presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
+ DivisionRepr divsA = relA.getLocalReprs();
+ DivisionRepr divsB = relB.getLocalReprs();
+
+ for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) {
+ divsA.setDividend(i, divsB.getDividend(i));
+ divsA.getDenom(i) = divsB.getDenom(i);
+ }
+
+ // Remove duplicate divisions from divsA. The removing duplicate divisions
+ // call, calls `merge` to effectively merge divisions in relA and relB.
+ divsA.removeDuplicateDivs(merge);
}
SmallVector<int64_t, 8> presburger::getDivUpperBound(ArrayRef<int64_t> dividend,
@@ -412,3 +356,59 @@ SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
--coeffs.back();
return coeffs;
}
+
+void DivisionRepr::removeDuplicateDivs(
+ llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
+
+ // Find and merge duplicate divisions.
+ // TODO: Add division normalization to support divisions that
diff er by
+ // a constant.
+ // TODO: Add division ordering such that a division representation for local
+ // variable at position `i` only depends on local variables at position <
+ // `i`. This would make sure that all divisions depending on other local
+ // variables that can be merged, are merged.
+ for (unsigned i = 0; i < getNumDivs(); ++i) {
+ // Check if a division representation exists for the `i^th` local var.
+ if (denoms[i] == 0)
+ continue;
+ // Check if a division exists which is a duplicate of the division at `i`.
+ for (unsigned j = i + 1; j < getNumDivs(); ++j) {
+ // Check if a division representation exists for the `j^th` local var.
+ if (denoms[j] == 0)
+ continue;
+ // Check if the denominators match.
+ if (denoms[i] != denoms[j])
+ continue;
+ // Check if the representations are equal.
+ if (dividends.getRow(i) != dividends.getRow(j))
+ continue;
+
+ // Merge divisions at position `j` into division at position `i`. If
+ // merge fails, do not merge these divs.
+ bool mergeResult = merge(i, j);
+ if (!mergeResult)
+ continue;
+
+ // Update division information to reflect merging.
+ unsigned divOffset = getDivOffset();
+ dividends.addToColumn(divOffset + j, divOffset + i, /*scale=*/1);
+ dividends.removeColumn(divOffset + j);
+ dividends.removeRow(j);
+ denoms.erase(denoms.begin() + j);
+
+ // Since `j` can never be zero, we do not need to worry about overflows.
+ --j;
+ }
+ }
+}
+
+void DivisionRepr::print(raw_ostream &os) const {
+ os << "Dividends:\n";
+ dividends.print(os);
+ os << "Denominators\n";
+ for (unsigned i = 0, e = denoms.size(); i < e; ++i)
+ os << denoms[i] << " ";
+ os << "\n";
+}
+
+void DivisionRepr::dump() const { print(llvm::errs()); }
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index ffe92cd1bcff7..593d818127384 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -865,7 +865,7 @@ static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
if (exprs[i])
foundRepr[i] = true;
- SmallVector<int64_t, 8> dividend;
+ SmallVector<int64_t, 8> dividend(cst.getNumCols());
unsigned divisor;
auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 581558cf9bbfb..a6c721d453461 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -608,23 +608,19 @@ TEST(IntegerPolyhedronTest, addConstantLowerBound) {
static void checkDivisionRepresentation(
IntegerPolyhedron &poly,
const std::vector<SmallVector<int64_t, 8>> &expectedDividends,
- const SmallVectorImpl<unsigned> &expectedDenominators) {
- std::vector<SmallVector<int64_t, 8>> dividends;
- SmallVector<unsigned, 4> denominators;
-
- poly.getLocalReprs(dividends, denominators);
+ ArrayRef<unsigned> expectedDenominators) {
+ DivisionRepr divs = poly.getLocalReprs();
// Check that the `denominators` and `expectedDenominators` match.
- EXPECT_TRUE(expectedDenominators == denominators);
+ EXPECT_TRUE(expectedDenominators == divs.getDenoms());
// Check that the `dividends` and `expectedDividends` match. If the
// denominator for a division is zero, we ignore its dividend.
- EXPECT_TRUE(dividends.size() == expectedDividends.size());
- for (unsigned i = 0, e = dividends.size(); i < e; ++i) {
- if (denominators[i] != 0) {
- EXPECT_TRUE(expectedDividends[i] == dividends[i]);
- }
- }
+ EXPECT_TRUE(divs.getNumDivs() == expectedDividends.size());
+ for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i)
+ if (divs.hasRepr(i))
+ for (unsigned j = 0, f = divs.getNumVars() + 1; j < f; ++j)
+ EXPECT_TRUE(expectedDividends[i][j] == divs.getDividend(i)[j]);
}
TEST(IntegerPolyhedronTest, computeLocalReprSimple) {
More information about the Mlir-commits
mailing list