[Mlir-commits] [mlir] 4418669 - [MLIR][Presburger] PWMAFunction::valueAt: support local ids
Arjun P
llvmlistbot at llvm.org
Wed Mar 23 17:42:11 PDT 2022
Author: Arjun P
Date: 2022-03-24T00:42:21Z
New Revision: 4418669f1e6c429b679a942f971a7ae148cdccc8
URL: https://github.com/llvm/llvm-project/commit/4418669f1e6c429b679a942f971a7ae148cdccc8
DIFF: https://github.com/llvm/llvm-project/commit/4418669f1e6c429b679a942f971a7ae148cdccc8.diff
LOG: [MLIR][Presburger] PWMAFunction::valueAt: support local ids
Add a baseline implementation of support for local ids for `PWMAFunction::valueAt`. This can be made more efficient later if needed by handling locals with known div representations separately.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D122144
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 41c0500e52367..2bfcea29af316 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -285,11 +285,12 @@ class IntegerRelation : public PresburgerLocalSpace {
Optional<uint64_t> computeVolume() const;
/// Returns true if the given point satisfies the constraints, or false
- /// otherwise.
- ///
- /// Note: currently, if the relation contains local ids, the values of
- /// the local ids must also be provided.
+ /// otherwise. Takes the values of all ids including locals.
bool containsPoint(ArrayRef<int64_t> point) const;
+ /// Given the values of non-local ids, return a satisfying assignment to the
+ /// local if one exists, or an empty optional otherwise.
+ Optional<SmallVector<int64_t, 8>>
+ containsPointNoLocal(ArrayRef<int64_t> point) const;
/// Find equality and pairs of inequality contraints identified by their
/// position indices, using which an explicit representation for each local
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 66b66940fd72a..f0519acaab279 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -41,8 +41,7 @@ namespace presburger {
/// each id, and an extra column at the end for the constant term.
///
/// Checking equality of two such functions is supported, as well as finding the
-/// value of the function at a specified point. Note that local ids in the
-/// domain are not yet supported for finding the value at a point.
+/// value of the function at a specified point.
class MultiAffineFunction : protected IntegerPolyhedron {
public:
/// We use protected inheritance to avoid inheriting the whole public
@@ -114,8 +113,6 @@ class MultiAffineFunction : protected IntegerPolyhedron {
/// Get the value of the function at the specified point. If the point lies
/// outside the domain, an empty optional is returned.
- ///
- /// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
void print(raw_ostream &os) const;
@@ -145,8 +142,7 @@ class MultiAffineFunction : protected IntegerPolyhedron {
/// symbolic ids.
///
/// Support is provided to compare equality of two such functions as well as
-/// finding the value of the function at a point. Note that local ids in the
-/// piece are not supported for the latter.
+/// finding the value of the function at a point.
class PWMAFunction : public PresburgerSpace {
public:
PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
@@ -170,8 +166,6 @@ class PWMAFunction : public PresburgerSpace {
/// Return the value at the specified point and an empty optional if the
/// point does not lie in the domain.
- ///
- /// Note: domains with local ids are not yet supported, and will assert-fail.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
/// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 4b41c23c0475c..d0c3744fa3cab 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -784,6 +784,25 @@ bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
return true;
}
+/// Just substitute the values given and check if an integer sample exists for
+/// the local ids.
+///
+/// TODO: this could be made more efficient by handling divisions separately.
+/// Instead of finding an integer sample over all the locals, we can first
+/// compute the values of the locals that have division representations and
+/// only use the integer emptiness check for the locals that don't have this.
+/// Handling this correctly requires ordering the divs, though.
+Optional<SmallVector<int64_t, 8>>
+IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
+ assert(point.size() == getNumIds() - getNumLocalIds() &&
+ "Point should contain all ids except locals!");
+ assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() &&
+ "This function depends on locals being stored last!");
+ IntegerRelation copy = *this;
+ copy.setAndEliminate(0, point);
+ return copy.findIntegerSample();
+}
+
void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
SmallVector<unsigned, 4> denominators(getNumLocalIds());
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 4ce7882ef5466..41ea913c4f292 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -36,19 +36,26 @@ PresburgerSet PWMAFunction::getDomain() const {
Optional<SmallVector<int64_t, 8>>
MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
- assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
- assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
+ assert(point.size() == getNumDimAndSymbolIds() &&
+ "Point has incorrect dimensionality!");
- if (!getDomain().containsPoint(point))
+ Optional<SmallVector<int64_t, 8>> maybeLocalValues =
+ getDomain().containsPointNoLocal(point);
+ if (!maybeLocalValues)
return {};
// The point lies in the domain, so we need to compute the output value.
+ SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
+ // The given point didn't include the values of locals which the output is a
+ // function of; we have computed one possible set of values and use them
+ // here. The function is not allowed to have local ids that take more than
+ // one possible value.
+ pointHomogenous.append(*maybeLocalValues);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
- SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
pointHomogenous.push_back(1);
SmallVector<int64_t, 8> result =
output.postMultiplyWithColumn(pointHomogenous);
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 0da5f8b369854..e7bbac3fc8ea0 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1187,3 +1187,18 @@ TEST(IntegerPolyhedronTest, computeVolume) {
parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"),
/*trueVolume=*/{}, /*resultBound=*/{});
}
+
+TEST(IntegerPolyhedronTest, containsPointNoLocal) {
+ IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)");
+ EXPECT_TRUE(poly1.containsPointNoLocal({0}));
+ EXPECT_FALSE(poly1.containsPointNoLocal({1}));
+
+ IntegerPolyhedron poly2 = parsePoly(
+ "(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)");
+ EXPECT_TRUE(poly2.containsPointNoLocal({6}));
+ EXPECT_FALSE(poly2.containsPointNoLocal({4}));
+
+ IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)");
+ EXPECT_TRUE(poly3.containsPointNoLocal({0, 0}));
+ EXPECT_FALSE(poly3.containsPointNoLocal({1, 0}));
+}
diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
index 9ee2fdc0cae6d..79139de6fd011 100644
--- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -129,16 +129,31 @@ TEST(PWAFunctionTest, isEqual) {
}
TEST(PWMAFunction, valueAt) {
- PWMAFunction nonNegPWAF = parsePWMAF(
+ PWMAFunction nonNegPWMAF = parsePWMAF(
/*numInputs=*/2, /*numOutputs=*/2,
{
{"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
{"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
});
- EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
- EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
- EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
- EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
+ EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
+ EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
+ EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
+ EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).hasValue());
+
+ PWMAFunction divPWMAF = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
+ {{0, 2, 1, 3}, {0, 4, 3, 5}}}, // (x, y).
+ {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
+ });
+ EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
+ EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
+ EXPECT_FALSE(divPWMAF.valueAt({3, 3}).hasValue());
+ EXPECT_FALSE(divPWMAF.valueAt({3, -3}).hasValue());
+
+ EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
+ EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).hasValue());
}
TEST(PWMAFunction, removeIdRangeRegressionTest) {
More information about the Mlir-commits
mailing list