[Mlir-commits] [mlir] 4418669 - [MLIR][Presburger]	PWMAFunction::valueAt: support local ids
    Arjun P 
    llvmlistbot at llvm.org
       
    Wed Mar 23 17:42:11 PDT 2022
    
    
  
Author: Arjun P
Date: 2022-03-24T00:42:21Z
New Revision: 4418669f1e6c429b679a942f971a7ae148cdccc8
URL: https://github.com/llvm/llvm-project/commit/4418669f1e6c429b679a942f971a7ae148cdccc8
DIFF: https://github.com/llvm/llvm-project/commit/4418669f1e6c429b679a942f971a7ae148cdccc8.diff
LOG: [MLIR][Presburger] PWMAFunction::valueAt: support local ids
Add a baseline implementation of support for local ids for `PWMAFunction::valueAt`. This can be made more efficient later if needed by handling locals with known div representations separately.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D122144
Added: 
    
Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
    mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
Removed: 
    
################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 41c0500e52367..2bfcea29af316 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -285,11 +285,12 @@ class IntegerRelation : public PresburgerLocalSpace {
   Optional<uint64_t> computeVolume() const;
 
   /// Returns true if the given point satisfies the constraints, or false
-  /// otherwise.
-  ///
-  /// Note: currently, if the relation contains local ids, the values of
-  /// the local ids must also be provided.
+  /// otherwise. Takes the values of all ids including locals.
   bool containsPoint(ArrayRef<int64_t> point) const;
+  /// Given the values of non-local ids, return a satisfying assignment to the
+  /// local if one exists, or an empty optional otherwise.
+  Optional<SmallVector<int64_t, 8>>
+  containsPointNoLocal(ArrayRef<int64_t> point) const;
 
   /// Find equality and pairs of inequality contraints identified by their
   /// position indices, using which an explicit representation for each local
diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 66b66940fd72a..f0519acaab279 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -41,8 +41,7 @@ namespace presburger {
 /// each id, and an extra column at the end for the constant term.
 ///
 /// Checking equality of two such functions is supported, as well as finding the
-/// value of the function at a specified point. Note that local ids in the
-/// domain are not yet supported for finding the value at a point.
+/// value of the function at a specified point.
 class MultiAffineFunction : protected IntegerPolyhedron {
 public:
   /// We use protected inheritance to avoid inheriting the whole public
@@ -114,8 +113,6 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 
   /// Get the value of the function at the specified point. If the point lies
   /// outside the domain, an empty optional is returned.
-  ///
-  /// Note: domains with local ids are not yet supported, and will assert-fail.
   Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
 
   void print(raw_ostream &os) const;
@@ -145,8 +142,7 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 /// symbolic ids.
 ///
 /// Support is provided to compare equality of two such functions as well as
-/// finding the value of the function at a point. Note that local ids in the
-/// piece are not supported for the latter.
+/// finding the value of the function at a point.
 class PWMAFunction : public PresburgerSpace {
 public:
   PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
@@ -170,8 +166,6 @@ class PWMAFunction : public PresburgerSpace {
 
   /// Return the value at the specified point and an empty optional if the
   /// point does not lie in the domain.
-  ///
-  /// Note: domains with local ids are not yet supported, and will assert-fail.
   Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
 
   /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 4b41c23c0475c..d0c3744fa3cab 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -784,6 +784,25 @@ bool IntegerRelation::containsPoint(ArrayRef<int64_t> point) const {
   return true;
 }
 
+/// Just substitute the values given and check if an integer sample exists for
+/// the local ids.
+///
+/// TODO: this could be made more efficient by handling divisions separately.
+/// Instead of finding an integer sample over all the locals, we can first
+/// compute the values of the locals that have division representations and
+/// only use the integer emptiness check for the locals that don't have this.
+/// Handling this correctly requires ordering the divs, though.
+Optional<SmallVector<int64_t, 8>>
+IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
+  assert(point.size() == getNumIds() - getNumLocalIds() &&
+         "Point should contain all ids except locals!");
+  assert(getIdKindOffset(IdKind::Local) == getNumIds() - getNumLocalIds() &&
+         "This function depends on locals being stored last!");
+  IntegerRelation copy = *this;
+  copy.setAndEliminate(0, point);
+  return copy.findIntegerSample();
+}
+
 void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
   std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalIds());
   SmallVector<unsigned, 4> denominators(getNumLocalIds());
diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 4ce7882ef5466..41ea913c4f292 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -36,19 +36,26 @@ PresburgerSet PWMAFunction::getDomain() const {
 
 Optional<SmallVector<int64_t, 8>>
 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
-  assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
-  assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
+  assert(point.size() == getNumDimAndSymbolIds() &&
+         "Point has incorrect dimensionality!");
 
-  if (!getDomain().containsPoint(point))
+  Optional<SmallVector<int64_t, 8>> maybeLocalValues =
+      getDomain().containsPointNoLocal(point);
+  if (!maybeLocalValues)
     return {};
 
   // The point lies in the domain, so we need to compute the output value.
+  SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
+  // The given point didn't include the values of locals which the output is a
+  // function of; we have computed one possible set of values and use them
+  // here. The function is not allowed to have local ids that take more than
+  // one possible value.
+  pointHomogenous.append(*maybeLocalValues);
   // The matrix `output` has an affine expression in the ith row, corresponding
   // to the expression for the ith value in the output vector. The last column
   // of the matrix contains the constant term. Let v be the input point with
   // a 1 appended at the end. We can see that output * v gives the desired
   // output vector.
-  SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
   pointHomogenous.push_back(1);
   SmallVector<int64_t, 8> result =
       output.postMultiplyWithColumn(pointHomogenous);
diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 0da5f8b369854..e7bbac3fc8ea0 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1187,3 +1187,18 @@ TEST(IntegerPolyhedronTest, computeVolume) {
       parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)"),
       /*trueVolume=*/{}, /*resultBound=*/{});
 }
+
+TEST(IntegerPolyhedronTest, containsPointNoLocal) {
+  IntegerPolyhedron poly1 = parsePoly("(x) : ((x floordiv 2) - x == 0)");
+  EXPECT_TRUE(poly1.containsPointNoLocal({0}));
+  EXPECT_FALSE(poly1.containsPointNoLocal({1}));
+
+  IntegerPolyhedron poly2 = parsePoly(
+      "(x) : (x - 2*(x floordiv 2) == 0, x - 4*(x floordiv 4) - 2 == 0)");
+  EXPECT_TRUE(poly2.containsPointNoLocal({6}));
+  EXPECT_FALSE(poly2.containsPointNoLocal({4}));
+
+  IntegerPolyhedron poly3 = parsePoly("(x, y) : (2*x - y >= 0, y - 3*x >= 0)");
+  EXPECT_TRUE(poly3.containsPointNoLocal({0, 0}));
+  EXPECT_FALSE(poly3.containsPointNoLocal({1, 0}));
+}
diff  --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
index 9ee2fdc0cae6d..79139de6fd011 100644
--- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -129,16 +129,31 @@ TEST(PWAFunctionTest, isEqual) {
 }
 
 TEST(PWMAFunction, valueAt) {
-  PWMAFunction nonNegPWAF = parsePWMAF(
+  PWMAFunction nonNegPWMAF = parsePWMAF(
       /*numInputs=*/2, /*numOutputs=*/2,
       {
           {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
           {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
       });
-  EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
-  EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
-  EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
-  EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
+  EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
+  EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
+  EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
+  EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).hasValue());
+
+  PWMAFunction divPWMAF = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
+           {{0, 2, 1, 3}, {0, 4, 3, 5}}}, // (x, y).
+          {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
+      });
+  EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
+  EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
+  EXPECT_FALSE(divPWMAF.valueAt({3, 3}).hasValue());
+  EXPECT_FALSE(divPWMAF.valueAt({3, -3}).hasValue());
+
+  EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
+  EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).hasValue());
 }
 
 TEST(PWMAFunction, removeIdRangeRegressionTest) {
        
    
    
More information about the Mlir-commits
mailing list