[Mlir-commits] [mlir] 6d6f6c4 - [MLIR][Presburger] use arbitrary-precision arithmetic with MPInt instead of int64_t
Arjun P
llvmlistbot at llvm.org
Wed Sep 14 07:47:30 PDT 2022
Author: Arjun P
Date: 2022-09-14T15:47:41+01:00
New Revision: 6d6f6c4d3f40b2d0220fca73b27b8e74e64ba55a
URL: https://github.com/llvm/llvm-project/commit/6d6f6c4d3f40b2d0220fca73b27b8e74e64ba55a
DIFF: https://github.com/llvm/llvm-project/commit/6d6f6c4d3f40b2d0220fca73b27b8e74e64ba55a.diff
LOG: [MLIR][Presburger] use arbitrary-precision arithmetic with MPInt instead of int64_t
Only the main Presburger library under the Presburger directory has been switched to use arbitrary precision. Users have been changed to just cast returned values back to int64_t or to use newly added convenience functions that perform the same cast internally.
The performance impact of this has been tested by checking test runtimes after copy-pasting 100 copies of each function. Affine/simplify-structures.mlir goes from 0.76s to 0.80s after this patch. Its performance sees no regression compared to its original performance at commit 18a06d4f3a7474d062d1fe7d405813ed2e40b4fc before a series of patches that I landed to offset the performance overhead of switching to arbitrary precision.
Affine/canonicalize.mlir and SCF/canonicalize.mlir show no noticable difference, staying at 2.02s and about 2.35s respectively.
Also, for Affine and SCF tests as a whole (no copy-pasting), the runtime remains about 0.09s on average before and after.
Reviewed By: bondhugula
Differential Revision: https://reviews.llvm.org/D129510
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/Fraction.h
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/LinearTransform.h
mlir/include/mlir/Analysis/Presburger/Matrix.h
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/include/mlir/Analysis/Presburger/Utils.h
mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/LinearTransform.cpp
mlir/lib/Analysis/Presburger/Matrix.cpp
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/lib/Analysis/Presburger/Utils.cpp
mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
mlir/unittests/Analysis/Presburger/SimplexTest.cpp
mlir/unittests/Analysis/Presburger/Utils.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/Fraction.h b/mlir/include/mlir/Analysis/Presburger/Fraction.h
index e26c47f6e341..c51b6c972bf8 100644
--- a/mlir/include/mlir/Analysis/Presburger/Fraction.h
+++ b/mlir/include/mlir/Analysis/Presburger/Fraction.h
@@ -14,6 +14,7 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_FRACTION_H
#define MLIR_ANALYSIS_PRESBURGER_FRACTION_H
+#include "mlir/Analysis/Presburger/MPInt.h"
#include "mlir/Support/MathExtras.h"
namespace mlir {
@@ -29,30 +30,34 @@ struct Fraction {
Fraction() = default;
/// Construct a Fraction from a numerator and denominator.
- Fraction(int64_t oNum, int64_t oDen) : num(oNum), den(oDen) {
+ Fraction(const MPInt &oNum, const MPInt &oDen) : num(oNum), den(oDen) {
if (den < 0) {
num = -num;
den = -den;
}
}
+ /// Overloads for passing literals.
+ Fraction(const MPInt &num, int64_t den) : Fraction(num, MPInt(den)) {}
+ Fraction(int64_t num, const MPInt &den) : Fraction(MPInt(num), den) {}
+ Fraction(int64_t num, int64_t den) : Fraction(MPInt(num), MPInt(den)) {}
// Return the value of the fraction as an integer. This should only be called
// when the fraction's value is really an integer.
- int64_t getAsInteger() const {
+ MPInt getAsInteger() const {
assert(num % den == 0 && "Get as integer called on non-integral fraction!");
return num / den;
}
/// The numerator and denominator, respectively. The denominator is always
/// positive.
- int64_t num{0}, den{1};
+ MPInt num{0}, den{1};
};
/// Three-way comparison between two fractions.
/// Returns +1, 0, and -1 if the first fraction is greater than, equal to, or
/// less than the second fraction, respectively.
-inline int compare(Fraction x, Fraction y) {
- int64_t
diff = x.num * y.den - y.num * x.den;
+inline int compare(const Fraction &x, const Fraction &y) {
+ MPInt
diff = x.num * y.den - y.num * x.den;
if (
diff > 0)
return +1;
if (
diff < 0)
@@ -60,25 +65,37 @@ inline int compare(Fraction x, Fraction y) {
return 0;
}
-inline int64_t floor(Fraction f) { return floorDiv(f.num, f.den); }
+inline MPInt floor(const Fraction &f) { return floorDiv(f.num, f.den); }
-inline int64_t ceil(Fraction f) { return ceilDiv(f.num, f.den); }
+inline MPInt ceil(const Fraction &f) { return ceilDiv(f.num, f.den); }
-inline Fraction operator-(Fraction x) { return Fraction(-x.num, x.den); }
+inline Fraction operator-(const Fraction &x) { return Fraction(-x.num, x.den); }
-inline bool operator<(Fraction x, Fraction y) { return compare(x, y) < 0; }
+inline bool operator<(const Fraction &x, const Fraction &y) {
+ return compare(x, y) < 0;
+}
-inline bool operator<=(Fraction x, Fraction y) { return compare(x, y) <= 0; }
+inline bool operator<=(const Fraction &x, const Fraction &y) {
+ return compare(x, y) <= 0;
+}
-inline bool operator==(Fraction x, Fraction y) { return compare(x, y) == 0; }
+inline bool operator==(const Fraction &x, const Fraction &y) {
+ return compare(x, y) == 0;
+}
-inline bool operator!=(Fraction x, Fraction y) { return compare(x, y) != 0; }
+inline bool operator!=(const Fraction &x, const Fraction &y) {
+ return compare(x, y) != 0;
+}
-inline bool operator>(Fraction x, Fraction y) { return compare(x, y) > 0; }
+inline bool operator>(const Fraction &x, const Fraction &y) {
+ return compare(x, y) > 0;
+}
-inline bool operator>=(Fraction x, Fraction y) { return compare(x, y) >= 0; }
+inline bool operator>=(const Fraction &x, const Fraction &y) {
+ return compare(x, y) >= 0;
+}
-inline Fraction operator*(Fraction x, Fraction y) {
+inline Fraction operator*(const Fraction &x, const Fraction &y) {
return Fraction(x.num * y.num, x.den * y.den);
}
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 320eb1536176..2c4aa9f6d1eb 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -133,14 +133,24 @@ class IntegerRelation {
bool isSubsetOf(const IntegerRelation &other) const;
/// Returns the value at the specified equality row and column.
- inline int64_t atEq(unsigned i, unsigned j) const { return equalities(i, j); }
- inline int64_t &atEq(unsigned i, unsigned j) { return equalities(i, j); }
+ inline MPInt atEq(unsigned i, unsigned j) const { return equalities(i, j); }
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ inline int64_t atEq64(unsigned i, unsigned j) const {
+ return int64_t(equalities(i, j));
+ }
+ inline MPInt &atEq(unsigned i, unsigned j) { return equalities(i, j); }
/// Returns the value at the specified inequality row and column.
- inline int64_t atIneq(unsigned i, unsigned j) const {
+ inline MPInt atIneq(unsigned i, unsigned j) const {
return inequalities(i, j);
}
- inline int64_t &atIneq(unsigned i, unsigned j) { return inequalities(i, j); }
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ inline int64_t atIneq64(unsigned i, unsigned j) const {
+ return int64_t(inequalities(i, j));
+ }
+ inline MPInt &atIneq(unsigned i, unsigned j) { return inequalities(i, j); }
unsigned getNumConstraints() const {
return getNumInequalities() + getNumEqualities();
@@ -174,13 +184,20 @@ class IntegerRelation {
return inequalities.getNumReservedRows();
}
- inline ArrayRef<int64_t> getEquality(unsigned idx) const {
+ inline ArrayRef<MPInt> getEquality(unsigned idx) const {
return equalities.getRow(idx);
}
-
- inline ArrayRef<int64_t> getInequality(unsigned idx) const {
+ inline ArrayRef<MPInt> getInequality(unsigned idx) const {
return inequalities.getRow(idx);
}
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ inline SmallVector<int64_t, 8> getEquality64(unsigned idx) const {
+ return getInt64Vec(equalities.getRow(idx));
+ }
+ inline SmallVector<int64_t, 8> getInequality64(unsigned idx) const {
+ return getInt64Vec(inequalities.getRow(idx));
+ }
/// Get the number of vars of the specified kind.
unsigned getNumVarKind(VarKind kind) const {
@@ -245,9 +262,13 @@ class IntegerRelation {
unsigned appendVar(VarKind kind, unsigned num = 1);
/// Adds an inequality (>= 0) from the coefficients specified in `inEq`.
- void addInequality(ArrayRef<int64_t> inEq);
+ void addInequality(ArrayRef<MPInt> inEq);
+ void addInequality(ArrayRef<int64_t> inEq) {
+ addInequality(getMPIntVec(inEq));
+ }
/// Adds an equality from the coefficients specified in `eq`.
- void addEquality(ArrayRef<int64_t> eq);
+ void addEquality(ArrayRef<MPInt> eq);
+ void addEquality(ArrayRef<int64_t> eq) { addEquality(getMPIntVec(eq)); }
/// Eliminate the `posB^th` local variable, replacing every instance of it
/// with the `posA^th` local variable. This should be used when the two
@@ -282,7 +303,7 @@ class IntegerRelation {
/// For a generic integer sampling operation, findIntegerSample is more
/// robust and should be preferred. Note that Domain is minimized first, then
/// range.
- MaybeOptimum<SmallVector<int64_t, 8>> findIntegerLexMin() const;
+ MaybeOptimum<SmallVector<MPInt, 8>> findIntegerLexMin() const;
/// Swap the posA^th variable with the posB^th variable.
virtual void swapVar(unsigned posA, unsigned posB);
@@ -292,7 +313,10 @@ class IntegerRelation {
/// Sets the `values.size()` variables starting at `po`s to the specified
/// values and removes them.
- void setAndEliminate(unsigned pos, ArrayRef<int64_t> values);
+ void setAndEliminate(unsigned pos, ArrayRef<MPInt> values);
+ void setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
+ setAndEliminate(pos, getMPIntVec(values));
+ }
/// Replaces the contents of this IntegerRelation with `other`.
virtual void clearAndCopyFrom(const IntegerRelation &other);
@@ -337,20 +361,27 @@ class IntegerRelation {
///
/// Returns an integer sample point if one exists, or an empty Optional
/// otherwise. The returned value also includes values of local ids.
- Optional<SmallVector<int64_t, 8>> findIntegerSample() const;
+ Optional<SmallVector<MPInt, 8>> findIntegerSample() const;
/// Compute an overapproximation of the number of integer points in the
/// relation. Symbol vars currently not supported. If the computed
/// overapproximation is infinite, an empty optional is returned.
- Optional<uint64_t> computeVolume() const;
+ Optional<MPInt> computeVolume() const;
/// Returns true if the given point satisfies the constraints, or false
/// otherwise. Takes the values of all vars including locals.
- bool containsPoint(ArrayRef<int64_t> point) const;
+ bool containsPoint(ArrayRef<MPInt> point) const;
+ bool containsPoint(ArrayRef<int64_t> point) const {
+ return containsPoint(getMPIntVec(point));
+ }
/// Given the values of non-local vars, return a satisfying assignment to the
/// local if one exists, or an empty optional otherwise.
- Optional<SmallVector<int64_t, 8>>
- containsPointNoLocal(ArrayRef<int64_t> point) const;
+ Optional<SmallVector<MPInt, 8>>
+ containsPointNoLocal(ArrayRef<MPInt> point) const;
+ Optional<SmallVector<MPInt, 8>>
+ containsPointNoLocal(ArrayRef<int64_t> point) const {
+ return containsPointNoLocal(getMPIntVec(point));
+ }
/// Returns a `DivisonRepr` representing the division representation of local
/// variables in the constraint system.
@@ -367,17 +398,26 @@ class IntegerRelation {
enum BoundType { EQ, LB, UB };
/// Adds a constant bound for the specified variable.
- void addBound(BoundType type, unsigned pos, int64_t value);
+ void addBound(BoundType type, unsigned pos, const MPInt &value);
+ void addBound(BoundType type, unsigned pos, int64_t value) {
+ addBound(type, pos, MPInt(value));
+ }
/// Adds a constant bound for the specified expression.
- void addBound(BoundType type, ArrayRef<int64_t> expr, int64_t value);
+ void addBound(BoundType type, ArrayRef<MPInt> expr, const MPInt &value);
+ void addBound(BoundType type, ArrayRef<int64_t> expr, int64_t value) {
+ addBound(type, getMPIntVec(expr), MPInt(value));
+ }
/// Adds a new local variable as the floordiv of an affine function of other
/// variables, the coefficients of which are provided in `dividend` and with
/// respect to a positive constant `divisor`. Two constraints are added to the
/// system to capture equivalence with the floordiv:
/// q = dividend floordiv c <=> c*q <= dividend <= c*q + c - 1.
- void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor);
+ void addLocalFloorDiv(ArrayRef<MPInt> dividend, const MPInt &divisor);
+ void addLocalFloorDiv(ArrayRef<int64_t> dividend, int64_t divisor) {
+ addLocalFloorDiv(getMPIntVec(dividend), MPInt(divisor));
+ }
/// Projects out (aka eliminates) `num` variables starting at position
/// `pos`. The resulting constraint system is the shadow along the dimensions
@@ -432,15 +472,38 @@ class IntegerRelation {
/// lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three
/// symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See comments at
/// function definition for examples.
- Optional<int64_t> getConstantBoundOnDimSize(
+ Optional<MPInt> getConstantBoundOnDimSize(
+ unsigned pos, SmallVectorImpl<MPInt> *lb = nullptr,
+ MPInt *boundFloorDivisor = nullptr, SmallVectorImpl<MPInt> *ub = nullptr,
+ unsigned *minLbPos = nullptr, unsigned *minUbPos = nullptr) const;
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ Optional<int64_t> getConstantBoundOnDimSize64(
unsigned pos, SmallVectorImpl<int64_t> *lb = nullptr,
int64_t *boundFloorDivisor = nullptr,
SmallVectorImpl<int64_t> *ub = nullptr, unsigned *minLbPos = nullptr,
- unsigned *minUbPos = nullptr) const;
+ unsigned *minUbPos = nullptr) const {
+ SmallVector<MPInt, 8> ubMPInt, lbMPInt;
+ MPInt boundFloorDivisorMPInt;
+ Optional<MPInt> result = getConstantBoundOnDimSize(
+ pos, &lbMPInt, &boundFloorDivisorMPInt, &ubMPInt, minLbPos, minUbPos);
+ if (lb)
+ *lb = getInt64Vec(lbMPInt);
+ if (ub)
+ *ub = getInt64Vec(ubMPInt);
+ if (boundFloorDivisor)
+ *boundFloorDivisor = int64_t(boundFloorDivisorMPInt);
+ return result.transform(int64FromMPInt);
+ }
/// Returns the constant bound for the pos^th variable if there is one;
/// None otherwise.
- Optional<int64_t> getConstantBound(BoundType type, unsigned pos) const;
+ Optional<MPInt> getConstantBound(BoundType type, unsigned pos) const;
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ Optional<int64_t> getConstantBound64(BoundType type, unsigned pos) const {
+ return getConstantBound(type, pos).transform(int64FromMPInt);
+ }
/// Removes constraints that are independent of (i.e., do not have a
/// coefficient) variables in the range [pos, pos + num).
@@ -619,7 +682,13 @@ class IntegerRelation {
/// Returns the constant lower bound bound if isLower is true, and the upper
/// bound if isLower is false.
template <bool isLower>
- Optional<int64_t> computeConstantLowerOrUpperBound(unsigned pos);
+ Optional<MPInt> computeConstantLowerOrUpperBound(unsigned pos);
+ /// The same, but casts to int64_t. This is unsafe and will assert-fail if the
+ /// value does not fit in an int64_t.
+ template <bool isLower>
+ Optional<int64_t> computeConstantLowerOrUpperBound64(unsigned pos) {
+ return computeConstantLowerOrUpperBound<isLower>(pos).map(int64FromMPInt);
+ }
/// Eliminates a single variable at `position` from equality and inequality
/// constraints. Returns `success` if the variable was eliminated, and
diff --git a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
index d3f44c8fbc0d..589dc17084c7 100644
--- a/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
+++ b/mlir/include/mlir/Analysis/Presburger/LinearTransform.h
@@ -40,14 +40,13 @@ class LinearTransform {
// The given vector is interpreted as a row vector v. Post-multiply v with
// this transform, say T, and return vT.
- SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
+ SmallVector<MPInt, 8> preMultiplyWithRow(ArrayRef<MPInt> rowVec) const {
return matrix.preMultiplyWithRow(rowVec);
}
// The given vector is interpreted as a column vector v. Pre-multiply v with
// this transform, say T, and return Tv.
- SmallVector<int64_t, 8>
- postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
+ SmallVector<MPInt, 8> postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
return matrix.postMultiplyWithColumn(colVec);
}
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index c9fff4ae5683..659a0d77edd6 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -14,6 +14,7 @@
#ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H
#define MLIR_ANALYSIS_PRESBURGER_MATRIX_H
+#include "mlir/Analysis/Presburger/MPInt.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
@@ -48,21 +49,21 @@ class Matrix {
static Matrix identity(unsigned dimension);
/// Access the element at the specified row and column.
- int64_t &at(unsigned row, unsigned column) {
+ MPInt &at(unsigned row, unsigned column) {
assert(row < nRows && "Row outside of range");
assert(column < nColumns && "Column outside of range");
return data[row * nReservedColumns + column];
}
- int64_t at(unsigned row, unsigned column) const {
+ MPInt at(unsigned row, unsigned column) const {
assert(row < nRows && "Row outside of range");
assert(column < nColumns && "Column outside of range");
return data[row * nReservedColumns + column];
}
- int64_t &operator()(unsigned row, unsigned column) { return at(row, column); }
+ MPInt &operator()(unsigned row, unsigned column) { return at(row, column); }
- int64_t operator()(unsigned row, unsigned column) const {
+ MPInt operator()(unsigned row, unsigned column) const {
return at(row, column);
}
@@ -86,11 +87,11 @@ class Matrix {
void reserveRows(unsigned rows);
/// Get a [Mutable]ArrayRef corresponding to the specified row.
- MutableArrayRef<int64_t> getRow(unsigned row);
- ArrayRef<int64_t> getRow(unsigned row) const;
+ MutableArrayRef<MPInt> getRow(unsigned row);
+ ArrayRef<MPInt> getRow(unsigned row) const;
/// Set the specified row to `elems`.
- void setRow(unsigned row, ArrayRef<int64_t> elems);
+ void setRow(unsigned row, ArrayRef<MPInt> elems);
/// Insert columns having positions pos, pos + 1, ... pos + count - 1.
/// Columns that were at positions 0 to pos - 1 will stay where they are;
@@ -124,15 +125,24 @@ class Matrix {
void copyRow(unsigned sourceRow, unsigned targetRow);
- void fillRow(unsigned row, int64_t value);
+ void fillRow(unsigned row, const MPInt &value);
+ void fillRow(unsigned row, int64_t value) { fillRow(row, MPInt(value)); }
/// Add `scale` multiples of the source row to the target row.
- void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale);
+ void addToRow(unsigned sourceRow, unsigned targetRow, const MPInt &scale);
+ void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
+ addToRow(sourceRow, targetRow, MPInt(scale));
+ }
/// Add `scale` multiples of the rowVec row to the specified row.
- void addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale);
+ void addToRow(unsigned row, ArrayRef<MPInt> rowVec, const MPInt &scale);
/// Add `scale` multiples of the source column to the target column.
- void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale);
+ void addToColumn(unsigned sourceColumn, unsigned targetColumn,
+ const MPInt &scale);
+ void addToColumn(unsigned sourceColumn, unsigned targetColumn,
+ int64_t scale) {
+ addToColumn(sourceColumn, targetColumn, MPInt(scale));
+ }
/// Negate the specified column.
void negateColumn(unsigned column);
@@ -142,19 +152,18 @@ class Matrix {
/// Divide the first `nCols` of the specified row by their GCD.
/// Returns the GCD of the first `nCols` of the specified row.
- int64_t normalizeRow(unsigned row, unsigned nCols);
+ MPInt normalizeRow(unsigned row, unsigned nCols);
/// Divide the columns of the specified row by their GCD.
/// Returns the GCD of the columns of the specified row.
- int64_t normalizeRow(unsigned row);
+ MPInt normalizeRow(unsigned row);
/// The given vector is interpreted as a row vector v. Post-multiply v with
/// this matrix, say M, and return vM.
- SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
+ SmallVector<MPInt, 8> preMultiplyWithRow(ArrayRef<MPInt> rowVec) const;
/// The given vector is interpreted as a column vector v. Pre-multiply v with
/// this matrix, say M, and return Mv.
- SmallVector<int64_t, 8>
- postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
+ SmallVector<MPInt, 8> postMultiplyWithColumn(ArrayRef<MPInt> colVec) const;
/// Resize the matrix to the specified dimensions. If a dimension is smaller,
/// the values are truncated; if it is bigger, the new values are initialized
@@ -171,7 +180,7 @@ class Matrix {
unsigned appendExtraRow();
/// Same as above, but copy the given elements into the row. The length of
/// `elems` must be equal to the number of columns.
- unsigned appendExtraRow(ArrayRef<int64_t> elems);
+ unsigned appendExtraRow(ArrayRef<MPInt> elems);
/// Print the matrix.
void print(raw_ostream &os) const;
@@ -190,7 +199,7 @@ class Matrix {
/// Stores the data. data.size() is equal to nRows * nReservedColumns.
/// data.capacity() / nReservedColumns is the number of reserved rows.
- SmallVector<int64_t, 16> data;
+ SmallVector<MPInt, 16> data;
};
} // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 63f3ecfca968..172e202d69e8 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -62,7 +62,7 @@ class MultiAffineFunction {
/// Get a matrix with each row representing row^th output expression.
const Matrix &getOutputMatrix() const { return output; }
/// Get the `i^th` output expression.
- ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
+ ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
// Remove the specified range of outputs.
void removeOutputs(unsigned start, unsigned end);
@@ -71,8 +71,11 @@ class MultiAffineFunction {
/// have the union of the division vars that exist in the functions.
void mergeDivs(MultiAffineFunction &other);
- /// Return the output of the function at the given point.
- SmallVector<int64_t, 8> valueAt(ArrayRef<int64_t> point) const;
+ //// Return the output of the function at the given point.
+ SmallVector<MPInt, 8> valueAt(ArrayRef<MPInt> point) const;
+ SmallVector<MPInt, 8> valueAt(ArrayRef<int64_t> point) const {
+ return valueAt(getMPIntVec(point));
+ }
/// Return whether the `this` and `other` are equal when the domain is
/// restricted to `domain`. This is the case if they lie in the same space,
@@ -172,7 +175,10 @@ class PWMAFunction {
PresburgerSet getDomain() const;
/// Return the output of the function at the given point.
- Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+ Optional<SmallVector<MPInt, 8>> valueAt(ArrayRef<MPInt> point) const;
+ Optional<SmallVector<MPInt, 8>> valueAt(ArrayRef<int64_t> point) const {
+ return valueAt(getMPIntVec(point));
+ }
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
/// they have the same dimensions, the same domain and they take the same
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index 8b4a4fe82e0f..541e89671f08 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -83,7 +83,10 @@ class PresburgerRelation {
PresburgerRelation intersect(const PresburgerRelation &set) const;
/// Return true if the set contains the given point, and false otherwise.
- bool containsPoint(ArrayRef<int64_t> point) const;
+ bool containsPoint(ArrayRef<MPInt> point) const;
+ bool containsPoint(ArrayRef<int64_t> point) const {
+ return containsPoint(getMPIntVec(point));
+ }
/// Return the complement of this set. All local variables in the set must
/// correspond to floor divisions.
@@ -108,7 +111,7 @@ class PresburgerRelation {
/// Find an integer sample from the given set. This should not be called if
/// any of the disjuncts in the union are unbounded.
- bool findIntegerSample(SmallVectorImpl<int64_t> &sample);
+ bool findIntegerSample(SmallVectorImpl<MPInt> &sample);
/// Compute an overapproximation of the number of integer points in the
/// disjunct. Symbol vars are currently not supported. If the computed
@@ -117,7 +120,7 @@ class PresburgerRelation {
/// This currently just sums up the overapproximations of the volumes of the
/// disjuncts, so the approximation might be far from the true volume in the
/// case when there is a lot of overlap between disjuncts.
- Optional<uint64_t> computeVolume() const;
+ Optional<MPInt> computeVolume() const;
/// Simplifies the representation of a PresburgerRelation.
///
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 485a064c6cce..34b77a029af0 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -166,7 +166,7 @@ class SimplexBase {
/// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding inequality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
- virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
+ virtual void addInequality(ArrayRef<MPInt> coeffs) = 0;
/// Returns the number of variables in the tableau.
unsigned getNumVariables() const;
@@ -177,14 +177,16 @@ class SimplexBase {
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding equality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
- void addEquality(ArrayRef<int64_t> coeffs);
+ void addEquality(ArrayRef<MPInt> coeffs);
/// Add new variables to the end of the list of variables.
void appendVariable(unsigned count = 1);
/// Append a new variable to the simplex and constrain it such that its only
/// integer value is the floor div of `coeffs` and `denom`.
- void addDivisionVariable(ArrayRef<int64_t> coeffs, int64_t denom);
+ ///
+ /// `denom` must be positive.
+ void addDivisionVariable(ArrayRef<MPInt> coeffs, const MPInt &denom);
/// Mark the tableau as being empty.
void markEmpty();
@@ -293,7 +295,7 @@ class SimplexBase {
/// con.
///
/// Returns the index of the new Unknown in con.
- unsigned addRow(ArrayRef<int64_t> coeffs, bool makeRestricted = false);
+ unsigned addRow(ArrayRef<MPInt> coeffs, bool makeRestricted = false);
/// Swap the two rows/columns in the tableau and associated data structures.
void swapRows(unsigned i, unsigned j);
@@ -421,7 +423,7 @@ class LexSimplexBase : public SimplexBase {
///
/// This just adds the inequality to the tableau and does not try to create a
/// consistent tableau configuration.
- void addInequality(ArrayRef<int64_t> coeffs) final;
+ void addInequality(ArrayRef<MPInt> coeffs) final;
/// Get a snapshot of the current state. This is used for rolling back.
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
@@ -493,15 +495,15 @@ class LexSimplex : public LexSimplexBase {
///
/// Note: this should be used only when the lexmin is really needed. To obtain
/// any integer sample, use Simplex::findIntegerSample as that is more robust.
- MaybeOptimum<SmallVector<int64_t, 8>> findIntegerLexMin();
+ MaybeOptimum<SmallVector<MPInt, 8>> findIntegerLexMin();
/// Return whether the specified inequality is redundant/separate for the
/// polytope. Redundant means every point satisfies the given inequality, and
/// separate means no point satisfies it.
///
/// These checks are integer-exact.
- bool isSeparateInequality(ArrayRef<int64_t> coeffs);
- bool isRedundantInequality(ArrayRef<int64_t> coeffs);
+ bool isSeparateInequality(ArrayRef<MPInt> coeffs);
+ bool isRedundantInequality(ArrayRef<MPInt> coeffs);
private:
/// Returns the current sample point, which may contain non-integer (rational)
@@ -654,11 +656,11 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// Get the numerator of the symbolic sample of the specific row.
/// This is an affine expression in the symbols with integer coefficients.
/// The last element is the constant term. This ignores the big M coefficient.
- SmallVector<int64_t, 8> getSymbolicSampleNumerator(unsigned row) const;
+ SmallVector<MPInt, 8> getSymbolicSampleNumerator(unsigned row) const;
/// Get an affine inequality in the symbols with integer coefficients that
/// holds iff the symbolic sample of the specified row is non-negative.
- SmallVector<int64_t, 8> getSymbolicSampleIneq(unsigned row) const;
+ SmallVector<MPInt, 8> getSymbolicSampleIneq(unsigned row) const;
/// Return whether all the coefficients of the symbolic sample are integers.
///
@@ -708,7 +710,7 @@ class Simplex : public SimplexBase {
///
/// This also tries to restore the tableau configuration to a consistent
/// state and marks the Simplex empty if this is not possible.
- void addInequality(ArrayRef<int64_t> coeffs) final;
+ void addInequality(ArrayRef<MPInt> coeffs) final;
/// Compute the maximum or minimum value of the given row, depending on
/// direction. The specified row is never pivoted. On return, the row may
@@ -724,7 +726,7 @@ class Simplex : public SimplexBase {
/// Returns a Fraction denoting the optimum, or a null value if no optimum
/// exists, i.e., if the expression is unbounded in this direction.
MaybeOptimum<Fraction> computeOptimum(Direction direction,
- ArrayRef<int64_t> coeffs);
+ ArrayRef<MPInt> coeffs);
/// Returns whether the perpendicular of the specified constraint is a
/// is a direction along which the polytope is bounded.
@@ -766,8 +768,8 @@ class Simplex : public SimplexBase {
/// Returns a (min, max) pair denoting the minimum and maximum integer values
/// of the given expression. If no integer value exists, both results will be
/// of kind Empty.
- std::pair<MaybeOptimum<int64_t>, MaybeOptimum<int64_t>>
- computeIntegerBounds(ArrayRef<int64_t> coeffs);
+ std::pair<MaybeOptimum<MPInt>, MaybeOptimum<MPInt>>
+ computeIntegerBounds(ArrayRef<MPInt> coeffs);
/// Returns true if the polytope is unbounded, i.e., extends to infinity in
/// some direction. Otherwise, returns false.
@@ -779,7 +781,7 @@ class Simplex : public SimplexBase {
/// Returns an integer sample point if one exists, or None
/// otherwise. This should only be called for bounded sets.
- Optional<SmallVector<int64_t, 8>> findIntegerSample();
+ Optional<SmallVector<MPInt, 8>> findIntegerSample();
enum class IneqType { Redundant, Cut, Separate };
@@ -789,13 +791,13 @@ class Simplex : public SimplexBase {
/// Redundant The inequality is satisfied in the polytope
/// Cut The inequality is satisfied by some points, but not by others
/// Separate The inequality is not satisfied by any point
- IneqType findIneqType(ArrayRef<int64_t> coeffs);
+ IneqType findIneqType(ArrayRef<MPInt> coeffs);
/// Check if the specified inequality already holds in the polytope.
- bool isRedundantInequality(ArrayRef<int64_t> coeffs);
+ bool isRedundantInequality(ArrayRef<MPInt> coeffs);
/// Check if the specified equality already holds in the polytope.
- bool isRedundantEquality(ArrayRef<int64_t> coeffs);
+ bool isRedundantEquality(ArrayRef<MPInt> coeffs);
/// Returns true if this Simplex's polytope is a rational subset of `rel`.
/// Otherwise, returns false.
@@ -803,7 +805,7 @@ class Simplex : public SimplexBase {
/// Returns the current sample point if it is integral. Otherwise, returns
/// None.
- Optional<SmallVector<int64_t, 8>> getSamplePointIfIntegral() const;
+ Optional<SmallVector<MPInt, 8>> getSamplePointIfIntegral() const;
/// Returns the current sample point, which may contain non-integer (rational)
/// coordinates. Returns an empty optional when the tableau is empty.
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 3801cb63af5e..fca0161139f8 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -109,14 +109,15 @@ struct MaybeLocalRepr {
/// system. The coefficients of the dividends are stored in order:
/// [nonLocalVars, localVars, constant]. Each local variable may or may not have
/// a representation. If the local does not have a representation, the dividend
-/// of the division has no meaning and the denominator is zero.
+/// of the division has no meaning and the denominator is zero. If it has a
+/// representation, the denominator will be positive.
///
/// The i^th division here, represents the division representation of the
/// variable at position `divOffset + i` in the constraint system.
class DivisionRepr {
public:
DivisionRepr(unsigned numVars, unsigned numDivs)
- : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
+ : dividends(numDivs, numVars + 1), denoms(numDivs, MPInt(0)) {}
DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {}
@@ -135,30 +136,26 @@ class DivisionRepr {
void clearRepr(unsigned i) { denoms[i] = 0; }
// Get the dividend of the `i^th` division.
- MutableArrayRef<int64_t> getDividend(unsigned i) {
- return dividends.getRow(i);
- }
- ArrayRef<int64_t> getDividend(unsigned i) const {
- return dividends.getRow(i);
- }
+ MutableArrayRef<MPInt> getDividend(unsigned i) { return dividends.getRow(i); }
+ ArrayRef<MPInt> getDividend(unsigned i) const { return dividends.getRow(i); }
// For a given point containing values for each variable other than the
// division variables, try to find the values for each division variable from
// their division representation.
- SmallVector<Optional<int64_t>, 4> divValuesAt(ArrayRef<int64_t> point) const;
+ SmallVector<Optional<MPInt>, 4> divValuesAt(ArrayRef<MPInt> point) const;
// Get the `i^th` denominator.
- unsigned &getDenom(unsigned i) { return denoms[i]; }
- unsigned getDenom(unsigned i) const { return denoms[i]; }
+ MPInt &getDenom(unsigned i) { return denoms[i]; }
+ MPInt getDenom(unsigned i) const { return denoms[i]; }
- ArrayRef<unsigned> getDenoms() const { return denoms; }
+ ArrayRef<MPInt> getDenoms() const { return denoms; }
- void setDiv(unsigned i, ArrayRef<int64_t> dividend, unsigned divisor) {
+ void setDiv(unsigned i, ArrayRef<MPInt> dividend, const MPInt &divisor) {
dividends.setRow(i, dividend);
denoms[i] = divisor;
}
- void insertDiv(unsigned pos, ArrayRef<int64_t> dividend, unsigned divisor);
+ void insertDiv(unsigned pos, ArrayRef<MPInt> dividend, const MPInt &divisor);
void insertDiv(unsigned pos, unsigned num = 1);
/// Removes duplicate divisions. On every possible duplicate division found,
@@ -183,7 +180,8 @@ class DivisionRepr {
/// Denominators of each division. If a denominator of a division is `0`, the
/// division variable is considered to not have a division representation.
- SmallVector<unsigned, 4> denoms;
+ /// Otherwise, the denominator is positive.
+ SmallVector<MPInt, 4> denoms;
};
/// If `q` is defined to be equal to `expr floordiv d`, this equivalent to
@@ -200,10 +198,13 @@ class DivisionRepr {
///
/// The coefficient of `q` in `dividend` must be zero, as it is not allowed for
/// local variable to be a floor division of an expression involving itself.
-SmallVector<int64_t, 8> getDivUpperBound(ArrayRef<int64_t> dividend,
- int64_t divisor, unsigned localVarIdx);
-SmallVector<int64_t, 8> getDivLowerBound(ArrayRef<int64_t> dividend,
- int64_t divisor, unsigned localVarIdx);
+/// The divisor must be positive.
+SmallVector<MPInt, 8> getDivUpperBound(ArrayRef<MPInt> dividend,
+ const MPInt &divisor,
+ unsigned localVarIdx);
+SmallVector<MPInt, 8> getDivLowerBound(ArrayRef<MPInt> dividend,
+ const MPInt &divisor,
+ unsigned localVarIdx);
llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
unsigned numSet);
@@ -216,14 +217,22 @@ llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
SmallVector<MPInt, 8> getMPIntVec(ArrayRef<int64_t> range);
/// Return the given array as an array of int64_t.
SmallVector<int64_t, 8> getInt64Vec(ArrayRef<MPInt> range);
+
/// Returns the `MaybeLocalRepr` struct which contains the indices of the
/// constraints that can be expressed as a floordiv of an affine function. If
-/// the representation could be computed, `dividend` and `denominator` are set.
-/// If the representation could not be computed, the kind attribute in
-/// `MaybeLocalRepr` is set to None.
+/// the representation could be computed, `dividend` and `divisor` are set,
+/// in which case, denominator will be positive. If the representation could
+/// not be computed, the kind attribute in `MaybeLocalRepr` is set to None.
MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst,
ArrayRef<bool> foundRepr, unsigned pos,
- MutableArrayRef<int64_t> dividend,
+ MutableArrayRef<MPInt> dividend,
+ MPInt &divisor);
+
+/// The following overload using int64_t is required for a callsite in
+/// AffineStructures.h.
+MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst,
+ ArrayRef<bool> foundRepr, unsigned pos,
+ SmallVector<int64_t, 8> ÷nd,
unsigned &divisor);
/// Given two relations, A and B, add additional local vars to the sets such
@@ -242,26 +251,25 @@ void mergeLocalVars(IntegerRelation &relA, IntegerRelation &relB,
llvm::function_ref<bool(unsigned i, unsigned j)> merge);
/// Compute the gcd of the range.
-int64_t gcdRange(ArrayRef<int64_t> range);
+MPInt gcdRange(ArrayRef<MPInt> range);
/// Divide the range by its gcd and return the gcd.
-int64_t normalizeRange(MutableArrayRef<int64_t> range);
+MPInt normalizeRange(MutableArrayRef<MPInt> range);
/// Normalize the given (numerator, denominator) pair by dividing out the
/// common factors between them. The numerator here is an affine expression
-/// with integer coefficients.
-void normalizeDiv(MutableArrayRef<int64_t> num, int64_t &denom);
+/// with integer coefficients. The denominator must be positive.
+void normalizeDiv(MutableArrayRef<MPInt> num, MPInt &denom);
/// Return `coeffs` with all the elements negated.
-SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs);
+SmallVector<MPInt, 8> getNegatedCoeffs(ArrayRef<MPInt> coeffs);
/// Return the complement of the given inequality.
///
/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0,
/// since all the variables are constrained to be integers.
-SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq);
-
+SmallVector<MPInt, 8> getComplementIneq(ArrayRef<MPInt> ineq);
} // namespace presburger
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index 0781748f2f8b..9caa589a69bf 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -320,7 +320,7 @@ struct MemRefRegion {
SmallVectorImpl<int64_t> *lb = nullptr,
int64_t *lbFloorDivisor = nullptr) const {
assert(pos < getRank() && "invalid position");
- return cst.getConstantBoundOnDimSize(pos, lb);
+ return cst.getConstantBoundOnDimSize64(pos, lb);
}
/// Returns the size of this MemRefRegion in bytes.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 03252ce0f4c8..34c64a3911fe 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -104,10 +104,9 @@ IntegerRelation::findRationalLexMin() const {
return maybeLexMin;
}
-MaybeOptimum<SmallVector<int64_t, 8>>
-IntegerRelation::findIntegerLexMin() const {
+MaybeOptimum<SmallVector<MPInt, 8>> IntegerRelation::findIntegerLexMin() const {
assert(getNumSymbolVars() == 0 && "Symbols are not supported!");
- MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin =
+ MaybeOptimum<SmallVector<MPInt, 8>> maybeLexMin =
LexSimplex(*this).findIntegerLexMin();
if (!maybeLexMin.isBounded())
@@ -124,8 +123,8 @@ IntegerRelation::findIntegerLexMin() const {
return maybeLexMin;
}
-static bool rangeIsZero(ArrayRef<int64_t> range) {
- return llvm::all_of(range, [](int64_t x) { return x == 0; });
+static bool rangeIsZero(ArrayRef<MPInt> range) {
+ return llvm::all_of(range, [](const MPInt &x) { return x == 0; });
}
static void removeConstraintsInvolvingVarRange(IntegerRelation &poly,
@@ -273,14 +272,14 @@ unsigned IntegerRelation::appendVar(VarKind kind, unsigned num) {
return insertVar(kind, pos, num);
}
-void IntegerRelation::addEquality(ArrayRef<int64_t> eq) {
+void IntegerRelation::addEquality(ArrayRef<MPInt> eq) {
assert(eq.size() == getNumCols());
unsigned row = equalities.appendExtraRow();
for (unsigned i = 0, e = eq.size(); i < e; ++i)
equalities(row, i) = eq[i];
}
-void IntegerRelation::addInequality(ArrayRef<int64_t> inEq) {
+void IntegerRelation::addInequality(ArrayRef<MPInt> inEq) {
assert(inEq.size() == getNumCols());
unsigned row = inequalities.appendExtraRow();
for (unsigned i = 0, e = inEq.size(); i < e; ++i)
@@ -445,7 +444,7 @@ bool IntegerRelation::hasConsistentState() const {
return true;
}
-void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<int64_t> values) {
+void IntegerRelation::setAndEliminate(unsigned pos, ArrayRef<MPInt> values) {
if (values.empty())
return;
assert(pos + values.size() <= getNumVars() &&
@@ -471,7 +470,7 @@ void IntegerRelation::clearAndCopyFrom(const IntegerRelation &other) {
bool IntegerRelation::findConstraintWithNonZeroAt(unsigned colIdx, bool isEq,
unsigned *rowIdx) const {
assert(colIdx < getNumCols() && "position out of bounds");
- auto at = [&](unsigned rowIdx) -> int64_t {
+ auto at = [&](unsigned rowIdx) -> MPInt {
return isEq ? atEq(rowIdx, colIdx) : atIneq(rowIdx, colIdx);
};
unsigned e = isEq ? getNumEqualities() : getNumInequalities();
@@ -498,7 +497,7 @@ bool IntegerRelation::hasInvalidConstraint() const {
for (unsigned i = 0, e = numRows; i < e; ++i) {
unsigned j;
for (j = 0; j < numCols - 1; ++j) {
- int64_t v = isEq ? atEq(i, j) : atIneq(i, j);
+ MPInt v = isEq ? atEq(i, j) : atIneq(i, j);
// Skip rows with non-zero variable coefficients.
if (v != 0)
break;
@@ -508,7 +507,7 @@ bool IntegerRelation::hasInvalidConstraint() const {
}
// Check validity of constant term at 'numCols - 1' w.r.t 'isEq'.
// Example invalid constraints include: '1 == 0' or '-1 >= 0'
- int64_t v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
+ MPInt v = isEq ? atEq(i, numCols - 1) : atIneq(i, numCols - 1);
if ((isEq && v != 0) || (!isEq && v < 0)) {
return true;
}
@@ -530,26 +529,26 @@ static void eliminateFromConstraint(IntegerRelation *constraints,
// Skip if equality 'rowIdx' if same as 'pivotRow'.
if (isEq && rowIdx == pivotRow)
return;
- auto at = [&](unsigned i, unsigned j) -> int64_t {
+ auto at = [&](unsigned i, unsigned j) -> MPInt {
return isEq ? constraints->atEq(i, j) : constraints->atIneq(i, j);
};
- int64_t leadCoeff = at(rowIdx, pivotCol);
+ MPInt leadCoeff = at(rowIdx, pivotCol);
// Skip if leading coefficient at 'rowIdx' is already zero.
if (leadCoeff == 0)
return;
- int64_t pivotCoeff = constraints->atEq(pivotRow, pivotCol);
- int64_t sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
- int64_t lcm = std::lcm(pivotCoeff, leadCoeff);
- int64_t pivotMultiplier = sign * (lcm / std::abs(pivotCoeff));
- int64_t rowMultiplier = lcm / std::abs(leadCoeff);
+ MPInt pivotCoeff = constraints->atEq(pivotRow, pivotCol);
+ int sign = (leadCoeff * pivotCoeff > 0) ? -1 : 1;
+ MPInt lcm = presburger::lcm(pivotCoeff, leadCoeff);
+ MPInt pivotMultiplier = sign * (lcm / abs(pivotCoeff));
+ MPInt rowMultiplier = lcm / abs(leadCoeff);
unsigned numCols = constraints->getNumCols();
for (unsigned j = 0; j < numCols; ++j) {
// Skip updating column 'j' if it was just eliminated.
if (j >= elimColStart && j < pivotCol)
continue;
- int64_t v = pivotMultiplier * constraints->atEq(pivotRow, j) +
- rowMultiplier * at(rowIdx, j);
+ MPInt v = pivotMultiplier * constraints->atEq(pivotRow, j) +
+ rowMultiplier * at(rowIdx, j);
isEq ? constraints->atEq(rowIdx, j) = v
: constraints->atIneq(rowIdx, j) = v;
}
@@ -653,16 +652,15 @@ bool IntegerRelation::isEmpty() const {
// has an integer solution iff:
//
// GCD of c_1, c_2, ..., c_n divides c_0.
-//
bool IntegerRelation::isEmptyByGCDTest() const {
assert(hasConsistentState());
unsigned numCols = getNumCols();
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
- uint64_t gcd = std::abs(atEq(i, 0));
+ MPInt gcd = abs(atEq(i, 0));
for (unsigned j = 1; j < numCols - 1; ++j) {
- gcd = std::gcd(gcd, (uint64_t)std::abs(atEq(i, j)));
+ gcd = presburger::gcd(gcd, abs(atEq(i, j)));
}
- int64_t v = std::abs(atEq(i, numCols - 1));
+ MPInt v = abs(atEq(i, numCols - 1));
if (gcd > 0 && (v % gcd != 0)) {
return true;
}
@@ -765,7 +763,7 @@ bool IntegerRelation::isIntegerEmpty() const { return !findIntegerSample(); }
///
/// Concatenating the samples from B and C gives a sample v in S*T, so the
/// returned sample T*v is a sample in S.
-Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
+Optional<SmallVector<MPInt, 8>> IntegerRelation::findIntegerSample() const {
// First, try the GCD test heuristic.
if (isEmptyByGCDTest())
return {};
@@ -804,7 +802,7 @@ Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
boundedSet.removeVarRange(numBoundedDims, boundedSet.getNumVars());
// 3) Try to obtain a sample from the bounded set.
- Optional<SmallVector<int64_t, 8>> boundedSample =
+ Optional<SmallVector<MPInt, 8>> boundedSample =
Simplex(boundedSet).findIntegerSample();
if (!boundedSample)
return {};
@@ -843,7 +841,7 @@ Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
// amount for the shrunken cone.
for (unsigned i = 0, e = cone.getNumInequalities(); i < e; ++i) {
for (unsigned j = 0; j < cone.getNumVars(); ++j) {
- int64_t coeff = cone.atIneq(i, j);
+ MPInt coeff = cone.atIneq(i, j);
if (coeff < 0)
cone.atIneq(i, cone.getNumVars()) += coeff;
}
@@ -860,10 +858,10 @@ Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
SmallVector<Fraction, 8> shrunkenConeSample =
*shrunkenConeSimplex.getRationalSample();
- SmallVector<int64_t, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
+ SmallVector<MPInt, 8> coneSample(llvm::map_range(shrunkenConeSample, ceil));
// 6) Return transform * concat(boundedSample, coneSample).
- SmallVector<int64_t, 8> &sample = *boundedSample;
+ SmallVector<MPInt, 8> &sample = *boundedSample;
sample.append(coneSample.begin(), coneSample.end());
return transform.postMultiplyWithColumn(sample);
}
@@ -871,10 +869,10 @@ Optional<SmallVector<int64_t, 8>> IntegerRelation::findIntegerSample() const {
/// Helper to evaluate an affine expression at a point.
/// The expression is a list of coefficients for the dimensions followed by the
/// constant term.
-static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
+static MPInt valueAt(ArrayRef<MPInt> expr, ArrayRef<MPInt> point) {
assert(expr.size() == 1 + point.size() &&
"Dimensionalities of point and expression don't match!");
- int64_t value = expr.back();
+ MPInt value = expr.back();
for (unsigned i = 0; i < point.size(); ++i)
value += expr[i] * point[i];
return value;
@@ -883,7 +881,7 @@ static int64_t valueAt(ArrayRef<int64_t> expr, ArrayRef<int64_t> point) {
/// A point satisfies an equality iff the value of the equality at the
/// expression is zero, and it satisfies an inequality iff the value of the
/// inequality at that point is non-negative.
-bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
+bool IntegerRelation::containsPoint(ArrayRef<MPInt> point) const {
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i) {
if (valueAt(getEquality(i), point) != 0)
return false;
@@ -903,8 +901,8 @@ bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
/// compute the values of the locals that have division representations and
/// only use the integer emptiness check for the locals that don't have this.
/// Handling this correctly requires ordering the divs, though.
-Optional<SmallVector<int64_t, 8>>
-IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
+Optional<SmallVector<MPInt, 8>>
+IntegerRelation::containsPointNoLocal(ArrayRef<MPInt> point) const {
assert(point.size() == getNumVars() - getNumLocalVars() &&
"Point should contain all vars except locals!");
assert(getVarKindOffset(VarKind::Local) == getNumVars() - getNumLocalVars() &&
@@ -961,9 +959,9 @@ void IntegerRelation::gcdTightenInequalities() {
unsigned numCols = getNumCols();
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i) {
// Normalize the constraint and tighten the constant term by the GCD.
- int64_t gcd = inequalities.normalizeRow(i, getNumCols() - 1);
+ MPInt gcd = inequalities.normalizeRow(i, getNumCols() - 1);
if (gcd > 1)
- atIneq(i, numCols - 1) = mlir::floorDiv(atIneq(i, numCols - 1), gcd);
+ atIneq(i, numCols - 1) = floorDiv(atIneq(i, numCols - 1), gcd);
}
}
@@ -1082,14 +1080,14 @@ void IntegerRelation::removeRedundantConstraints() {
equalities.resizeVertically(pos);
}
-Optional<uint64_t> IntegerRelation::computeVolume() const {
+Optional<MPInt> IntegerRelation::computeVolume() const {
assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
Simplex simplex(*this);
// If the polytope is rationally empty, there are certainly no integer
// points.
if (simplex.isEmpty())
- return 0;
+ return MPInt(0);
// Just find the maximum and minimum integer value of each non-local var
// separately, thus finding the number of integer values each such var can
@@ -1105,8 +1103,8 @@ Optional<uint64_t> IntegerRelation::computeVolume() const {
//
// If there is no such empty dimension, if any dimension is unbounded we
// just return the result as unbounded.
- uint64_t count = 1;
- SmallVector<int64_t, 8> dim(getNumVars() + 1);
+ MPInt count(1);
+ SmallVector<MPInt, 8> dim(getNumVars() + 1);
bool hasUnboundedVar = false;
for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i) {
dim[i] = 1;
@@ -1126,13 +1124,13 @@ Optional<uint64_t> IntegerRelation::computeVolume() const {
// In this case there are no valid integer points and the volume is
// definitely zero.
if (min.getBoundedOptimum() > max.getBoundedOptimum())
- return 0;
+ return MPInt(0);
count *= (*max - *min + 1);
}
if (count == 0)
- return 0;
+ return MPInt(0);
if (hasUnboundedVar)
return {};
return count;
@@ -1224,7 +1222,7 @@ void IntegerRelation::removeRedundantLocalVars() {
for (i = 0, e = getNumEqualities(); i < e; ++i) {
// Find a local variable to eliminate using ith equality.
for (j = getNumDimAndSymbolVars(), f = getNumVars(); j < f; ++j)
- if (std::abs(atEq(i, j)) == 1)
+ if (abs(atEq(i, j)) == 1)
break;
// Local variable can be eliminated using ith equality.
@@ -1282,7 +1280,8 @@ void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
removeVarRange(srcKind, varStart, varLimit);
}
-void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
+void IntegerRelation::addBound(BoundType type, unsigned pos,
+ const MPInt &value) {
assert(pos < getNumCols());
if (type == BoundType::EQ) {
unsigned row = equalities.appendExtraRow();
@@ -1296,8 +1295,8 @@ void IntegerRelation::addBound(BoundType type, unsigned pos, int64_t value) {
}
}
-void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
- int64_t value) {
+void IntegerRelation::addBound(BoundType type, ArrayRef<MPInt> expr,
+ const MPInt &value) {
assert(type != BoundType::EQ && "EQ not implemented");
assert(expr.size() == getNumCols());
unsigned row = inequalities.appendExtraRow();
@@ -1312,15 +1311,15 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<int64_t> expr,
/// respect to a positive constant 'divisor'. Two constraints are added to the
/// system to capture equivalence with the floordiv.
/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
-void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
- int64_t divisor) {
+void IntegerRelation::addLocalFloorDiv(ArrayRef<MPInt> dividend,
+ const MPInt &divisor) {
assert(dividend.size() == getNumCols() && "incorrect dividend size");
assert(divisor > 0 && "positive divisor expected");
appendVar(VarKind::Local);
- SmallVector<int64_t, 8> dividendCopy(dividend.begin(), dividend.end());
- dividendCopy.insert(dividendCopy.end() - 1, 0);
+ SmallVector<MPInt, 8> dividendCopy(dividend.begin(), dividend.end());
+ dividendCopy.insert(dividendCopy.end() - 1, MPInt(0));
addInequality(
getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
addInequality(
@@ -1336,7 +1335,7 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
bool symbolic = false) {
assert(pos < cst.getNumVars() && "invalid position");
for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
- int64_t v = cst.atEq(r, pos);
+ MPInt v = cst.atEq(r, pos);
if (v * v != 1)
continue;
unsigned c;
@@ -1365,7 +1364,7 @@ LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
// atEq(rowIdx, pos) is either -1 or 1.
assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
- int64_t constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
+ MPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
setAndEliminate(pos, constVal);
return success();
}
@@ -1391,10 +1390,9 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
// s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
// ceil(s0 - 7 / 8) = floor(s0 / 8)).
-Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
- unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
- SmallVectorImpl<int64_t> *ub, unsigned *minLbPos,
- unsigned *minUbPos) const {
+Optional<MPInt> IntegerRelation::getConstantBoundOnDimSize(
+ unsigned pos, SmallVectorImpl<MPInt> *lb, MPInt *boundFloorDivisor,
+ SmallVectorImpl<MPInt> *ub, unsigned *minLbPos, unsigned *minUbPos) const {
assert(pos < getNumDimVars() && "Invalid variable position");
// Find an equality for 'pos'^th variable that equates it to some function
@@ -1406,7 +1404,7 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
// TODO: this can be handled in the future by using the explicit
// representation of the local vars.
if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
- [](int64_t coeff) { return coeff == 0; }))
+ [](const MPInt &coeff) { return coeff == 0; }))
return None;
// This variable can only take a single value.
@@ -1416,7 +1414,7 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
if (ub)
ub->resize(getNumSymbolVars() + 1);
for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
- int64_t v = atEq(eqPos, pos);
+ MPInt v = atEq(eqPos, pos);
// atEq(eqRow, pos) is either -1 or 1.
assert(v * v == 1);
(*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
@@ -1433,7 +1431,7 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
*minLbPos = eqPos;
if (minUbPos)
*minUbPos = eqPos;
- return 1;
+ return MPInt(1);
}
// Check if the variable appears at all in any of the inequalities.
@@ -1457,7 +1455,7 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
/*eqIndices=*/nullptr, /*offset=*/0,
/*num=*/getNumDimVars());
- Optional<int64_t> minDiff;
+ Optional<MPInt> minDiff;
unsigned minLbPosition = 0, minUbPosition = 0;
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
@@ -1474,11 +1472,11 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
}
if (j < getNumCols() - 1)
continue;
- int64_t
diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
- atIneq(lbPos, getNumCols() - 1) + 1,
- atIneq(lbPos, pos));
+ MPInt
diff = ceilDiv(atIneq(ubPos, getNumCols() - 1) +
+ atIneq(lbPos, getNumCols() - 1) + 1,
+ atIneq(lbPos, pos));
// This bound is non-negative by definition.
-
diff = std::max<int64_t>(
diff , 0);
+
diff = std::max<MPInt>(
diff , MPInt(0));
if (minDiff == None ||
diff < minDiff) {
minDiff =
diff ;
minLbPosition = lbPos;
@@ -1518,7 +1516,7 @@ Optional<int64_t> IntegerRelation::getConstantBoundOnDimSize(
}
template <bool isLower>
-Optional<int64_t>
+Optional<MPInt>
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
// Project to 'pos'.
@@ -1540,7 +1538,7 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
// If it doesn't, there isn't a bound on it.
return None;
- Optional<int64_t> minOrMaxConst;
+ Optional<MPInt> minOrMaxConst;
// Take the max across all const lower bounds (or min across all constant
// upper bounds).
@@ -1561,9 +1559,9 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
// Not a constant bound.
continue;
- int64_t boundConst =
- isLower ? mlir::ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
- : mlir::floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
+ MPInt boundConst =
+ isLower ? ceilDiv(-atIneq(r, getNumCols() - 1), atIneq(r, 0))
+ : floorDiv(atIneq(r, getNumCols() - 1), -atIneq(r, 0));
if (isLower) {
if (minOrMaxConst == None || boundConst > minOrMaxConst)
minOrMaxConst = boundConst;
@@ -1575,8 +1573,8 @@ IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
return minOrMaxConst;
}
-Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
- unsigned pos) const {
+Optional<MPInt> IntegerRelation::getConstantBound(BoundType type,
+ unsigned pos) const {
if (type == BoundType::LB)
return IntegerRelation(*this)
.computeConstantLowerOrUpperBound</*isLower=*/true>(pos);
@@ -1585,13 +1583,13 @@ Optional<int64_t> IntegerRelation::getConstantBound(BoundType type,
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
assert(type == BoundType::EQ && "expected EQ");
- Optional<int64_t> lb =
+ Optional<MPInt> lb =
IntegerRelation(*this).computeConstantLowerOrUpperBound</*isLower=*/true>(
pos);
- Optional<int64_t> ub =
+ Optional<MPInt> ub =
IntegerRelation(*this)
.computeConstantLowerOrUpperBound</*isLower=*/false>(pos);
- return (lb && ub && *lb == *ub) ? Optional<int64_t>(*ub) : None;
+ return (lb && ub && *lb == *ub) ? Optional<MPInt>(*ub) : None;
}
// A simple (naive and conservative) check for hyper-rectangularity.
@@ -1632,10 +1630,10 @@ void IntegerRelation::removeTrivialRedundancy() {
// A map used to detect redundancy stemming from constraints that only
diff er
// in their constant term. The value stored is <row position, const term>
// for a given row.
- SmallDenseMap<ArrayRef<int64_t>, std::pair<unsigned, int64_t>>
+ SmallDenseMap<ArrayRef<MPInt>, std::pair<unsigned, MPInt>>
rowsWithoutConstTerm;
// To unique rows.
- SmallDenseSet<ArrayRef<int64_t>, 8> rowSet;
+ SmallDenseSet<ArrayRef<MPInt>, 8> rowSet;
// Check if constraint is of the form <non-negative-constant> >= 0.
auto isTriviallyValid = [&](unsigned r) -> bool {
@@ -1649,8 +1647,8 @@ void IntegerRelation::removeTrivialRedundancy() {
// Detect and mark redundant constraints.
SmallVector<bool, 256> redunIneq(getNumInequalities(), false);
for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
- int64_t *rowStart = &inequalities(r, 0);
- auto row = ArrayRef<int64_t>(rowStart, getNumCols());
+ MPInt *rowStart = &inequalities(r, 0);
+ auto row = ArrayRef<MPInt>(rowStart, getNumCols());
if (isTriviallyValid(r) || !rowSet.insert(row).second) {
redunIneq[r] = true;
continue;
@@ -1660,8 +1658,8 @@ void IntegerRelation::removeTrivialRedundancy() {
// everything other than the one with the smallest constant term redundant.
// (eg: among i - 16j - 5 >= 0, i - 16j - 1 >=0, i - 16j - 7 >= 0, the
// former two are redundant).
- int64_t constTerm = atIneq(r, getNumCols() - 1);
- auto rowWithoutConstTerm = ArrayRef<int64_t>(rowStart, getNumCols() - 1);
+ MPInt constTerm = atIneq(r, getNumCols() - 1);
+ auto rowWithoutConstTerm = ArrayRef<MPInt>(rowStart, getNumCols() - 1);
const auto &ret =
rowsWithoutConstTerm.insert({rowWithoutConstTerm, {r, constTerm}});
if (!ret.second) {
@@ -1817,19 +1815,19 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
// integer exact.
for (auto ubPos : ubIndices) {
for (auto lbPos : lbIndices) {
- SmallVector<int64_t, 4> ineq;
+ SmallVector<MPInt, 4> ineq;
ineq.reserve(newRel.getNumCols());
- int64_t lbCoeff = atIneq(lbPos, pos);
+ MPInt lbCoeff = atIneq(lbPos, pos);
// Note that in the comments above, ubCoeff is the negation of the
// coefficient in the canonical form as the view taken here is that of the
// term being moved to the other size of '>='.
- int64_t ubCoeff = -atIneq(ubPos, pos);
+ MPInt ubCoeff = -atIneq(ubPos, pos);
// TODO: refactor this loop to avoid all branches inside.
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
continue;
assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
- int64_t lcm = std::lcm(lbCoeff, ubCoeff);
+ MPInt lcm = presburger::lcm(lbCoeff, ubCoeff);
ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
atIneq(lbPos, l) * (lcm / lbCoeff));
assert(lcm > 0 && "lcm should be positive!");
@@ -1854,7 +1852,7 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
// Copy over the constraints not involving this variable.
for (auto nbPos : nbIndices) {
- SmallVector<int64_t, 4> ineq;
+ SmallVector<MPInt, 4> ineq;
ineq.reserve(getNumCols() - 1);
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
@@ -1869,7 +1867,7 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
// Copy over the equalities.
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
- SmallVector<int64_t, 4> eq;
+ SmallVector<MPInt, 4> eq;
eq.reserve(newRel.getNumCols());
for (unsigned l = 0, e = getNumCols(); l < e; l++) {
if (l == pos)
@@ -1933,7 +1931,7 @@ enum BoundCmpResult { Greater, Less, Equal, Unknown };
/// Compares two affine bounds whose coefficients are provided in 'first' and
/// 'second'. The last coefficient is the constant term.
-static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
+static BoundCmpResult compareBounds(ArrayRef<MPInt> a, ArrayRef<MPInt> b) {
assert(a.size() == b.size());
// For the bounds to be comparable, their corresponding variable
@@ -1985,20 +1983,20 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
getCommonConstraints(*this, otherCst, commonCst);
- std::vector<SmallVector<int64_t, 8>> boundingLbs;
- std::vector<SmallVector<int64_t, 8>> boundingUbs;
+ std::vector<SmallVector<MPInt, 8>> boundingLbs;
+ std::vector<SmallVector<MPInt, 8>> boundingUbs;
boundingLbs.reserve(2 * getNumDimVars());
boundingUbs.reserve(2 * getNumDimVars());
// To hold lower and upper bounds for each dimension.
- SmallVector<int64_t, 4> lb, otherLb, ub, otherUb;
+ SmallVector<MPInt, 4> lb, otherLb, ub, otherUb;
// To compute min of lower bounds and max of upper bounds for each dimension.
- SmallVector<int64_t, 4> minLb(getNumSymbolVars() + 1);
- SmallVector<int64_t, 4> maxUb(getNumSymbolVars() + 1);
+ SmallVector<MPInt, 4> minLb(getNumSymbolVars() + 1);
+ SmallVector<MPInt, 4> maxUb(getNumSymbolVars() + 1);
// To compute final new lower and upper bounds for the union.
- SmallVector<int64_t, 8> newLb(getNumCols()), newUb(getNumCols());
+ SmallVector<MPInt, 8> newLb(getNumCols()), newUb(getNumCols());
- int64_t lbFloorDivisor, otherLbFloorDivisor;
+ MPInt lbFloorDivisor, otherLbFloorDivisor;
for (unsigned d = 0, e = getNumDimVars(); d < e; ++d) {
auto extent = getConstantBoundOnDimSize(d, &lb, &lbFloorDivisor, &ub);
if (!extent.has_value())
@@ -2061,7 +2059,7 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
// Copy over the symbolic part + constant term.
std::copy(minLb.begin(), minLb.end(), newLb.begin() + getNumDimVars());
std::transform(newLb.begin() + getNumDimVars(), newLb.end(),
- newLb.begin() + getNumDimVars(), std::negate<int64_t>());
+ newLb.begin() + getNumDimVars(), std::negate<MPInt>());
std::copy(maxUb.begin(), maxUb.end(), newUb.begin() + getNumDimVars());
boundingLbs.push_back(newLb);
diff --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
index ba9d34b998ac..12b4a7c50556 100644
--- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp
+++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
@@ -25,7 +25,7 @@ static void modEntryColumnOperation(Matrix &m, unsigned row, unsigned sourceCol,
assert(m(row, sourceCol) != 0 && "Cannot divide by zero!");
assert((m(row, sourceCol) > 0 && m(row, targetCol) > 0) &&
"Operands must be positive!");
- int64_t ratio = m(row, targetCol) / m(row, sourceCol);
+ MPInt ratio = m(row, targetCol) / m(row, sourceCol);
m.addToColumn(sourceCol, targetCol, -ratio);
otherMatrix.addToColumn(sourceCol, targetCol, -ratio);
}
@@ -116,21 +116,21 @@ IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const {
IntegerRelation result(rel.getSpace());
for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) {
- ArrayRef<int64_t> eq = rel.getEquality(i);
+ ArrayRef<MPInt> eq = rel.getEquality(i);
- int64_t c = eq.back();
+ const MPInt &c = eq.back();
- SmallVector<int64_t, 8> newEq = preMultiplyWithRow(eq.drop_back());
+ SmallVector<MPInt, 8> newEq = preMultiplyWithRow(eq.drop_back());
newEq.push_back(c);
result.addEquality(newEq);
}
for (unsigned i = 0, e = rel.getNumInequalities(); i < e; ++i) {
- ArrayRef<int64_t> ineq = rel.getInequality(i);
+ ArrayRef<MPInt> ineq = rel.getInequality(i);
- int64_t c = ineq.back();
+ const MPInt &c = ineq.back();
- SmallVector<int64_t, 8> newIneq = preMultiplyWithRow(ineq.drop_back());
+ SmallVector<MPInt, 8> newIneq = preMultiplyWithRow(ineq.drop_back());
newIneq.push_back(c);
result.addInequality(newIneq);
}
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index c51aa3c922ea..4fbee3faf7f9 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -41,7 +41,7 @@ unsigned Matrix::appendExtraRow() {
return nRows - 1;
}
-unsigned Matrix::appendExtraRow(ArrayRef<int64_t> elems) {
+unsigned Matrix::appendExtraRow(ArrayRef<MPInt> elems) {
assert(elems.size() == nColumns && "elems must match row length!");
unsigned row = appendExtraRow();
for (unsigned col = 0; col < nColumns; ++col)
@@ -84,15 +84,15 @@ void Matrix::swapColumns(unsigned column, unsigned otherColumn) {
std::swap(at(row, column), at(row, otherColumn));
}
-MutableArrayRef<int64_t> Matrix::getRow(unsigned row) {
+MutableArrayRef<MPInt> Matrix::getRow(unsigned row) {
return {&data[row * nReservedColumns], nColumns};
}
-ArrayRef<int64_t> Matrix::getRow(unsigned row) const {
+ArrayRef<MPInt> Matrix::getRow(unsigned row) const {
return {&data[row * nReservedColumns], nColumns};
}
-void Matrix::setRow(unsigned row, ArrayRef<int64_t> elems) {
+void Matrix::setRow(unsigned row, ArrayRef<MPInt> elems) {
assert(elems.size() == getNumColumns() &&
"elems size must match row length!");
for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
@@ -115,7 +115,7 @@ void Matrix::insertColumns(unsigned pos, unsigned count) {
for (int ci = nReservedColumns - 1; ci >= 0; --ci) {
unsigned r = ri;
unsigned c = ci;
- int64_t &dest = data[r * nReservedColumns + c];
+ MPInt &dest = data[r * nReservedColumns + c];
if (c >= nColumns) { // NOLINT
// Out of bounds columns are zero-initialized. NOLINT because clang-tidy
// complains about this branch being the same as the c >= pos one.
@@ -186,16 +186,18 @@ void Matrix::copyRow(unsigned sourceRow, unsigned targetRow) {
at(targetRow, c) = at(sourceRow, c);
}
-void Matrix::fillRow(unsigned row, int64_t value) {
+void Matrix::fillRow(unsigned row, const MPInt &value) {
for (unsigned col = 0; col < nColumns; ++col)
at(row, col) = value;
}
-void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
+void Matrix::addToRow(unsigned sourceRow, unsigned targetRow,
+ const MPInt &scale) {
addToRow(targetRow, getRow(sourceRow), scale);
}
-void Matrix::addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale) {
+void Matrix::addToRow(unsigned row, ArrayRef<MPInt> rowVec,
+ const MPInt &scale) {
if (scale == 0)
return;
for (unsigned col = 0; col < nColumns; ++col)
@@ -203,7 +205,7 @@ void Matrix::addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale) {
}
void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,
- int64_t scale) {
+ const MPInt &scale) {
if (scale == 0)
return;
for (unsigned row = 0, e = getNumRows(); row < e; ++row)
@@ -220,31 +222,30 @@ void Matrix::negateRow(unsigned row) {
at(row, column) = -at(row, column);
}
-int64_t Matrix::normalizeRow(unsigned row, unsigned cols) {
+MPInt Matrix::normalizeRow(unsigned row, unsigned cols) {
return normalizeRange(getRow(row).slice(0, cols));
}
-int64_t Matrix::normalizeRow(unsigned row) {
+MPInt Matrix::normalizeRow(unsigned row) {
return normalizeRow(row, getNumColumns());
}
-SmallVector<int64_t, 8>
-Matrix::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
+SmallVector<MPInt, 8> Matrix::preMultiplyWithRow(ArrayRef<MPInt> rowVec) const {
assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
- SmallVector<int64_t, 8> result(getNumColumns(), 0);
+ SmallVector<MPInt, 8> result(getNumColumns(), MPInt(0));
for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
for (unsigned i = 0, e = getNumRows(); i < e; ++i)
result[col] += rowVec[i] * at(i, col);
return result;
}
-SmallVector<int64_t, 8>
-Matrix::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
+SmallVector<MPInt, 8>
+Matrix::postMultiplyWithColumn(ArrayRef<MPInt> colVec) const {
assert(getNumColumns() == colVec.size() &&
"Invalid column vector dimension!");
- SmallVector<int64_t, 8> result(getNumRows(), 0);
+ SmallVector<MPInt, 8> result(getNumRows(), MPInt(0));
for (unsigned row = 0, e = getNumRows(); row < e; row++)
for (unsigned i = 0, e = getNumColumns(); i < e; i++)
result[row] += at(row, i) * colVec[i];
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index d1d3925c5946..7dc1a804111b 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -29,11 +29,11 @@ void MultiAffineFunction::assertIsConsistent() const {
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
-static SmallVector<int64_t, 8> subtractExprs(ArrayRef<int64_t> vecA,
- ArrayRef<int64_t> vecB) {
+static SmallVector<MPInt, 8> subtractExprs(ArrayRef<MPInt> vecA,
+ ArrayRef<MPInt> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of
diff ering lengths!");
- SmallVector<int64_t, 8> result;
+ SmallVector<MPInt, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.push_back(vecA[i] - vecB[i]);
@@ -55,27 +55,26 @@ void MultiAffineFunction::print(raw_ostream &os) const {
output.print(os);
}
-SmallVector<int64_t, 8>
-MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
+SmallVector<MPInt, 8>
+MultiAffineFunction::valueAt(ArrayRef<MPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
"Point has incorrect dimensionality!");
- SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
+ SmallVector<MPInt, 8> pointHomogenous{llvm::to_vector(point)};
// Get the division values at this point.
- SmallVector<Optional<int64_t>, 8> divValues = divs.divValuesAt(point);
+ SmallVector<Optional<MPInt>, 8> divValues = divs.divValuesAt(point);
// The given point didn't include the values of the divs which the output is a
// function of; we have computed one possible set of values and use them here.
pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
- for (const Optional<int64_t> &divVal : divValues)
+ for (const Optional<MPInt> &divVal : divValues)
pointHomogenous.push_back(*divVal);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
- pointHomogenous.push_back(1);
- SmallVector<int64_t, 8> result =
- output.postMultiplyWithColumn(pointHomogenous);
+ pointHomogenous.emplace_back(1);
+ SmallVector<MPInt, 8> result = output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
@@ -127,7 +126,7 @@ void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
other.divs.insertDiv(0, nDivs);
- SmallVector<int64_t, 8> div(other.divs.getNumVars() + 1);
+ SmallVector<MPInt, 8> div(other.divs.getNumVars() + 1);
for (unsigned i = 0; i < nDivs; ++i) {
// Zero fill.
std::fill(div.begin(), div.end(), 0);
@@ -304,7 +303,7 @@ static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
// Create the expression `outA - outB` for this level.
- SmallVector<int64_t, 8> subExpr = subtractExprs(
+ SmallVector<MPInt, 8> subExpr = subtractExprs(
pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
if (lexMin) {
@@ -312,13 +311,13 @@ static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
// outA - outB <= -1
// outA <= outB - 1
// outA < outB
- levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1);
+ levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
} else {
// For lexMax, we add a lower bound of 1:
// outA - outB >= 1
// outA > outB + 1
// outA > outB
- levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1);
+ levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
}
// Union the set with the result.
@@ -351,7 +350,7 @@ void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
MultiAffineFunction copyOther = other;
mergeDivs(copyOther);
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
- output.addToRow(i, copyOther.getOutputExpr(i), -1);
+ output.addToRow(i, copyOther.getOutputExpr(i), MPInt(-1));
// Check consistency.
assertIsConsistent();
@@ -391,14 +390,14 @@ IntegerRelation MultiAffineFunction::getAsRelation() const {
// Add equalities such that the i^th range variable is equal to the i^th
// output expression.
- SmallVector<int64_t, 8> eq(result.getNumCols());
+ SmallVector<MPInt, 8> eq(result.getNumCols());
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
// TODO: Add functions to get VarKind offsets in output in MAF and use them
// here.
// The output expression does not contain range variables, while the
// equality does. So, we need to copy all variables and mark all range
// variables as 0 in the equality.
- ArrayRef<int64_t> expr = getOutputExpr(i);
+ ArrayRef<MPInt> expr = getOutputExpr(i);
// Copy domain variables in `expr` to domain variables in `eq`.
std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
// Fill the range variables in `eq` as zero.
@@ -424,8 +423,8 @@ void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
piece.output.removeOutputs(start, end);
}
-Optional<SmallVector<int64_t, 8>>
-PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
+Optional<SmallVector<MPInt, 8>>
+PWMAFunction::valueAt(ArrayRef<MPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars());
for (const Piece &piece : pieces)
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index c6a88f285c1c..df086012c8d6 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -68,7 +68,7 @@ PresburgerRelation::unionSet(const PresburgerRelation &set) const {
}
/// A point is contained in the union iff any of the parts contain the point.
-bool PresburgerRelation::containsPoint(ArrayRef<int64_t> point) const {
+bool PresburgerRelation::containsPoint(ArrayRef<MPInt> point) const {
return llvm::any_of(disjuncts, [&](const IntegerRelation &disjunct) {
return (disjunct.containsPointNoLocal(point));
});
@@ -121,15 +121,15 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
///
/// For every eq `coeffs == 0` there are two possible ineqs to index into.
/// The first is coeffs >= 0 and the second is coeffs <= 0.
-static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
- unsigned idx) {
+static SmallVector<MPInt, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
+ unsigned idx) {
assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
"idx out of bounds!");
if (idx < rel.getNumInequalities())
return llvm::to_vector<8>(rel.getInequality(idx));
idx -= rel.getNumInequalities();
- ArrayRef<int64_t> eqCoeffs = rel.getEquality(idx / 2);
+ ArrayRef<MPInt> eqCoeffs = rel.getEquality(idx / 2);
if (idx % 2 == 0)
return llvm::to_vector<8>(eqCoeffs);
@@ -389,7 +389,7 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
// state before adding this complement constraint, and add s_ij to b.
simplex.rollback(frame.simplexSnapshot);
b.truncate(frame.bCounts);
- SmallVector<int64_t, 8> ineq =
+ SmallVector<MPInt, 8> ineq =
getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed);
b.addInequality(ineq);
simplex.addInequality(ineq);
@@ -407,7 +407,7 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
frame.simplexSnapshot = simplex.getSnapshot();
unsigned idx = frame.ineqsToProcess.back();
- SmallVector<int64_t, 8> ineq =
+ SmallVector<MPInt, 8> ineq =
getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx));
b.addInequality(ineq);
simplex.addInequality(ineq);
@@ -459,10 +459,10 @@ bool PresburgerRelation::isIntegerEmpty() const {
return llvm::all_of(disjuncts, std::mem_fn(&IntegerRelation::isIntegerEmpty));
}
-bool PresburgerRelation::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
+bool PresburgerRelation::findIntegerSample(SmallVectorImpl<MPInt> &sample) {
// A sample exists iff any of the disjuncts contains a sample.
for (const IntegerRelation &disjunct : disjuncts) {
- if (Optional<SmallVector<int64_t, 8>> opt = disjunct.findIntegerSample()) {
+ if (Optional<SmallVector<MPInt, 8>> opt = disjunct.findIntegerSample()) {
sample = std::move(*opt);
return true;
}
@@ -470,13 +470,13 @@ bool PresburgerRelation::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
return false;
}
-Optional<uint64_t> PresburgerRelation::computeVolume() const {
+Optional<MPInt> PresburgerRelation::computeVolume() const {
assert(getNumSymbolVars() == 0 && "Symbols are not yet supported!");
// The sum of the volumes of the disjuncts is a valid overapproximation of the
// volume of their union, even if they overlap.
- uint64_t result = 0;
+ MPInt result(0);
for (const IntegerRelation &disjunct : disjuncts) {
- Optional<uint64_t> volume = disjunct.computeVolume();
+ Optional<MPInt> volume = disjunct.computeVolume();
if (!volume)
return {};
result += *volume;
@@ -511,20 +511,20 @@ class presburger::SetCoalescer {
/// The list of all inversed equalities during typing. This ensures that
/// the constraints exist even after the typing function has concluded.
- SmallVector<SmallVector<int64_t, 2>, 2> negEqs;
+ SmallVector<SmallVector<MPInt, 2>, 2> negEqs;
/// `redundantIneqsA` is the inequalities of `a` that are redundant for `b`
/// (similarly for `cuttingIneqsA`, `redundantIneqsB`, and `cuttingIneqsB`).
- SmallVector<ArrayRef<int64_t>, 2> redundantIneqsA;
- SmallVector<ArrayRef<int64_t>, 2> cuttingIneqsA;
+ SmallVector<ArrayRef<MPInt>, 2> redundantIneqsA;
+ SmallVector<ArrayRef<MPInt>, 2> cuttingIneqsA;
- SmallVector<ArrayRef<int64_t>, 2> redundantIneqsB;
- SmallVector<ArrayRef<int64_t>, 2> cuttingIneqsB;
+ SmallVector<ArrayRef<MPInt>, 2> redundantIneqsB;
+ SmallVector<ArrayRef<MPInt>, 2> cuttingIneqsB;
/// Given a Simplex `simp` and one of its inequalities `ineq`, check
/// that the facet of `simp` where `ineq` holds as an equality is contained
/// within `a`.
- bool isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp);
+ bool isFacetContained(ArrayRef<MPInt> ineq, Simplex &simp);
/// Removes redundant constraints from `disjunct`, adds it to `disjuncts` and
/// removes the disjuncts at position `i` and `j`. Updates `simplices` to
@@ -548,13 +548,13 @@ class presburger::SetCoalescer {
/// Types the inequality `ineq` according to its `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
- LogicalResult typeInequality(ArrayRef<int64_t> ineq, Simplex &simp);
+ LogicalResult typeInequality(ArrayRef<MPInt> ineq, Simplex &simp);
/// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
/// -`eq` >= 0 according to their `IneqType` for `simp` into
/// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
/// inequalities were encountered. Otherwise, returns failure.
- LogicalResult typeEquality(ArrayRef<int64_t> eq, Simplex &simp);
+ LogicalResult typeEquality(ArrayRef<MPInt> eq, Simplex &simp);
/// Replaces the element at position `i` with the last element and erases
/// the last element for both `disjuncts` and `simplices`.
@@ -631,10 +631,10 @@ PresburgerRelation SetCoalescer::coalesce() {
/// Given a Simplex `simp` and one of its inequalities `ineq`, check
/// that all inequalities of `cuttingIneqsB` are redundant for the facet of
/// `simp` where `ineq` holds as an equality is contained within `a`.
-bool SetCoalescer::isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp) {
+bool SetCoalescer::isFacetContained(ArrayRef<MPInt> ineq, Simplex &simp) {
SimplexRollbackScopeExit scopeExit(simp);
simp.addEquality(ineq);
- return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
+ return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef<MPInt> curr) {
return simp.isRedundantInequality(curr);
});
}
@@ -696,23 +696,23 @@ LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
/// redundant ones are, so only the cutting ones remain to be checked.
Simplex &simp = simplices[i];
IntegerRelation &disjunct = disjuncts[i];
- if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<int64_t> curr) {
+ if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<MPInt> curr) {
return !isFacetContained(curr, simp);
}))
return failure();
IntegerRelation newSet(disjunct.getSpace());
- for (ArrayRef<int64_t> curr : redundantIneqsA)
+ for (ArrayRef<MPInt> curr : redundantIneqsA)
newSet.addInequality(curr);
- for (ArrayRef<int64_t> curr : redundantIneqsB)
+ for (ArrayRef<MPInt> curr : redundantIneqsB)
newSet.addInequality(curr);
addCoalescedDisjunct(i, j, newSet);
return success();
}
-LogicalResult SetCoalescer::typeInequality(ArrayRef<int64_t> ineq,
+LogicalResult SetCoalescer::typeInequality(ArrayRef<MPInt> ineq,
Simplex &simp) {
Simplex::IneqType type = simp.findIneqType(ineq);
if (type == Simplex::IneqType::Redundant)
@@ -724,11 +724,11 @@ LogicalResult SetCoalescer::typeInequality(ArrayRef<int64_t> ineq,
return success();
}
-LogicalResult SetCoalescer::typeEquality(ArrayRef<int64_t> eq, Simplex &simp) {
+LogicalResult SetCoalescer::typeEquality(ArrayRef<MPInt> eq, Simplex &simp) {
if (typeInequality(eq, simp).failed())
return failure();
negEqs.push_back(getNegatedCoeffs(eq));
- ArrayRef<int64_t> inv(negEqs.back());
+ ArrayRef<MPInt> inv(negEqs.back());
if (typeInequality(inv, simp).failed())
return failure();
return success();
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index d3c66027b98e..f08af75a6084 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -22,10 +22,10 @@ const int nullIndex = std::numeric_limits<int>::max();
// Return a + scale*b;
LLVM_ATTRIBUTE_UNUSED
-static SmallVector<int64_t, 8>
-scaleAndAddForAssert(ArrayRef<int64_t> a, int64_t scale, ArrayRef<int64_t> b) {
+static SmallVector<MPInt, 8>
+scaleAndAddForAssert(ArrayRef<MPInt> a, const MPInt &scale, ArrayRef<MPInt> b) {
assert(a.size() == b.size());
- SmallVector<int64_t, 8> res;
+ SmallVector<MPInt, 8> res;
res.reserve(a.size());
for (unsigned i = 0, e = a.size(); i < e; ++i)
res.push_back(a[i] + scale * b[i]);
@@ -101,7 +101,7 @@ unsigned SimplexBase::addZeroRow(bool makeRestricted) {
/// Add a new row to the tableau corresponding to the given constant term and
/// list of coefficients. The coefficients are specified as a vector of
/// (variable index, coefficient) pairs.
-unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
+unsigned SimplexBase::addRow(ArrayRef<MPInt> coeffs, bool makeRestricted) {
assert(coeffs.size() == var.size() + 1 &&
"Incorrect number of coefficients!");
assert(var.size() + getNumFixedCols() == getNumColumns() &&
@@ -124,7 +124,7 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
//
// Symbols don't use the big M parameter since they do not get lex
// optimized.
- int64_t bigMCoeff = 0;
+ MPInt bigMCoeff(0);
for (unsigned i = 0; i < coeffs.size() - 1; ++i)
if (!var[i].isSymbol)
bigMCoeff -= coeffs[i];
@@ -150,9 +150,9 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
// row, scaled by the coefficient for the variable, accounting for the two
// rows potentially having
diff erent denominators. The new denominator is
// the lcm of the two.
- int64_t lcm = std::lcm(tableau(newRow, 0), tableau(pos, 0));
- int64_t nRowCoeff = lcm / tableau(newRow, 0);
- int64_t idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0));
+ MPInt lcm = presburger::lcm(tableau(newRow, 0), tableau(pos, 0));
+ MPInt nRowCoeff = lcm / tableau(newRow, 0);
+ MPInt idxRowCoeff = coeffs[i] * (lcm / tableau(pos, 0));
tableau(newRow, 0) = lcm;
for (unsigned col = 1, e = getNumColumns(); col < e; ++col)
tableau(newRow, col) =
@@ -165,7 +165,7 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
}
namespace {
-bool signMatchesDirection(int64_t elem, Direction direction) {
+bool signMatchesDirection(const MPInt &elem, Direction direction) {
assert(elem != 0 && "elem should not be 0");
return direction == Direction::Up ? elem > 0 : elem < 0;
}
@@ -261,7 +261,7 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
/// The constraint is violated when added (it would be useless otherwise)
/// so we immediately try to move it to a column.
LogicalResult LexSimplexBase::addCut(unsigned row) {
- int64_t d = tableau(row, 0);
+ MPInt d = tableau(row, 0);
unsigned cutRow = addZeroRow(/*makeRestricted=*/true);
tableau(cutRow, 0) = d;
tableau(cutRow, 1) = -mod(-tableau(row, 1), d); // -c%d.
@@ -285,7 +285,7 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
return {};
}
-MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() {
+MaybeOptimum<SmallVector<MPInt, 8>> LexSimplex::findIntegerLexMin() {
// We first try to make the tableau consistent.
if (restoreRationalConsistency().failed())
return OptimumKind::Empty;
@@ -316,19 +316,19 @@ MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() {
llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
}
-bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
+bool LexSimplex::isSeparateInequality(ArrayRef<MPInt> coeffs) {
SimplexRollbackScopeExit scopeExit(*this);
addInequality(coeffs);
return findIntegerLexMin().isEmpty();
}
-bool LexSimplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
+bool LexSimplex::isRedundantInequality(ArrayRef<MPInt> coeffs) {
return isSeparateInequality(getComplementIneq(coeffs));
}
-SmallVector<int64_t, 8>
+SmallVector<MPInt, 8>
SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const {
- SmallVector<int64_t, 8> sample;
+ SmallVector<MPInt, 8> sample;
sample.reserve(nSymbol + 1);
for (unsigned col = 3; col < 3 + nSymbol; ++col)
sample.push_back(tableau(row, col));
@@ -336,9 +336,9 @@ SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const {
return sample;
}
-SmallVector<int64_t, 8>
+SmallVector<MPInt, 8>
SymbolicLexSimplex::getSymbolicSampleIneq(unsigned row) const {
- SmallVector<int64_t, 8> sample = getSymbolicSampleNumerator(row);
+ SmallVector<MPInt, 8> sample = getSymbolicSampleNumerator(row);
// The inequality is equivalent to the GCD-normalized one.
normalizeRange(sample);
return sample;
@@ -351,13 +351,14 @@ void LexSimplexBase::appendSymbol() {
nSymbol++;
}
-static bool isRangeDivisibleBy(ArrayRef<int64_t> range, int64_t divisor) {
+static bool isRangeDivisibleBy(ArrayRef<MPInt> range, const MPInt &divisor) {
assert(divisor > 0 && "divisor must be positive!");
- return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; });
+ return llvm::all_of(range,
+ [divisor](const MPInt &x) { return x % divisor == 0; });
}
bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
- int64_t denom = tableau(row, 0);
+ MPInt denom = tableau(row, 0);
return tableau(row, 1) % denom == 0 &&
isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom);
}
@@ -396,7 +397,7 @@ bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
/// This constraint is violated when added so we immediately try to move it to a
/// column.
LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
- int64_t d = tableau(row, 0);
+ MPInt d = tableau(row, 0);
if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) {
// The coefficients of symbols in the symbol numerator are divisible
// by the denominator, so we can add the constraint directly,
@@ -405,9 +406,9 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
}
// Construct the division variable `q = ((-c%d) + sum_i (-a_i%d)s_i)/d`.
- SmallVector<int64_t, 8> divCoeffs;
+ SmallVector<MPInt, 8> divCoeffs;
divCoeffs.reserve(nSymbol + 1);
- int64_t divDenom = d;
+ MPInt divDenom = d;
for (unsigned col = 3; col < 3 + nSymbol; ++col)
divCoeffs.push_back(mod(-tableau(row, col), divDenom)); // (-a_i%d)s_i
divCoeffs.push_back(mod(-tableau(row, 1), divDenom)); // -c%d.
@@ -448,7 +449,7 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
return;
}
- int64_t denom = tableau(u.pos, 0);
+ MPInt denom = tableau(u.pos, 0);
if (tableau(u.pos, 2) < denom) {
// M + u has a sample value of fM + something, where f < 1, so
// u = (f - 1)M + something, which has a negative coefficient for M,
@@ -459,8 +460,8 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
assert(tableau(u.pos, 2) == denom &&
"Coefficient of M should not be greater than 1!");
- SmallVector<int64_t, 8> sample = getSymbolicSampleNumerator(u.pos);
- for (int64_t &elem : sample) {
+ SmallVector<MPInt, 8> sample = getSymbolicSampleNumerator(u.pos);
+ for (MPInt &elem : sample) {
assert(elem % denom == 0 && "coefficients must be integral!");
elem /= denom;
}
@@ -557,7 +558,7 @@ SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
continue;
}
- SmallVector<int64_t, 8> symbolicSample;
+ SmallVector<MPInt, 8> symbolicSample;
unsigned splitRow = 0;
for (unsigned e = getNumRows(); splitRow < e; ++splitRow) {
if (tableau(splitRow, 2) > 0)
@@ -642,7 +643,7 @@ SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
// was negative.
assert(u.orientation == Orientation::Row &&
"The split row should have been returned to row orientation!");
- SmallVector<int64_t, 8> splitIneq =
+ SmallVector<MPInt, 8> splitIneq =
getComplementIneq(getSymbolicSampleIneq(u.pos));
normalizeRange(splitIneq);
if (moveRowUnknownToColumn(u.pos).failed()) {
@@ -818,7 +819,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
// all possible values of the symbols.
auto getSampleChangeCoeffForVar = [this, row](unsigned col,
const Unknown &u) -> Fraction {
- int64_t a = tableau(row, col);
+ MPInt a = tableau(row, col);
if (u.orientation == Orientation::Column) {
// Pivot column case.
if (u.pos == col)
@@ -833,7 +834,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
return {1, 1};
// Non-pivot row case.
- int64_t c = tableau(u.pos, col);
+ MPInt c = tableau(u.pos, col);
return {c, a};
};
@@ -867,7 +868,7 @@ Optional<SimplexBase::Pivot> Simplex::findPivot(int row,
Direction direction) const {
Optional<unsigned> col;
for (unsigned j = 2, e = getNumColumns(); j < e; ++j) {
- int64_t elem = tableau(row, j);
+ MPInt elem = tableau(row, j);
if (elem == 0)
continue;
@@ -1016,18 +1017,18 @@ Optional<unsigned> Simplex::findPivotRow(Optional<unsigned> skipRow,
// retConst being used uninitialized in the initialization of `
diff ` below. In
// reality, these are always initialized when that line is reached since these
// are set whenever retRow is set.
- int64_t retElem = 0, retConst = 0;
+ MPInt retElem, retConst;
for (unsigned row = nRedundant, e = getNumRows(); row < e; ++row) {
if (skipRow && row == *skipRow)
continue;
- int64_t elem = tableau(row, col);
+ MPInt elem = tableau(row, col);
if (elem == 0)
continue;
if (!unknownFromRow(row).restricted)
continue;
if (signMatchesDirection(elem, direction))
continue;
- int64_t constTerm = tableau(row, 1);
+ MPInt constTerm = tableau(row, 1);
if (!retRow) {
retRow = row;
@@ -1036,7 +1037,7 @@ Optional<unsigned> Simplex::findPivotRow(Optional<unsigned> skipRow,
continue;
}
- int64_t
diff = retConst * elem - constTerm * retElem;
+ MPInt
diff = retConst * elem - constTerm * retElem;
if ((
diff == 0 && rowUnknown[row] < rowUnknown[*retRow]) ||
(
diff != 0 && !signMatchesDirection(
diff , direction))) {
retRow = row;
@@ -1087,7 +1088,7 @@ void SimplexBase::markEmpty() {
/// We add the inequality and mark it as restricted. We then try to make its
/// sample value non-negative. If this is not possible, the tableau has become
/// empty and we mark it as such.
-void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
+void Simplex::addInequality(ArrayRef<MPInt> coeffs) {
unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true);
LogicalResult result = restoreRow(con[conIndex]);
if (failed(result))
@@ -1100,10 +1101,10 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
///
/// We simply add two opposing inequalities, which force the expression to
/// be zero.
-void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
+void SimplexBase::addEquality(ArrayRef<MPInt> coeffs) {
addInequality(coeffs);
- SmallVector<int64_t, 8> negatedCoeffs;
- for (int64_t coeff : coeffs)
+ SmallVector<MPInt, 8> negatedCoeffs;
+ for (const MPInt &coeff : coeffs)
negatedCoeffs.emplace_back(-coeff);
addInequality(negatedCoeffs);
}
@@ -1278,17 +1279,18 @@ void SimplexBase::rollback(unsigned snapshot) {
///
/// This constrains the remainder `coeffs - denom*q` to be in the
/// range `[0, denom - 1]`, which fixes the integer value of the quotient `q`.
-void SimplexBase::addDivisionVariable(ArrayRef<int64_t> coeffs, int64_t denom) {
- assert(denom != 0 && "Cannot divide by zero!\n");
+void SimplexBase::addDivisionVariable(ArrayRef<MPInt> coeffs,
+ const MPInt &denom) {
+ assert(denom > 0 && "Denominator must be positive!");
appendVariable();
- SmallVector<int64_t, 8> ineq(coeffs.begin(), coeffs.end());
- int64_t constTerm = ineq.back();
+ SmallVector<MPInt, 8> ineq(coeffs.begin(), coeffs.end());
+ MPInt constTerm = ineq.back();
ineq.back() = -denom;
ineq.push_back(constTerm);
addInequality(ineq);
- for (int64_t &coeff : ineq)
+ for (MPInt &coeff : ineq)
coeff = -coeff;
ineq.back() += denom - 1;
addInequality(ineq);
@@ -1338,7 +1340,7 @@ MaybeOptimum<Fraction> Simplex::computeRowOptimum(Direction direction,
/// Compute the optimum of the specified expression in the specified direction,
/// or None if it is unbounded.
MaybeOptimum<Fraction> Simplex::computeOptimum(Direction direction,
- ArrayRef<int64_t> coeffs) {
+ ArrayRef<MPInt> coeffs) {
if (empty)
return OptimumKind::Empty;
@@ -1447,7 +1449,7 @@ bool Simplex::isUnbounded() {
if (empty)
return false;
- SmallVector<int64_t, 8> dir(var.size() + 1);
+ SmallVector<MPInt, 8> dir(var.size() + 1);
for (unsigned i = 0; i < var.size(); ++i) {
dir[i] = 1;
@@ -1557,14 +1559,14 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
} else {
// If the variable is in row position, its sample value is the
// entry in the constant column divided by the denominator.
- int64_t denom = tableau(u.pos, 0);
+ MPInt denom = tableau(u.pos, 0);
sample.emplace_back(tableau(u.pos, 1), denom);
}
}
return sample;
}
-void LexSimplexBase::addInequality(ArrayRef<int64_t> coeffs) {
+void LexSimplexBase::addInequality(ArrayRef<MPInt> coeffs) {
addRow(coeffs, /*makeRestricted=*/true);
}
@@ -1589,7 +1591,7 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
// If the variable is in row position, its sample value is the
// entry in the constant column divided by the denominator.
- int64_t denom = tableau(u.pos, 0);
+ MPInt denom = tableau(u.pos, 0);
if (usingBigM)
if (tableau(u.pos, 2) != denom)
return OptimumKind::Unbounded;
@@ -1598,14 +1600,14 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
return sample;
}
-Optional<SmallVector<int64_t, 8>> Simplex::getSamplePointIfIntegral() const {
+Optional<SmallVector<MPInt, 8>> Simplex::getSamplePointIfIntegral() const {
// If the tableau is empty, no sample point exists.
if (empty)
return {};
// The value will always exist since the Simplex is non-empty.
SmallVector<Fraction, 8> rationalSample = *getRationalSample();
- SmallVector<int64_t, 8> integerSample;
+ SmallVector<MPInt, 8> integerSample;
integerSample.reserve(var.size());
for (const Fraction &coord : rationalSample) {
// If the sample is non-integral, return None.
@@ -1637,14 +1639,14 @@ class presburger::GBRSimplex {
/// Add an equality dotProduct(dir, x - y) == 0.
/// First pushes a snapshot for the current simplex state to the stack so
/// that this can be rolled back later.
- void addEqualityForDirection(ArrayRef<int64_t> dir) {
- assert(llvm::any_of(dir, [](int64_t x) { return x != 0; }) &&
+ void addEqualityForDirection(ArrayRef<MPInt> dir) {
+ assert(llvm::any_of(dir, [](const MPInt &x) { return x != 0; }) &&
"Direction passed is the zero vector!");
snapshotStack.push_back(simplex.getSnapshot());
simplex.addEquality(getCoeffsForDirection(dir));
}
/// Compute max(dotProduct(dir, x - y)).
- Fraction computeWidth(ArrayRef<int64_t> dir) {
+ Fraction computeWidth(ArrayRef<MPInt> dir) {
MaybeOptimum<Fraction> maybeWidth =
simplex.computeOptimum(Direction::Up, getCoeffsForDirection(dir));
assert(maybeWidth.isBounded() && "Width should be bounded!");
@@ -1653,9 +1655,9 @@ class presburger::GBRSimplex {
/// Compute max(dotProduct(dir, x - y)) and save the dual variables for only
/// the direction equalities to `dual`.
- Fraction computeWidthAndDuals(ArrayRef<int64_t> dir,
- SmallVectorImpl<int64_t> &dual,
- int64_t &dualDenom) {
+ Fraction computeWidthAndDuals(ArrayRef<MPInt> dir,
+ SmallVectorImpl<MPInt> &dual,
+ MPInt &dualDenom) {
// We can't just call into computeWidth or computeOptimum since we need to
// access the state of the tableau after computing the optimum, and these
// functions rollback the insertion of the objective function into the
@@ -1723,12 +1725,12 @@ class presburger::GBRSimplex {
/// i.e., dir_1 * x_1 + dir_2 * x_2 + ... + dir_n * x_n
/// - dir_1 * y_1 - dir_2 * y_2 - ... - dir_n * y_n,
/// where n is the dimension of the original polytope.
- SmallVector<int64_t, 8> getCoeffsForDirection(ArrayRef<int64_t> dir) {
+ SmallVector<MPInt, 8> getCoeffsForDirection(ArrayRef<MPInt> dir) {
assert(2 * dir.size() == simplex.getNumVariables() &&
"Direction vector has wrong dimensionality");
- SmallVector<int64_t, 8> coeffs(dir.begin(), dir.end());
+ SmallVector<MPInt, 8> coeffs(dir.begin(), dir.end());
coeffs.reserve(2 * dir.size());
- for (int64_t coeff : dir)
+ for (const MPInt &coeff : dir)
coeffs.push_back(-coeff);
coeffs.emplace_back(0); // constant term
return coeffs;
@@ -1805,8 +1807,8 @@ void Simplex::reduceBasis(Matrix &basis, unsigned level) {
GBRSimplex gbrSimplex(*this);
SmallVector<Fraction, 8> width;
- SmallVector<int64_t, 8> dual;
- int64_t dualDenom;
+ SmallVector<MPInt, 8> dual;
+ MPInt dualDenom;
// Finds the value of u that minimizes width_i(b_{i+1} + u*b_i), caches the
// duals from this computation, sets b_{i+1} to b_{i+1} + u*b_i, and returns
@@ -1829,11 +1831,11 @@ void Simplex::reduceBasis(Matrix &basis, unsigned level) {
auto updateBasisWithUAndGetFCandidate = [&](unsigned i) -> Fraction {
assert(i < level + dual.size() && "dual_i is not known!");
- int64_t u = floorDiv(dual[i - level], dualDenom);
+ MPInt u = floorDiv(dual[i - level], dualDenom);
basis.addToRow(i, i + 1, u);
if (dual[i - level] % dualDenom != 0) {
- SmallVector<int64_t, 8> candidateDual[2];
- int64_t candidateDualDenom[2];
+ SmallVector<MPInt, 8> candidateDual[2];
+ MPInt candidateDualDenom[2];
Fraction widthI[2];
// Initially u is floor(dual) and basis reflects this.
@@ -1860,11 +1862,13 @@ void Simplex::reduceBasis(Matrix &basis, unsigned level) {
// Check the value at u - 1.
assert(gbrSimplex.computeWidth(scaleAndAddForAssert(
- basis.getRow(i + 1), -1, basis.getRow(i))) >= widthI[j] &&
+ basis.getRow(i + 1), MPInt(-1), basis.getRow(i))) >=
+ widthI[j] &&
"Computed u value does not minimize the width!");
// Check the value at u + 1.
assert(gbrSimplex.computeWidth(scaleAndAddForAssert(
- basis.getRow(i + 1), +1, basis.getRow(i))) >= widthI[j] &&
+ basis.getRow(i + 1), MPInt(+1), basis.getRow(i))) >=
+ widthI[j] &&
"Computed u value does not minimize the width!");
dual = std::move(candidateDual[j]);
@@ -1964,7 +1968,7 @@ void Simplex::reduceBasis(Matrix &basis, unsigned level) {
///
/// To avoid potentially arbitrarily large recursion depths leading to stack
/// overflows, this algorithm is implemented iteratively.
-Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
+Optional<SmallVector<MPInt, 8>> Simplex::findIntegerSample() {
if (empty)
return {};
@@ -1975,9 +1979,9 @@ Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
// The snapshot just before constraining a direction to a value at each level.
SmallVector<unsigned, 8> snapshotStack;
// The maximum value in the range of the direction for each level.
- SmallVector<int64_t, 8> upperBoundStack;
+ SmallVector<MPInt, 8> upperBoundStack;
// The next value to try constraining the basis vector to at each level.
- SmallVector<int64_t, 8> nextValueStack;
+ SmallVector<MPInt, 8> nextValueStack;
snapshotStack.reserve(basis.getNumRows());
upperBoundStack.reserve(basis.getNumRows());
@@ -1997,7 +2001,7 @@ Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
// just come down a level ("recursed"). Find the lower and upper bounds.
// If there is more than one integer point in the range, perform
// generalized basis reduction.
- SmallVector<int64_t, 8> basisCoeffs =
+ SmallVector<MPInt, 8> basisCoeffs =
llvm::to_vector<8>(basis.getRow(level));
basisCoeffs.emplace_back(0);
@@ -2049,7 +2053,7 @@ Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
// to the snapshot of the starting state at this level. (in the "recursed"
// case this has no effect)
rollback(snapshotStack.back());
- int64_t nextValue = nextValueStack.back();
+ MPInt nextValue = nextValueStack.back();
++nextValueStack.back();
if (nextValue > upperBoundStack.back()) {
// We have exhausted the range and found no solution. Pop the stack and
@@ -2062,8 +2066,8 @@ Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
}
// Try the next value in the range and "recurse" into the next level.
- SmallVector<int64_t, 8> basisCoeffs(basis.getRow(level).begin(),
- basis.getRow(level).end());
+ SmallVector<MPInt, 8> basisCoeffs(basis.getRow(level).begin(),
+ basis.getRow(level).end());
basisCoeffs.push_back(-nextValue);
addEquality(basisCoeffs);
level++;
@@ -2074,11 +2078,11 @@ Optional<SmallVector<int64_t, 8>> Simplex::findIntegerSample() {
/// Compute the minimum and maximum integer values the expression can take. We
/// compute each separately.
-std::pair<MaybeOptimum<int64_t>, MaybeOptimum<int64_t>>
-Simplex::computeIntegerBounds(ArrayRef<int64_t> coeffs) {
- MaybeOptimum<int64_t> minRoundedUp(
+std::pair<MaybeOptimum<MPInt>, MaybeOptimum<MPInt>>
+Simplex::computeIntegerBounds(ArrayRef<MPInt> coeffs) {
+ MaybeOptimum<MPInt> minRoundedUp(
computeOptimum(Simplex::Direction::Down, coeffs).map(ceil));
- MaybeOptimum<int64_t> maxRoundedDown(
+ MaybeOptimum<MPInt> maxRoundedDown(
computeOptimum(Simplex::Direction::Up, coeffs).map(floor));
return {minRoundedUp, maxRoundedDown};
}
@@ -2149,7 +2153,7 @@ bool Simplex::isRationalSubsetOf(const IntegerRelation &rel) {
/// maximum satisfy it. Hence, it is a cut inequality. If both are < 0, no
/// points of the polytope satisfy the inequality, which means it is a separate
/// inequality.
-Simplex::IneqType Simplex::findIneqType(ArrayRef<int64_t> coeffs) {
+Simplex::IneqType Simplex::findIneqType(ArrayRef<MPInt> coeffs) {
MaybeOptimum<Fraction> minimum = computeOptimum(Direction::Down, coeffs);
if (minimum.isBounded() && *minimum >= Fraction(0, 1)) {
return IneqType::Redundant;
@@ -2164,7 +2168,7 @@ Simplex::IneqType Simplex::findIneqType(ArrayRef<int64_t> coeffs) {
/// Checks whether the type of the inequality with coefficients `coeffs`
/// is Redundant.
-bool Simplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
+bool Simplex::isRedundantInequality(ArrayRef<MPInt> coeffs) {
assert(!empty &&
"It is not meaningful to ask about redundancy in an empty set!");
return findIneqType(coeffs) == IneqType::Redundant;
@@ -2174,7 +2178,7 @@ bool Simplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
/// the existing constraints. This is redundant when `coeffs` is already
/// always zero under the existing constraints. `coeffs` is always zero
/// when the minimum and maximum value that `coeffs` can take are both zero.
-bool Simplex::isRedundantEquality(ArrayRef<int64_t> coeffs) {
+bool Simplex::isRedundantEquality(ArrayRef<MPInt> coeffs) {
assert(!empty &&
"It is not meaningful to ask about redundancy in an empty set!");
MaybeOptimum<Fraction> minimum = computeOptimum(Direction::Down, coeffs);
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index e22fa02cfce9..5da3e7a0c81f 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/MPInt.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include <numeric>
@@ -23,14 +24,16 @@ using namespace presburger;
/// Normalize a division's `dividend` and the `divisor` by their GCD. For
/// example: if the dividend and divisor are [2,0,4] and 4 respectively,
-/// they get normalized to [1,0,2] and 2.
-static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend,
- unsigned &divisor) {
+/// they get normalized to [1,0,2] and 2. The divisor must be non-negative;
+/// it is allowed for the divisor to be zero, but nothing is done in this case.
+static void normalizeDivisionByGCD(MutableArrayRef<MPInt> dividend,
+ MPInt &divisor) {
+ assert(divisor > 0 && "divisor must be non-negative!");
if (divisor == 0 || dividend.empty())
return;
// We take the absolute value of dividend's coefficients to make sure that
// `gcd` is positive.
- int64_t gcd = std::gcd(std::abs(dividend.front()), int64_t(divisor));
+ MPInt gcd = presburger::gcd(abs(dividend.front()), divisor);
// The reason for ignoring the constant term is as follows.
// For a division:
@@ -40,14 +43,14 @@ static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend,
// Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not
// influence the result of the floor division and thus, can be ignored.
for (size_t i = 1, m = dividend.size() - 1; i < m; i++) {
- gcd = std::gcd(std::abs(dividend[i]), gcd);
+ gcd = presburger::gcd(abs(dividend[i]), gcd);
if (gcd == 1)
return;
}
// Normalize the dividend and the denominator.
std::transform(dividend.begin(), dividend.end(), dividend.begin(),
- [gcd](int64_t &n) { return floorDiv(n, gcd); });
+ [gcd](MPInt &n) { return floorDiv(n, gcd); });
divisor /= gcd;
}
@@ -87,12 +90,11 @@ static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend,
/// -divisor * var + expr - c >= 0 <-- Upper bound for 'var'
///
/// If successful, `expr` is set to dividend of the division and `divisor` is
-/// set to the denominator of the division. The final division expression is
-/// normalized by GCD.
+/// set to the denominator of the division, which will be positive.
+/// The final division expression is normalized by GCD.
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
unsigned ubIneq, unsigned lbIneq,
- MutableArrayRef<int64_t> expr,
- unsigned &divisor) {
+ MutableArrayRef<MPInt> expr, MPInt &divisor) {
assert(pos <= cst.getNumVars() && "Invalid variable position");
assert(ubIneq <= cst.getNumInequalities() &&
@@ -100,6 +102,8 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
assert(lbIneq <= cst.getNumInequalities() &&
"Invalid upper bound inequality position");
assert(expr.size() == cst.getNumCols() && "Invalid expression size");
+ assert(cst.atIneq(lbIneq, pos) > 0 && "lbIneq is not a lower bound!");
+ assert(cst.atIneq(ubIneq, pos) < 0 && "ubIneq is not an upper bound!");
// Extract divisor from the lower bound.
divisor = cst.atIneq(lbIneq, pos);
@@ -117,12 +121,12 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
// Then, check if the constant term is of the proper form.
// Due to the form of the upper/lower bound inequalities, the sum of their
// constants is `divisor - 1 - c`. From this, we can extract c:
- int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
- cst.atIneq(ubIneq, cst.getNumCols() - 1);
- int64_t c = divisor - 1 - constantSum;
+ MPInt constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
+ cst.atIneq(ubIneq, cst.getNumCols() - 1);
+ MPInt c = divisor - 1 - constantSum;
- // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also
- // implictly checks that `divisor` is positive.
+ // Check if `c` satisfies the condition `0 <= c <= divisor - 1`.
+ // This also implictly checks that `divisor` is positive.
if (!(0 <= c && c <= divisor - 1)) // NOLINT
return failure();
@@ -154,8 +158,8 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
/// set to the denominator of the division. The final division expression is
/// normalized by GCD.
static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
- unsigned eqInd, MutableArrayRef<int64_t> expr,
- unsigned &divisor) {
+ unsigned eqInd, MutableArrayRef<MPInt> expr,
+ MPInt &divisor) {
assert(pos <= cst.getNumVars() && "Invalid variable position");
assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
@@ -164,10 +168,10 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
// Extract divisor, the divisor can be negative and hence its sign information
// is stored in `signDiv` to reverse the sign of dividend's coefficients.
// Equality must involve the pos-th variable and hence `tempDiv` != 0.
- int64_t tempDiv = cst.atEq(eqInd, pos);
+ MPInt tempDiv = cst.atEq(eqInd, pos);
if (tempDiv == 0)
return failure();
- int64_t signDiv = tempDiv < 0 ? -1 : 1;
+ int signDiv = tempDiv < 0 ? -1 : 1;
// The divisor is always a positive integer.
divisor = tempDiv * signDiv;
@@ -186,7 +190,7 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
// explicit representation has not been found yet, otherwise returns `true`.
static bool checkExplicitRepresentation(const IntegerRelation &cst,
ArrayRef<bool> foundRepr,
- ArrayRef<int64_t> dividend,
+ ArrayRef<MPInt> dividend,
unsigned pos) {
// Exit to avoid circular dependencies between divisions.
for (unsigned c = 0, e = cst.getNumVars(); c < e; ++c) {
@@ -215,9 +219,11 @@ static bool checkExplicitRepresentation(const IntegerRelation &cst,
/// the representation could be computed, `dividend` and `denominator` are set.
/// If the representation could not be computed, the kind attribute in
/// `MaybeLocalRepr` is set to None.
-MaybeLocalRepr presburger::computeSingleVarRepr(
- const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
- MutableArrayRef<int64_t> dividend, unsigned &divisor) {
+MaybeLocalRepr presburger::computeSingleVarRepr(const IntegerRelation &cst,
+ ArrayRef<bool> foundRepr,
+ unsigned pos,
+ MutableArrayRef<MPInt> dividend,
+ MPInt &divisor) {
assert(pos < cst.getNumVars() && "invalid position");
assert(foundRepr.size() == cst.getNumVars() &&
"Size of foundRepr does not match total number of variables");
@@ -256,6 +262,18 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
return repr;
}
+MaybeLocalRepr presburger::computeSingleVarRepr(
+ const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
+ SmallVector<int64_t, 8> ÷nd, unsigned &divisor) {
+ SmallVector<MPInt, 8> dividendMPInt(cst.getNumCols());
+ MPInt divisorMPInt;
+ MaybeLocalRepr result =
+ computeSingleVarRepr(cst, foundRepr, pos, dividendMPInt, divisorMPInt);
+ dividend = getInt64Vec(dividendMPInt);
+ divisor = unsigned(int64_t(divisorMPInt));
+ return result;
+}
+
llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
unsigned setOffset,
unsigned numSet) {
@@ -290,78 +308,80 @@ void presburger::mergeLocalVars(
divsA.removeDuplicateDivs(merge);
}
-SmallVector<int64_t, 8> presburger::getDivUpperBound(ArrayRef<int64_t> dividend,
- int64_t divisor,
- unsigned localVarIdx) {
+SmallVector<MPInt, 8> presburger::getDivUpperBound(ArrayRef<MPInt> dividend,
+ const MPInt &divisor,
+ unsigned localVarIdx) {
+ assert(divisor > 0 && "divisor must be positive!");
assert(dividend[localVarIdx] == 0 &&
"Local to be set to division must have zero coeff!");
- SmallVector<int64_t, 8> ineq(dividend.begin(), dividend.end());
+ SmallVector<MPInt, 8> ineq(dividend.begin(), dividend.end());
ineq[localVarIdx] = -divisor;
return ineq;
}
-SmallVector<int64_t, 8> presburger::getDivLowerBound(ArrayRef<int64_t> dividend,
- int64_t divisor,
- unsigned localVarIdx) {
+SmallVector<MPInt, 8> presburger::getDivLowerBound(ArrayRef<MPInt> dividend,
+ const MPInt &divisor,
+ unsigned localVarIdx) {
+ assert(divisor > 0 && "divisor must be positive!");
assert(dividend[localVarIdx] == 0 &&
"Local to be set to division must have zero coeff!");
- SmallVector<int64_t, 8> ineq(dividend.size());
+ SmallVector<MPInt, 8> ineq(dividend.size());
std::transform(dividend.begin(), dividend.end(), ineq.begin(),
- std::negate<int64_t>());
+ std::negate<MPInt>());
ineq[localVarIdx] = divisor;
ineq.back() += divisor - 1;
return ineq;
}
-int64_t presburger::gcdRange(ArrayRef<int64_t> range) {
- int64_t gcd = 0;
- for (int64_t elem : range) {
- gcd = std::gcd((uint64_t)gcd, (uint64_t)std::abs(elem));
+MPInt presburger::gcdRange(ArrayRef<MPInt> range) {
+ MPInt gcd(0);
+ for (const MPInt &elem : range) {
+ gcd = presburger::gcd(gcd, abs(elem));
if (gcd == 1)
return gcd;
}
return gcd;
}
-int64_t presburger::normalizeRange(MutableArrayRef<int64_t> range) {
- int64_t gcd = gcdRange(range);
- if (gcd == 0 || gcd == 1)
+MPInt presburger::normalizeRange(MutableArrayRef<MPInt> range) {
+ MPInt gcd = gcdRange(range);
+ if ((gcd == 0) || (gcd == 1))
return gcd;
- for (int64_t &elem : range)
+ for (MPInt &elem : range)
elem /= gcd;
return gcd;
}
-void presburger::normalizeDiv(MutableArrayRef<int64_t> num, int64_t &denom) {
+void presburger::normalizeDiv(MutableArrayRef<MPInt> num, MPInt &denom) {
assert(denom > 0 && "denom must be positive!");
- int64_t gcd = std::gcd(gcdRange(num), denom);
- for (int64_t &coeff : num)
+ MPInt gcd = presburger::gcd(gcdRange(num), denom);
+ for (MPInt &coeff : num)
coeff /= gcd;
denom /= gcd;
}
-SmallVector<int64_t, 8> presburger::getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
- SmallVector<int64_t, 8> negatedCoeffs;
+SmallVector<MPInt, 8> presburger::getNegatedCoeffs(ArrayRef<MPInt> coeffs) {
+ SmallVector<MPInt, 8> negatedCoeffs;
negatedCoeffs.reserve(coeffs.size());
- for (int64_t coeff : coeffs)
+ for (const MPInt &coeff : coeffs)
negatedCoeffs.emplace_back(-coeff);
return negatedCoeffs;
}
-SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
- SmallVector<int64_t, 8> coeffs;
+SmallVector<MPInt, 8> presburger::getComplementIneq(ArrayRef<MPInt> ineq) {
+ SmallVector<MPInt, 8> coeffs;
coeffs.reserve(ineq.size());
- for (int64_t coeff : ineq)
+ for (const MPInt &coeff : ineq)
coeffs.emplace_back(-coeff);
--coeffs.back();
return coeffs;
}
-SmallVector<Optional<int64_t>, 4>
-DivisionRepr::divValuesAt(ArrayRef<int64_t> point) const {
+SmallVector<Optional<MPInt>, 4>
+DivisionRepr::divValuesAt(ArrayRef<MPInt> point) const {
assert(point.size() == getNumNonDivs() && "Incorrect point size");
- SmallVector<Optional<int64_t>, 4> divValues(getNumDivs(), None);
+ SmallVector<Optional<MPInt>, 4> divValues(getNumDivs(), None);
bool changed = true;
while (changed) {
changed = false;
@@ -370,8 +390,8 @@ DivisionRepr::divValuesAt(ArrayRef<int64_t> point) const {
if (divValues[i])
continue;
- ArrayRef<int64_t> dividend = getDividend(i);
- int64_t divVal = 0;
+ ArrayRef<MPInt> dividend = getDividend(i);
+ MPInt divVal(0);
// Check if we have all the division values required for this division.
unsigned j, f;
@@ -451,8 +471,8 @@ void DivisionRepr::removeDuplicateDivs(
}
}
-void DivisionRepr::insertDiv(unsigned pos, ArrayRef<int64_t> dividend,
- unsigned divisor) {
+void DivisionRepr::insertDiv(unsigned pos, ArrayRef<MPInt> dividend,
+ const MPInt &divisor) {
assert(pos <= getNumDivs() && "Invalid insertion position");
assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size");
@@ -465,7 +485,7 @@ void DivisionRepr::insertDiv(unsigned pos, unsigned num) {
assert(pos <= getNumDivs() && "Invalid insertion position");
dividends.insertColumns(getDivOffset() + pos, num);
dividends.insertRows(pos, num);
- denoms.insert(denoms.begin() + pos, num, 0);
+ denoms.insert(denoms.begin() + pos, num, MPInt(0));
}
void DivisionRepr::print(raw_ostream &os) const {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
index 5c0917d865d8..274c941477c8 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp
@@ -430,10 +430,12 @@ static void computeDirectionVector(
dependenceComponents->resize(numCommonLoops);
for (unsigned j = 0; j < numCommonLoops; ++j) {
(*dependenceComponents)[j].op = commonLoops[j].getOperation();
- auto lbConst = dependenceDomain->getConstantBound(IntegerPolyhedron::LB, j);
+ auto lbConst =
+ dependenceDomain->getConstantBound64(IntegerPolyhedron::LB, j);
(*dependenceComponents)[j].lb =
lbConst.value_or(std::numeric_limits<int64_t>::min());
- auto ubConst = dependenceDomain->getConstantBound(IntegerPolyhedron::UB, j);
+ auto ubConst =
+ dependenceDomain->getConstantBound64(IntegerPolyhedron::UB, j);
(*dependenceComponents)[j].ub =
ubConst.value_or(std::numeric_limits<int64_t>::max());
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 8d7d63eead32..9e86d1ea58c6 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -755,14 +755,14 @@ static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
// Check for the aforementioned conditions in each equality.
for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
curEquality < numEqualities; curEquality++) {
- int64_t coefficientAtPos = cst.atEq(curEquality, pos);
+ int64_t coefficientAtPos = cst.atEq64(curEquality, pos);
// If current equality does not involve `var_r`, continue to the next
// equality.
if (coefficientAtPos == 0)
continue;
// Constant term should be 0 in this equality.
- if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
+ if (cst.atEq64(curEquality, cst.getNumCols() - 1) != 0)
continue;
// Traverse through the equality and construct the dividend expression
@@ -784,7 +784,7 @@ static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
// Ignore var_r.
if (curVar == pos)
continue;
- int64_t coefficientOfCurVar = cst.atEq(curEquality, curVar);
+ int64_t coefficientOfCurVar = cst.atEq64(curEquality, curVar);
// Ignore vars that do not contribute to the current equality.
if (coefficientOfCurVar == 0)
continue;
@@ -825,8 +825,8 @@ static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
// Express `var_r` as `var_n % divisor` and store the expression in `memo`.
if (quotientCount >= 1) {
- auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB,
- dimExpr.getPosition());
+ auto ub = cst.getConstantBound64(
+ FlatAffineValueConstraints::BoundType::UB, dimExpr.getPosition());
// If `var_n` has an upperbound that is less than the divisor, mod can be
// eliminated altogether.
if (ub && *ub < divisor)
@@ -910,7 +910,7 @@ FlatAffineValueConstraints::getLowerAndUpperBound(
lbExprs.reserve(lbIndices.size() + eqIndices.size());
// Lower bound expressions.
for (auto idx : lbIndices) {
- auto ineq = getInequality(idx);
+ auto ineq = getInequality64(idx);
// Extract the lower bound (in terms of other coeff's + const), i.e., if
// i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
// - 1.
@@ -928,7 +928,7 @@ FlatAffineValueConstraints::getLowerAndUpperBound(
ubExprs.reserve(ubIndices.size() + eqIndices.size());
// Upper bound expressions.
for (auto idx : ubIndices) {
- auto ineq = getInequality(idx);
+ auto ineq = getInequality64(idx);
// Extract the upper bound (in terms of other coeff's + const).
addCoeffs(ineq, ub);
auto expr =
@@ -941,7 +941,7 @@ FlatAffineValueConstraints::getLowerAndUpperBound(
// Equalities. It's both a lower and a upper bound.
SmallVector<int64_t, 4> b;
for (auto idx : eqIndices) {
- auto eq = getEquality(idx);
+ auto eq = getEquality64(idx);
addCoeffs(eq, b);
if (eq[pos + offset] > 0)
std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
@@ -1004,8 +1004,8 @@ void FlatAffineValueConstraints::getSliceBounds(
if (memo[pos])
continue;
- auto lbConst = getConstantBound(BoundType::LB, pos);
- auto ubConst = getConstantBound(BoundType::UB, pos);
+ auto lbConst = getConstantBound64(BoundType::LB, pos);
+ auto ubConst = getConstantBound64(BoundType::UB, pos);
if (lbConst.has_value() && ubConst.has_value()) {
// Detect equality to a constant.
if (lbConst.value() == ubConst.value()) {
@@ -1042,7 +1042,7 @@ void FlatAffineValueConstraints::getSliceBounds(
for (j = 0, e = getNumVars(); j < e; ++j) {
if (j == pos)
continue;
- int64_t c = atEq(idx, j);
+ int64_t c = atEq64(idx, j);
if (c == 0)
continue;
// If any of the involved IDs hasn't been found yet, we can't proceed.
@@ -1056,8 +1056,8 @@ void FlatAffineValueConstraints::getSliceBounds(
continue;
// Add constant term to AffineExpr.
- expr = expr + atEq(idx, getNumVars());
- int64_t vPos = atEq(idx, pos);
+ expr = expr + atEq64(idx, getNumVars());
+ int64_t vPos = atEq64(idx, pos);
assert(vPos != 0 && "expected non-zero here");
if (vPos > 0)
expr = (-expr).floorDiv(vPos);
@@ -1116,7 +1116,7 @@ void FlatAffineValueConstraints::getSliceBounds(
if (!lbMap || lbMap.getNumResults() > 1) {
LLVM_DEBUG(llvm::dbgs()
<< "WARNING: Potentially over-approximating slice lb\n");
- auto lbConst = getConstantBound(BoundType::LB, pos + offset);
+ auto lbConst = getConstantBound64(BoundType::LB, pos + offset);
if (lbConst.has_value()) {
lbMap =
AffineMap::get(numMapDims, numMapSymbols,
@@ -1126,7 +1126,7 @@ void FlatAffineValueConstraints::getSliceBounds(
if (!ubMap || ubMap.getNumResults() > 1) {
LLVM_DEBUG(llvm::dbgs()
<< "WARNING: Potentially over-approximating slice ub\n");
- auto ubConst = getConstantBound(BoundType::UB, pos + offset);
+ auto ubConst = getConstantBound64(BoundType::UB, pos + offset);
if (ubConst.has_value()) {
ubMap = AffineMap::get(
numMapDims, numMapSymbols,
@@ -1486,7 +1486,7 @@ void FlatAffineValueConstraints::getIneqAsAffineValueMap(
auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
// Compute the AffineExpr lower/upper bound for this inequality.
- ArrayRef<int64_t> inequality = getInequality(ineqPos);
+ SmallVector<int64_t, 8> inequality = getInequality64(ineqPos);
SmallVector<int64_t, 8> bound;
bound.reserve(getNumCols() - 1);
// Everything other than the coefficient at `pos`.
@@ -1560,10 +1560,10 @@ FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const {
exprs.reserve(getNumConstraints());
for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
- exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
- localExprs, context));
+ exprs.push_back(getAffineExprFromFlatForm(getEquality64(i), numDims,
+ numSyms, localExprs, context));
for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
- exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
+ exprs.push_back(getAffineExprFromFlatForm(getInequality64(i), numDims,
numSyms, localExprs, context));
return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
}
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 05d55f8514bb..0621cf34615b 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -371,7 +371,7 @@ Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
for (unsigned d = 0; d < rank; d++) {
SmallVector<int64_t, 4> lb;
Optional<int64_t>
diff =
- cstWithShapeBounds.getConstantBoundOnDimSize(d, &lb, &lbDivisor);
+ cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor);
if (
diff .has_value()) {
diff Constant =
diff .value();
assert(
diff Constant >= 0 && "Dim size bound can't be negative");
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 703ff5b5d744..9546cb3814bf 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1801,7 +1801,7 @@ MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b,
newShape[d] = -1;
} else {
// The lower bound for the shape is always zero.
- auto ubConst = fac.getConstantBound(IntegerPolyhedron::UB, d);
+ auto ubConst = fac.getConstantBound64(IntegerPolyhedron::UB, d);
// For a static memref and an affine map with no symbols, this is
// always bounded.
assert(ubConst && "should always have an upper bound");
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index d70e81cc6d45..c6cf0522e5b9 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -315,7 +315,7 @@ void getUpperBoundForIndex(Value value, AffineMap &boundMap,
// of the terminals of the index computation.
unsigned pos = getPosition(value);
if (constantRequired) {
- auto ubConst = constraints.getConstantBound(
+ auto ubConst = constraints.getConstantBound64(
FlatAffineValueConstraints::BoundType::UB, pos);
if (!ubConst)
return;
diff --git a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
index c2d004089faa..ee92ecf8d7c8 100644
--- a/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/AffineCanonicalizationUtils.cpp
@@ -189,7 +189,7 @@ canonicalizeMinMaxOp(RewriterBase &rewriter, Operation *op, AffineMap map,
// Skip unused operands and operands that are already constants.
if (!newOperands[i] || getConstantIntValue(newOperands[i]))
continue;
- if (auto bound = constraints.getConstantBound(IntegerPolyhedron::EQ, i))
+ if (auto bound = constraints.getConstantBound64(IntegerPolyhedron::EQ, i))
newOperands[i] =
rewriter.create<arith::ConstantIndexOp>(op->getLoc(), *bound);
}
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index de7413973e1b..e8a71edf3151 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -39,8 +39,8 @@ makeSetFromConstraints(unsigned ids, ArrayRef<SmallVector<int64_t, 4>> ineqs,
return set;
}
-static void dump(ArrayRef<int64_t> vec) {
- for (int64_t x : vec)
+static void dump(ArrayRef<MPInt> vec) {
+ for (const MPInt &x : vec)
llvm::errs() << x << ' ';
llvm::errs() << '\n';
}
@@ -58,8 +58,8 @@ static void dump(ArrayRef<int64_t> vec) {
/// opposite of hasSample.
static void checkSample(bool hasSample, const IntegerPolyhedron &poly,
TestFunction fn = TestFunction::Sample) {
- Optional<SmallVector<int64_t, 8>> maybeSample;
- MaybeOptimum<SmallVector<int64_t, 8>> maybeLexMin;
+ Optional<SmallVector<MPInt, 8>> maybeSample;
+ MaybeOptimum<SmallVector<MPInt, 8>> maybeLexMin;
switch (fn) {
case TestFunction::Sample:
maybeSample = poly.findIntegerSample();
@@ -426,6 +426,12 @@ TEST(IntegerPolyhedronTest, FindSampleTest) {
"-7*x - 4*y + z + 1 >= 0,"
"2*x - 7*y - 8*z - 7 >= 0,"
"9*x + 8*y - 9*z - 7 >= 0)"));
+
+ checkSample(
+ true,
+ parsePoly(
+ "(x) : (1152921504606846977*(x floordiv 1152921504606846977) == x, "
+ "1152921504606846976*(x floordiv 1152921504606846976) == x)"));
}
TEST(IntegerPolyhedronTest, IsIntegerEmptyTest) {
@@ -569,10 +575,10 @@ TEST(IntegerPolyhedronTest, removeRedundantConstraintsTest) {
// y >= 128x >= 0.
poly5.removeRedundantConstraints();
EXPECT_EQ(poly5.getNumInequalities(), 3u);
- SmallVector<int64_t, 8> redundantConstraint = {0, 1, 0};
+ SmallVector<MPInt, 8> redundantConstraint = getMPIntVec({0, 1, 0});
for (unsigned i = 0; i < 3; ++i) {
// Ensure that the removed constraint was the redundant constraint [3].
- EXPECT_NE(poly5.getInequality(i), ArrayRef<int64_t>(redundantConstraint));
+ EXPECT_NE(poly5.getInequality(i), ArrayRef<MPInt>(redundantConstraint));
}
}
@@ -611,11 +617,12 @@ TEST(IntegerPolyhedronTest, addConstantLowerBound) {
static void checkDivisionRepresentation(
IntegerPolyhedron &poly,
const std::vector<SmallVector<int64_t, 8>> &expectedDividends,
- ArrayRef<unsigned> expectedDenominators) {
+ ArrayRef<int64_t> expectedDenominators) {
DivisionRepr divs = poly.getLocalReprs();
// Check that the `denominators` and `expectedDenominators` match.
- EXPECT_TRUE(expectedDenominators == divs.getDenoms());
+ EXPECT_EQ(ArrayRef<MPInt>(getMPIntVec(expectedDenominators)),
+ divs.getDenoms());
// Check that the `dividends` and `expectedDividends` match. If the
// denominator for a division is zero, we ignore its dividend.
@@ -637,7 +644,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprSimple) {
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 0, 0, 4},
{1, 0, 0, 100}};
- SmallVector<unsigned, 8> denoms = {10, 10};
+ SmallVector<int64_t, 8> denoms = {10, 10};
// Check if floordivs can be computed when no other inequalities exist
// and floor divs do not depend on each other.
@@ -656,7 +663,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprConstantFloorDiv) {
std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0, 0, 0, 0, 3},
{0, 0, 0, 0, 0, 0, 2}};
- SmallVector<unsigned, 8> denoms = {1, 1};
+ SmallVector<int64_t, 8> denoms = {1, 1};
// Check if floordivs with constant numerator can be computed.
checkDivisionRepresentation(poly, divisions, denoms);
@@ -680,7 +687,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprRecursive) {
{3, 0, 9, 2, 2, 0, 0, 10},
{0, 1, -123, 2, 0, -4, 0, 10}};
- SmallVector<unsigned, 8> denoms = {3, 5, 3};
+ SmallVector<int64_t, 8> denoms = {3, 5, 3};
// Check if floordivs which may depend on other floordivs can be computed.
checkDivisionRepresentation(poly, divisions, denoms);
@@ -701,7 +708,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprTightUpperBound) {
poly.removeRedundantConstraints();
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 0, 0}};
- SmallVector<unsigned, 8> denoms = {3};
+ SmallVector<int64_t, 8> denoms = {3};
// Check if the divisions can be computed even with a tighter upper bound.
checkDivisionRepresentation(poly, divisions, denoms);
@@ -714,7 +721,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprTightUpperBound) {
poly.convertToLocal(VarKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 1}};
- SmallVector<unsigned, 8> denoms = {4};
+ SmallVector<int64_t, 8> denoms = {4};
// Check if the divisions can be computed even with a tighter upper bound.
checkDivisionRepresentation(poly, divisions, denoms);
@@ -728,7 +735,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
poly.convertToLocal(VarKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0}};
- SmallVector<unsigned, 8> denoms = {4};
+ SmallVector<int64_t, 8> denoms = {4};
checkDivisionRepresentation(poly, divisions, denoms);
}
@@ -738,7 +745,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
poly.convertToLocal(VarKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0}};
- SmallVector<unsigned, 8> denoms = {4};
+ SmallVector<int64_t, 8> denoms = {4};
checkDivisionRepresentation(poly, divisions, denoms);
}
@@ -748,7 +755,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
poly.convertToLocal(VarKind::SetDim, 2, 3);
std::vector<SmallVector<int64_t, 8>> divisions = {{-1, -1, 0, 2}};
- SmallVector<unsigned, 8> denoms = {3};
+ SmallVector<int64_t, 8> denoms = {3};
checkDivisionRepresentation(poly, divisions, denoms);
}
@@ -764,7 +771,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEqualityAndInequality) {
std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0, 1},
{1, 1, 0, 0, 0}};
- SmallVector<unsigned, 8> denoms = {4, 3};
+ SmallVector<int64_t, 8> denoms = {4, 3};
checkDivisionRepresentation(poly, divisions, denoms);
}
@@ -777,7 +784,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNoRepr) {
poly.convertToLocal(VarKind::SetDim, 1, 2);
std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0}};
- SmallVector<unsigned, 8> denoms = {0};
+ SmallVector<int64_t, 8> denoms = {0};
// Check that no division is computed.
checkDivisionRepresentation(poly, divisions, denoms);
@@ -793,7 +800,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNegConstNormalize) {
// = floor((1/3) + (-1 - x)/2)
// = floor((-1 - x)/2).
std::vector<SmallVector<int64_t, 8>> divisions = {{-1, 0, -1}};
- SmallVector<unsigned, 8> denoms = {2};
+ SmallVector<int64_t, 8> denoms = {2};
checkDivisionRepresentation(poly, divisions, denoms);
}
@@ -1061,7 +1068,7 @@ TEST(IntegerPolyhedronTest, negativeDividends) {
// Merging triggers normalization.
std::vector<SmallVector<int64_t, 8>> divisions = {{-1, 0, 0, 1},
{-1, 0, 0, -2}};
- SmallVector<unsigned, 8> denoms = {2, 3};
+ SmallVector<int64_t, 8> denoms = {2, 3};
checkDivisionRepresentation(poly1, divisions, denoms);
}
@@ -1139,9 +1146,9 @@ TEST(IntegerPolyhedronTest, findRationalLexMin) {
}
void expectIntegerLexMin(const IntegerPolyhedron &poly, ArrayRef<int64_t> min) {
- auto lexMin = poly.findIntegerLexMin();
+ MaybeOptimum<SmallVector<MPInt, 8>> lexMin = poly.findIntegerLexMin();
ASSERT_TRUE(lexMin.isBounded());
- EXPECT_EQ(ArrayRef<int64_t>(*lexMin), min);
+ EXPECT_EQ(*lexMin, getMPIntVec(min));
}
void expectNoIntegerLexMin(OptimumKind kind, const IntegerPolyhedron &poly) {
@@ -1389,8 +1396,8 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
static void
expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
- Optional<uint64_t> trueVolume,
- Optional<uint64_t> resultBound) {
+ Optional<int64_t> trueVolume,
+ Optional<int64_t> resultBound) {
expectComputedVolumeIsValidOverapprox(poly.computeVolume(), trueVolume,
resultBound);
}
@@ -1442,19 +1449,24 @@ TEST(IntegerPolyhedronTest, computeVolume) {
/*trueVolume=*/{}, /*resultBound=*/{});
}
+bool containsPointNoLocal(const IntegerPolyhedron &poly,
+ ArrayRef<int64_t> point) {
+ return poly.containsPointNoLocal(getMPIntVec(point)).has_value();
+}
+
TEST(IntegerPolyhedronTest, containsPointNoLocal) {
IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)");
- EXPECT_TRUE(poly1.containsPointNoLocal({0}));
- EXPECT_FALSE(poly1.containsPointNoLocal({1}));
+ EXPECT_TRUE(containsPointNoLocal(poly1, {0}));
+ EXPECT_FALSE(containsPointNoLocal(poly1, {1}));
IntegerPolyhedron poly2 = parsePoly(
"(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)");
- EXPECT_TRUE(poly2.containsPointNoLocal({6}));
- EXPECT_FALSE(poly2.containsPointNoLocal({4}));
+ EXPECT_TRUE(containsPointNoLocal(poly2, {6}));
+ EXPECT_FALSE(containsPointNoLocal(poly2, {4}));
IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)");
- EXPECT_TRUE(poly3.containsPointNoLocal({0, 0}));
- EXPECT_FALSE(poly3.containsPointNoLocal({1, 0}));
+ EXPECT_TRUE(containsPointNoLocal(poly3, {0, 0}));
+ EXPECT_FALSE(containsPointNoLocal(poly3, {1, 0}));
}
TEST(IntegerPolyhedronTest, truncateEqualityRegressionTest) {
diff --git a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
index 5aad9f96ae21..32d9e532e1f6 100644
--- a/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/LinearTransformTest.cpp
@@ -23,8 +23,7 @@ void testColumnEchelonForm(const Matrix &m, unsigned expectedRank) {
// In column echelon form, each row's last non-zero value can be at most one
// column to the right of the last non-zero column among the previous rows.
for (unsigned row = 0, nRows = m.getNumRows(); row < nRows; ++row) {
- SmallVector<int64_t, 8> rowVec =
- transform.preMultiplyWithRow(m.getRow(row));
+ SmallVector<MPInt, 8> rowVec = transform.preMultiplyWithRow(m.getRow(row));
for (unsigned col = lastAllowedNonZeroCol + 1, nCols = m.getNumColumns();
col < nCols; ++col) {
EXPECT_EQ(rowVec[col], 0);
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index fd1a2a79ebc7..8e0f1c2217f2 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -780,8 +780,8 @@ TEST(SetTest, coalesceDivOtherContained) {
static void
expectComputedVolumeIsValidOverapprox(const PresburgerSet &set,
- Optional<uint64_t> trueVolume,
- Optional<uint64_t> resultBound) {
+ Optional<int64_t> trueVolume,
+ Optional<int64_t> resultBound) {
expectComputedVolumeIsValidOverapprox(set.computeVolume(), trueVolume,
resultBound);
}
diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
index df03935e16f3..f1a41e0fd0fc 100644
--- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
@@ -17,6 +17,30 @@
using namespace mlir;
using namespace presburger;
+/// Convenience functions to pass literals to Simplex.
+void addInequality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
+ simplex.addInequality(getMPIntVec(coeffs));
+}
+void addEquality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
+ simplex.addEquality(getMPIntVec(coeffs));
+}
+bool isRedundantInequality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+ return simplex.isRedundantInequality(getMPIntVec(coeffs));
+}
+bool isRedundantInequality(LexSimplex &simplex, ArrayRef<int64_t> coeffs) {
+ return simplex.isRedundantInequality(getMPIntVec(coeffs));
+}
+bool isRedundantEquality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+ return simplex.isRedundantEquality(getMPIntVec(coeffs));
+}
+bool isSeparateInequality(LexSimplex &simplex, ArrayRef<int64_t> coeffs) {
+ return simplex.isSeparateInequality(getMPIntVec(coeffs));
+}
+
+Simplex::IneqType findIneqType(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+ return simplex.findIneqType(getMPIntVec(coeffs));
+}
+
/// Take a snapshot, add constraints making the set empty, and rollback.
/// The set should not be empty after rolling back. We add additional
/// constraints after the set is already empty and roll back the addition
@@ -25,17 +49,17 @@ using namespace presburger;
TEST(SimplexTest, emptyRollback) {
Simplex simplex(2);
// (u - v) >= 0
- simplex.addInequality({1, -1, 0});
+ addInequality(simplex, {1, -1, 0});
ASSERT_FALSE(simplex.isEmpty());
unsigned snapshot = simplex.getSnapshot();
// (u - v) <= -1
- simplex.addInequality({-1, 1, -1});
+ addInequality(simplex, {-1, 1, -1});
ASSERT_TRUE(simplex.isEmpty());
unsigned snapshot2 = simplex.getSnapshot();
// (u - v) <= -3
- simplex.addInequality({-1, 1, -3});
+ addInequality(simplex, {-1, 1, -3});
ASSERT_TRUE(simplex.isEmpty());
simplex.rollback(snapshot2);
@@ -49,9 +73,9 @@ TEST(SimplexTest, emptyRollback) {
/// constraints.
TEST(SimplexTest, addEquality_separate) {
Simplex simplex(1);
- simplex.addInequality({1, -1}); // x >= 1.
+ addInequality(simplex, {1, -1}); // x >= 1.
ASSERT_FALSE(simplex.isEmpty());
- simplex.addEquality({1, 0}); // x == 0.
+ addEquality(simplex, {1, 0}); // x == 0.
EXPECT_TRUE(simplex.isEmpty());
}
@@ -59,7 +83,7 @@ void expectInequalityMakesSetEmpty(Simplex &simplex, ArrayRef<int64_t> coeffs,
bool expect) {
ASSERT_FALSE(simplex.isEmpty());
unsigned snapshot = simplex.getSnapshot();
- simplex.addInequality(coeffs);
+ addInequality(simplex, coeffs);
EXPECT_EQ(simplex.isEmpty(), expect);
simplex.rollback(snapshot);
}
@@ -82,7 +106,7 @@ TEST(SimplexTest, addInequality_rollback) {
expectInequalityMakesSetEmpty(simplex, checkCoeffs[1], false);
for (int i = 0; i < 4; i++)
- simplex.addInequality(coeffs[(run + i) % 4]);
+ addInequality(simplex, coeffs[(run + i) % 4]);
expectInequalityMakesSetEmpty(simplex, checkCoeffs[0], true);
expectInequalityMakesSetEmpty(simplex, checkCoeffs[1], true);
@@ -100,9 +124,9 @@ Simplex simplexFromConstraints(unsigned nDim,
ArrayRef<SmallVector<int64_t, 8>> eqs) {
Simplex simplex(nDim);
for (const auto &ineq : ineqs)
- simplex.addInequality(ineq);
+ addInequality(simplex, ineq);
for (const auto &eq : eqs)
- simplex.addEquality(eq);
+ addEquality(simplex, eq);
return simplex;
}
@@ -235,7 +259,7 @@ TEST(SimplexTest, getSamplePointIfIntegral) {
/// Some basic sanity checks involving zero or one variables.
TEST(SimplexTest, isMarkedRedundant_no_var_ge_zero) {
Simplex simplex(0);
- simplex.addInequality({0}); // 0 >= 0.
+ addInequality(simplex, {0}); // 0 >= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
@@ -244,7 +268,7 @@ TEST(SimplexTest, isMarkedRedundant_no_var_ge_zero) {
TEST(SimplexTest, isMarkedRedundant_no_var_eq) {
Simplex simplex(0);
- simplex.addEquality({0}); // 0 == 0.
+ addEquality(simplex, {0}); // 0 == 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_TRUE(simplex.isMarkedRedundant(0));
@@ -252,7 +276,7 @@ TEST(SimplexTest, isMarkedRedundant_no_var_eq) {
TEST(SimplexTest, isMarkedRedundant_pos_var_eq) {
Simplex simplex(1);
- simplex.addEquality({1, 0}); // x == 0.
+ addEquality(simplex, {1, 0}); // x == 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
@@ -261,7 +285,7 @@ TEST(SimplexTest, isMarkedRedundant_pos_var_eq) {
TEST(SimplexTest, isMarkedRedundant_zero_var_eq) {
Simplex simplex(1);
- simplex.addEquality({0, 0}); // 0x == 0.
+ addEquality(simplex, {0, 0}); // 0x == 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_TRUE(simplex.isMarkedRedundant(0));
@@ -269,7 +293,7 @@ TEST(SimplexTest, isMarkedRedundant_zero_var_eq) {
TEST(SimplexTest, isMarkedRedundant_neg_var_eq) {
Simplex simplex(1);
- simplex.addEquality({-1, 0}); // -x == 0.
+ addEquality(simplex, {-1, 0}); // -x == 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_FALSE(simplex.isMarkedRedundant(0));
@@ -277,7 +301,7 @@ TEST(SimplexTest, isMarkedRedundant_neg_var_eq) {
TEST(SimplexTest, isMarkedRedundant_pos_var_ge) {
Simplex simplex(1);
- simplex.addInequality({1, 0}); // x >= 0.
+ addInequality(simplex, {1, 0}); // x >= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_FALSE(simplex.isMarkedRedundant(0));
@@ -285,7 +309,7 @@ TEST(SimplexTest, isMarkedRedundant_pos_var_ge) {
TEST(SimplexTest, isMarkedRedundant_zero_var_ge) {
Simplex simplex(1);
- simplex.addInequality({0, 0}); // 0x >= 0.
+ addInequality(simplex, {0, 0}); // 0x >= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_TRUE(simplex.isMarkedRedundant(0));
@@ -293,7 +317,7 @@ TEST(SimplexTest, isMarkedRedundant_zero_var_ge) {
TEST(SimplexTest, isMarkedRedundant_neg_var_ge) {
Simplex simplex(1);
- simplex.addInequality({-1, 0}); // x <= 0.
+ addInequality(simplex, {-1, 0}); // x <= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_FALSE(simplex.isMarkedRedundant(0));
@@ -304,9 +328,9 @@ TEST(SimplexTest, isMarkedRedundant_neg_var_ge) {
TEST(SimplexTest, isMarkedRedundant_no_redundant) {
Simplex simplex(3);
- simplex.addEquality({-1, 0, 1, 0}); // u = w.
- simplex.addInequality({-1, 16, 0, 15}); // 15 - (u - 16v) >= 0.
- simplex.addInequality({1, -16, 0, 0}); // (u - 16v) >= 0.
+ addEquality(simplex, {-1, 0, 1, 0}); // u = w.
+ addInequality(simplex, {-1, 16, 0, 15}); // 15 - (u - 16v) >= 0.
+ addInequality(simplex, {1, -16, 0, 0}); // (u - 16v) >= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
@@ -319,14 +343,14 @@ TEST(SimplexTest, isMarkedRedundant_repeated_constraints) {
Simplex simplex(3);
// [4] to [7] are repeats of [0] to [3].
- simplex.addInequality({0, -1, 0, 1}); // [0]: y <= 1.
- simplex.addInequality({-1, 0, 8, 7}); // [1]: 8z >= x - 7.
- simplex.addInequality({1, 0, -8, 0}); // [2]: 8z <= x.
- simplex.addInequality({0, 1, 0, 0}); // [3]: y >= 0.
- simplex.addInequality({-1, 0, 8, 7}); // [4]: 8z >= 7 - x.
- simplex.addInequality({1, 0, -8, 0}); // [5]: 8z <= x.
- simplex.addInequality({0, 1, 0, 0}); // [6]: y >= 0.
- simplex.addInequality({0, -1, 0, 1}); // [7]: y <= 1.
+ addInequality(simplex, {0, -1, 0, 1}); // [0]: y <= 1.
+ addInequality(simplex, {-1, 0, 8, 7}); // [1]: 8z >= x - 7.
+ addInequality(simplex, {1, 0, -8, 0}); // [2]: 8z <= x.
+ addInequality(simplex, {0, 1, 0, 0}); // [3]: y >= 0.
+ addInequality(simplex, {-1, 0, 8, 7}); // [4]: 8z >= 7 - x.
+ addInequality(simplex, {1, 0, -8, 0}); // [5]: 8z <= x.
+ addInequality(simplex, {0, 1, 0, 0}); // [6]: y >= 0.
+ addInequality(simplex, {0, -1, 0, 1}); // [7]: y <= 1.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
@@ -343,14 +367,14 @@ TEST(SimplexTest, isMarkedRedundant_repeated_constraints) {
TEST(SimplexTest, isMarkedRedundant) {
Simplex simplex(3);
- simplex.addInequality({0, -1, 0, 1}); // [0]: y <= 1.
- simplex.addInequality({1, 0, 0, -1}); // [1]: x >= 1.
- simplex.addInequality({-1, 0, 0, 2}); // [2]: x <= 2.
- simplex.addInequality({-1, 0, 2, 7}); // [3]: 2z >= x - 7.
- simplex.addInequality({1, 0, -2, 0}); // [4]: 2z <= x.
- simplex.addInequality({0, 1, 0, 0}); // [5]: y >= 0.
- simplex.addInequality({0, 1, -2, 1}); // [6]: y >= 2z - 1.
- simplex.addInequality({-1, 1, 0, 1}); // [7]: y >= x - 1.
+ addInequality(simplex, {0, -1, 0, 1}); // [0]: y <= 1.
+ addInequality(simplex, {1, 0, 0, -1}); // [1]: x >= 1.
+ addInequality(simplex, {-1, 0, 0, 2}); // [2]: x <= 2.
+ addInequality(simplex, {-1, 0, 2, 7}); // [3]: 2z >= x - 7.
+ addInequality(simplex, {1, 0, -2, 0}); // [4]: 2z <= x.
+ addInequality(simplex, {0, 1, 0, 0}); // [5]: y >= 0.
+ addInequality(simplex, {0, 1, -2, 1}); // [6]: y >= 2z - 1.
+ addInequality(simplex, {-1, 1, 0, 1}); // [7]: y >= x - 1.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
@@ -372,12 +396,12 @@ TEST(SimplexTest, isMarkedRedundant) {
TEST(SimplexTest, isMarkedRedundantTiledLoopNestConstraints) {
Simplex simplex(3); // Variables are x, y, N.
- simplex.addInequality({1, 0, 0, 0}); // [0]: x >= 0.
- simplex.addInequality({-32, 0, 1, -1}); // [1]: 32x <= N - 1.
- simplex.addInequality({0, 1, 0, 0}); // [2]: y >= 0.
- simplex.addInequality({-32, 1, 0, 0}); // [3]: y >= 32x.
- simplex.addInequality({32, -1, 0, 31}); // [4]: y <= 32x + 31.
- simplex.addInequality({0, -1, 1, -1}); // [5]: y <= N - 1.
+ addInequality(simplex, {1, 0, 0, 0}); // [0]: x >= 0.
+ addInequality(simplex, {-32, 0, 1, -1}); // [1]: 32x <= N - 1.
+ addInequality(simplex, {0, 1, 0, 0}); // [2]: y >= 0.
+ addInequality(simplex, {-32, 1, 0, 0}); // [3]: y >= 32x.
+ addInequality(simplex, {32, -1, 0, 31}); // [4]: y <= 32x + 31.
+ addInequality(simplex, {0, -1, 1, -1}); // [5]: y <= N - 1.
// [3] and [0] imply [2], as we have y >= 32x >= 0.
// [3] and [5] imply [1], as we have 32x <= y <= N - 1.
simplex.detectRedundant();
@@ -391,11 +415,11 @@ TEST(SimplexTest, isMarkedRedundantTiledLoopNestConstraints) {
TEST(SimplexTest, pivotRedundantRegressionTest) {
Simplex simplex(2);
- simplex.addInequality({-1, 0, -1}); // x <= -1.
+ addInequality(simplex, {-1, 0, -1}); // x <= -1.
unsigned snapshot = simplex.getSnapshot();
- simplex.addInequality({-1, 0, -2}); // x <= -2.
- simplex.addInequality({-3, 0, -6});
+ addInequality(simplex, {-1, 0, -2}); // x <= -2.
+ addInequality(simplex, {-3, 0, -6});
// This first marks x <= -1 as redundant. Then it performs some more pivots
// to check if the other constraints are redundant. Pivot must update the
@@ -408,14 +432,14 @@ TEST(SimplexTest, pivotRedundantRegressionTest) {
// The maximum value of x should be -1.
simplex.rollback(snapshot);
MaybeOptimum<Fraction> maxX =
- simplex.computeOptimum(Simplex::Direction::Up, {1, 0, 0});
+ simplex.computeOptimum(Simplex::Direction::Up, getMPIntVec({1, 0, 0}));
EXPECT_TRUE(maxX.isBounded() && *maxX == Fraction(-1, 1));
}
TEST(SimplexTest, addInequality_already_redundant) {
Simplex simplex(1);
- simplex.addInequality({1, -1}); // x >= 1.
- simplex.addInequality({1, 0}); // x >= 0.
+ addInequality(simplex, {1, -1}); // x >= 1.
+ addInequality(simplex, {1, 0}); // x >= 0.
simplex.detectRedundant();
ASSERT_FALSE(simplex.isEmpty());
EXPECT_FALSE(simplex.isMarkedRedundant(0));
@@ -431,8 +455,8 @@ TEST(SimplexTest, appendVariable) {
EXPECT_EQ(simplex.getNumVariables(), 2u);
int64_t yMin = 2, yMax = 5;
- simplex.addInequality({0, 1, -yMin}); // y >= 2.
- simplex.addInequality({0, -1, yMax}); // y <= 5.
+ addInequality(simplex, {0, 1, -yMin}); // y >= 2.
+ addInequality(simplex, {0, -1, yMax}); // y <= 5.
unsigned snapshot2 = simplex.getSnapshot();
simplex.appendVariable(2);
@@ -441,9 +465,9 @@ TEST(SimplexTest, appendVariable) {
EXPECT_EQ(simplex.getNumVariables(), 2u);
EXPECT_EQ(simplex.getNumConstraints(), 2u);
- EXPECT_EQ(
- simplex.computeIntegerBounds({0, 1, 0}),
- std::make_pair(MaybeOptimum<int64_t>(yMin), MaybeOptimum<int64_t>(yMax)));
+ EXPECT_EQ(simplex.computeIntegerBounds(getMPIntVec({0, 1, 0})),
+ std::make_pair(MaybeOptimum<MPInt>(MPInt(yMin)),
+ MaybeOptimum<MPInt>(MPInt(yMax))));
simplex.rollback(snapshot1);
EXPECT_EQ(simplex.getNumVariables(), 1u);
@@ -452,54 +476,54 @@ TEST(SimplexTest, appendVariable) {
TEST(SimplexTest, isRedundantInequality) {
Simplex simplex(2);
- simplex.addInequality({0, -1, 2}); // y <= 2.
- simplex.addInequality({1, 0, 0}); // x >= 0.
- simplex.addEquality({-1, 1, 0}); // y = x.
+ addInequality(simplex, {0, -1, 2}); // y <= 2.
+ addInequality(simplex, {1, 0, 0}); // x >= 0.
+ addEquality(simplex, {-1, 1, 0}); // y = x.
- EXPECT_TRUE(simplex.isRedundantInequality({-1, 0, 2})); // x <= 2.
- EXPECT_TRUE(simplex.isRedundantInequality({0, 1, 0})); // y >= 0.
+ EXPECT_TRUE(isRedundantInequality(simplex, {-1, 0, 2})); // x <= 2.
+ EXPECT_TRUE(isRedundantInequality(simplex, {0, 1, 0})); // y >= 0.
- EXPECT_FALSE(simplex.isRedundantInequality({-1, 0, -1})); // x <= -1.
- EXPECT_FALSE(simplex.isRedundantInequality({0, 1, -2})); // y >= 2.
- EXPECT_FALSE(simplex.isRedundantInequality({0, 1, -1})); // y >= 1.
+ EXPECT_FALSE(isRedundantInequality(simplex, {-1, 0, -1})); // x <= -1.
+ EXPECT_FALSE(isRedundantInequality(simplex, {0, 1, -2})); // y >= 2.
+ EXPECT_FALSE(isRedundantInequality(simplex, {0, 1, -1})); // y >= 1.
}
TEST(SimplexTest, ineqType) {
Simplex simplex(2);
- simplex.addInequality({0, -1, 2}); // y <= 2.
- simplex.addInequality({1, 0, 0}); // x >= 0.
- simplex.addEquality({-1, 1, 0}); // y = x.
-
- EXPECT_TRUE(simplex.findIneqType({-1, 0, 2}) ==
- Simplex::IneqType::Redundant); // x <= 2.
- EXPECT_TRUE(simplex.findIneqType({0, 1, 0}) ==
- Simplex::IneqType::Redundant); // y >= 0.
-
- EXPECT_TRUE(simplex.findIneqType({0, 1, -1}) ==
- Simplex::IneqType::Cut); // y >= 1.
- EXPECT_TRUE(simplex.findIneqType({-1, 0, 1}) ==
- Simplex::IneqType::Cut); // x <= 1.
- EXPECT_TRUE(simplex.findIneqType({0, 1, -2}) ==
- Simplex::IneqType::Cut); // y >= 2.
-
- EXPECT_TRUE(simplex.findIneqType({-1, 0, -1}) ==
- Simplex::IneqType::Separate); // x <= -1.
+ addInequality(simplex, {0, -1, 2}); // y <= 2.
+ addInequality(simplex, {1, 0, 0}); // x >= 0.
+ addEquality(simplex, {-1, 1, 0}); // y = x.
+
+ EXPECT_EQ(findIneqType(simplex, {-1, 0, 2}),
+ Simplex::IneqType::Redundant); // x <= 2.
+ EXPECT_EQ(findIneqType(simplex, {0, 1, 0}),
+ Simplex::IneqType::Redundant); // y >= 0.
+
+ EXPECT_EQ(findIneqType(simplex, {0, 1, -1}),
+ Simplex::IneqType::Cut); // y >= 1.
+ EXPECT_EQ(findIneqType(simplex, {-1, 0, 1}),
+ Simplex::IneqType::Cut); // x <= 1.
+ EXPECT_EQ(findIneqType(simplex, {0, 1, -2}),
+ Simplex::IneqType::Cut); // y >= 2.
+
+ EXPECT_EQ(findIneqType(simplex, {-1, 0, -1}),
+ Simplex::IneqType::Separate); // x <= -1.
}
TEST(SimplexTest, isRedundantEquality) {
Simplex simplex(2);
- simplex.addInequality({0, -1, 2}); // y <= 2.
- simplex.addInequality({1, 0, 0}); // x >= 0.
- simplex.addEquality({-1, 1, 0}); // y = x.
+ addInequality(simplex, {0, -1, 2}); // y <= 2.
+ addInequality(simplex, {1, 0, 0}); // x >= 0.
+ addEquality(simplex, {-1, 1, 0}); // y = x.
- EXPECT_TRUE(simplex.isRedundantEquality({-1, 1, 0})); // y = x.
- EXPECT_TRUE(simplex.isRedundantEquality({1, -1, 0})); // x = y.
+ EXPECT_TRUE(isRedundantEquality(simplex, {-1, 1, 0})); // y = x.
+ EXPECT_TRUE(isRedundantEquality(simplex, {1, -1, 0})); // x = y.
- EXPECT_FALSE(simplex.isRedundantEquality({0, 1, -1})); // y = 1.
+ EXPECT_FALSE(isRedundantEquality(simplex, {0, 1, -1})); // y = 1.
- simplex.addEquality({0, -1, 2}); // y = 2.
+ addEquality(simplex, {0, -1, 2}); // y = 2.
- EXPECT_TRUE(simplex.isRedundantEquality({-1, 0, 2})); // x = 2.
+ EXPECT_TRUE(isRedundantEquality(simplex, {-1, 0, 2})); // x = 2.
}
TEST(SimplexTest, IsRationalSubsetOf) {
@@ -541,27 +565,27 @@ TEST(SimplexTest, IsRationalSubsetOf) {
TEST(SimplexTest, addDivisionVariable) {
Simplex simplex(/*nVar=*/1);
- simplex.addDivisionVariable({1, 0}, 2);
- simplex.addInequality({1, 0, -3}); // x >= 3.
- simplex.addInequality({-1, 0, 9}); // x <= 9.
- Optional<SmallVector<int64_t, 8>> sample = simplex.findIntegerSample();
+ simplex.addDivisionVariable(getMPIntVec({1, 0}), MPInt(2));
+ addInequality(simplex, {1, 0, -3}); // x >= 3.
+ addInequality(simplex, {-1, 0, 9}); // x <= 9.
+ Optional<SmallVector<MPInt, 8>> sample = simplex.findIntegerSample();
ASSERT_TRUE(sample.has_value());
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
}
TEST(SimplexTest, LexIneqType) {
LexSimplex simplex(/*nVar=*/1);
- simplex.addInequality({2, -1}); // x >= 1/2.
+ addInequality(simplex, {2, -1}); // x >= 1/2.
// Redundant inequality x >= 2/3.
- EXPECT_TRUE(simplex.isRedundantInequality({3, -2}));
- EXPECT_FALSE(simplex.isSeparateInequality({3, -2}));
+ EXPECT_TRUE(isRedundantInequality(simplex, {3, -2}));
+ EXPECT_FALSE(isSeparateInequality(simplex, {3, -2}));
// Separate inequality x <= 2/3.
- EXPECT_FALSE(simplex.isRedundantInequality({-3, 2}));
- EXPECT_TRUE(simplex.isSeparateInequality({-3, 2}));
+ EXPECT_FALSE(isRedundantInequality(simplex, {-3, 2}));
+ EXPECT_TRUE(isSeparateInequality(simplex, {-3, 2}));
// Cut inequality x <= 1.
- EXPECT_FALSE(simplex.isRedundantInequality({-1, 1}));
- EXPECT_FALSE(simplex.isSeparateInequality({-1, 1}));
+ EXPECT_FALSE(isRedundantInequality(simplex, {-1, 1}));
+ EXPECT_FALSE(isSeparateInequality(simplex, {-1, 1}));
}
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index 417cce15e842..b839b628173a 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -93,7 +93,7 @@ inline PWMAFunction parsePWMAF(
/// lhs and rhs represent non-negative integers or positive infinity. The
/// infinity case corresponds to when the Optional is empty.
-inline bool infinityOrUInt64LE(Optional<uint64_t> lhs, Optional<uint64_t> rhs) {
+inline bool infinityOrUInt64LE(Optional<MPInt> lhs, Optional<MPInt> rhs) {
// No constraint.
if (!rhs)
return true;
@@ -107,15 +107,24 @@ inline bool infinityOrUInt64LE(Optional<uint64_t> lhs, Optional<uint64_t> rhs) {
/// the true volume `trueVolume`, while also being at least as good an
/// approximation as `resultBound`.
inline void
-expectComputedVolumeIsValidOverapprox(Optional<uint64_t> computedVolume,
- Optional<uint64_t> trueVolume,
- Optional<uint64_t> resultBound) {
+expectComputedVolumeIsValidOverapprox(const Optional<MPInt> &computedVolume,
+ const Optional<MPInt> &trueVolume,
+ const Optional<MPInt> &resultBound) {
assert(infinityOrUInt64LE(trueVolume, resultBound) &&
"can't expect result to be less than the true volume");
EXPECT_TRUE(infinityOrUInt64LE(trueVolume, computedVolume));
EXPECT_TRUE(infinityOrUInt64LE(computedVolume, resultBound));
}
+inline void
+expectComputedVolumeIsValidOverapprox(const Optional<MPInt> &computedVolume,
+ Optional<int64_t> trueVolume,
+ Optional<int64_t> resultBound) {
+ expectComputedVolumeIsValidOverapprox(computedVolume,
+ trueVolume.transform(mpintFromInt64),
+ resultBound.transform(mpintFromInt64));
+}
+
} // namespace presburger
} // namespace mlir
More information about the Mlir-commits
mailing list