[Mlir-commits] [mlir] d5a2944 - [MLIR][Presburger] Add support for piece-wise multi-affine functions
Arjun P
llvmlistbot at llvm.org
Mon Feb 7 11:14:04 PST 2022
Author: Arjun P
Date: 2022-02-08T00:43:59+05:30
New Revision: d5a29442191066b70176e59260288d8cb08bd602
URL: https://github.com/llvm/llvm-project/commit/d5a29442191066b70176e59260288d8cb08bd602
DIFF: https://github.com/llvm/llvm-project/commit/d5a29442191066b70176e59260288d8cb08bd602.diff
LOG: [MLIR][Presburger] Add support for piece-wise multi-affine functions
Add the class MultiAffineFunction which represents functions whose domain is an
IntegerPolyhedron and which produce an output given by a tuple of affine
expressions in the IntegerPolyhedron's ids.
Also add support for piece-wise MultiAffineFunctions, which are defined on a
union of IntegerPolyhedrons, and may have different output affine expressions
on each IntegerPolyhedron. Thus the function is affine on each individual
IntegerPolyhedron piece in the domain.
This is part of a series of patches leading up to parametric integer programming.
Depends on D118778.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D118779
Added:
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
mlir/lib/Analysis/Presburger/CMakeLists.txt
mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
mlir/unittests/Analysis/Presburger/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
index 4cdccef6db155..28fb6c09be6c2 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
@@ -56,6 +56,7 @@ class IntegerPolyhedron {
enum class Kind {
FlatAffineConstraints,
FlatAffineValueConstraints,
+ MultiAffineFunction,
IntegerPolyhedron
};
@@ -194,6 +195,11 @@ class IntegerPolyhedron {
/// Adds an equality from the coefficients specified in `eq`.
void addEquality(ArrayRef<int64_t> eq);
+ /// Eliminate the `posB^th` local identifier, replacing every instance of it
+ /// with the `posA^th` local identifier. This should be used when the two
+ /// local variables are known to always take the same values.
+ virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB);
+
/// Removes identifiers of the specified kind with the specified pos (or
/// within the specified range) from the system. The specified location is
/// relative to the first identifier of the specified kind.
@@ -273,6 +279,9 @@ class IntegerPolyhedron {
/// Returns true if the given point satisfies the constraints, or false
/// otherwise.
+ ///
+ /// Note: currently, if the polyhedron contains local ids, the values of
+ /// the local ids must also be provided.
bool containsPoint(ArrayRef<int64_t> point) const;
/// Find equality and pairs of inequality contraints identified by their
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
new file mode 100644
index 0000000000000..a01c3ef7a3e61
--- /dev/null
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -0,0 +1,195 @@
+//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Support for piece-wise multi-affine functions. These are functions that are
+// defined on a domain that is a union of IntegerPolyhedrons, and on each domain
+// the value of the function is a tuple of integers, with each value in the
+// tuple being an affine expression in the ids of the IntegerPolyhedron.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
+#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
+
+#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
+#include "mlir/Analysis/Presburger/PresburgerSet.h"
+
+namespace mlir {
+
+/// This class represents a multi-affine function whose domain is given by an
+/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
+/// tuple of integer values attached to every point in the polyhedron, with the
+/// value of each element of the tuple given by an affine expression in the ids
+/// of the polyhedron. For example we could have the domain
+///
+/// (x, y) : (x >= 5, y >= x)
+///
+/// and a tuple of three integers defined at every point in the polyhedron:
+///
+/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
+///
+/// In this way every point in the polyhedron has a tuple of integers associated
+/// with it. If the integer polyhedron has local ids, then the output
+/// expressions can use them as well. The output expressions are represented as
+/// a matrix with one row for every element in the output vector one column for
+/// each id, and an extra column at the end for the constant term.
+///
+/// Checking equality of two such functions is supported, as well as finding the
+/// value of the function at a specified point. Note that local ids in the
+/// domain are not yet supported for finding the value at a point.
+class MultiAffineFunction : protected IntegerPolyhedron {
+public:
+ /// We use protected inheritance to avoid inheriting the whole public
+ /// interface of IntegerPolyhedron. These using declarations explicitly make
+ /// only the relevant functions part of the public interface.
+ using IntegerPolyhedron::getNumDimAndSymbolIds;
+ using IntegerPolyhedron::getNumDimIds;
+ using IntegerPolyhedron::getNumIds;
+ using IntegerPolyhedron::getNumLocalIds;
+ using IntegerPolyhedron::getNumSymbolIds;
+
+ MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
+ : IntegerPolyhedron(domain), output(output) {}
+ MultiAffineFunction(const Matrix &output, unsigned numDims,
+ unsigned numSymbols = 0, unsigned numLocals = 0)
+ : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
+
+ ~MultiAffineFunction() override = default;
+ Kind getKind() const override { return Kind::MultiAffineFunction; }
+ bool classof(const IntegerPolyhedron *poly) const {
+ return poly->getKind() == Kind::MultiAffineFunction;
+ }
+
+ unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
+ unsigned getNumOutputs() const { return output.getNumRows(); }
+ bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
+ const IntegerPolyhedron &getDomain() const { return *this; }
+
+ bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
+
+ /// Insert `num` identifiers of the specified kind at position `pos`.
+ /// Positions are relative to the kind of identifier. The coefficient columns
+ /// corresponding to the added identifiers are initialized to zero. Return the
+ /// absolute column position (i.e., not relative to the kind of identifier)
+ /// of the first added identifier.
+ unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+
+ /// Swap the posA^th identifier with the posB^th identifier.
+ void swapId(unsigned posA, unsigned posB) override;
+
+ /// Remove the specified range of ids.
+ void removeIdRange(unsigned idStart, unsigned idLimit) override;
+
+ /// Eliminate the `posB^th` local identifier, replacing every instance of it
+ /// with the `posA^th` local identifier. This should be used when the two
+ /// local variables are known to always take the same values.
+ void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
+
+ /// Return whether the outputs of `this` and `other` agree wherever both
+ /// functions are defined, i.e., the outputs should be equal for all points in
+ /// the intersection of the domains.
+ bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
+
+ /// Return whether the `this` and `other` are equal. This is the case if
+ /// they lie in the same space, i.e. have the same dimensions, and their
+ /// domains are identical and their outputs are equal on their domain.
+ bool isEqual(const MultiAffineFunction &other) const;
+
+ /// Get the value of the function at the specified point. If the point lies
+ /// outside the domain, an empty optional is returned.
+ ///
+ /// Note: domains with local ids are not yet supported, and will assert-fail.
+ Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+
+ void print(raw_ostream &os) const;
+
+ void dump() const;
+
+private:
+ /// The function's output is a tuple of integers, with the ith element of the
+ /// tuple defined by the affine expression given by the ith row of this output
+ /// matrix.
+ Matrix output;
+};
+
+/// This class represents a piece-wise MultiAffineFunction. This can be thought
+/// of as a list of MultiAffineFunction with disjoint domains, with each having
+/// their own affine expressions for their output tuples. For example, we could
+/// have a function with two input variables (x, y), defined as
+///
+/// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0
+/// = (-2*x + y, y + 4) if x < 0, y < 0
+/// = (4, 1) if x < 0, y >= 0
+///
+/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
+/// this class is undefined. The domains need not cover all possible points;
+/// this represents a partial function and so could be undefined at some points.
+///
+/// As in PresburgerSets, the input ids are partitioned into dimension ids and
+/// symbolic ids.
+///
+/// Support is provided to compare equality of two such functions as well as
+/// finding the value of the function at a point. Note that local ids in the
+/// piece are not supported for the latter.
+class PWMAFunction {
+public:
+ PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
+ : numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
+ assert(numOutputs >= 1 && "The function must output something!");
+ }
+
+ void addPiece(const MultiAffineFunction &piece);
+ void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
+
+ const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
+ unsigned getNumPieces() const { return pieces.size(); }
+ unsigned getNumOutputs() const { return numOutputs; }
+ unsigned getNumInputs() const { return numDims + numSymbols; }
+ unsigned getNumDimIds() const { return numDims; }
+ unsigned getNumSymbolIds() const { return numSymbols; }
+ MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
+
+ /// Return the domain of this piece-wise MultiAffineFunction. This is the
+ /// union of the domains of all the pieces.
+ PresburgerSet getDomain() const;
+
+ /// Check whether the `this` and the given function have compatible
+ /// dimensions, i.e., the same number of dimension inputs, symbol inputs, and
+ /// outputs.
+ bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
+ bool hasCompatibleDimensions(const PWMAFunction &f) const;
+
+ /// Return the value at the specified point and an empty optional if the
+ /// point does not lie in the domain.
+ ///
+ /// Note: domains with local ids are not yet supported, and will assert-fail.
+ Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+
+ /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
+ /// they have the same dimensions, the same domain and they take the same
+ /// value at every point in the domain.
+ bool isEqual(const PWMAFunction &other) const;
+
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+private:
+ /// The list of pieces in this piece-wise MultiAffineFunction.
+ SmallVector<MultiAffineFunction, 4> pieces;
+
+ /// The number of dimensions ids in the domains.
+ unsigned numDims;
+ /// The number of symbol ids in the domains.
+ unsigned numSymbols;
+ /// The number of output ids.
+ unsigned numOutputs;
+};
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
diff --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt
index c2458b124921a..313742f7e3d8b 100644
--- a/mlir/lib/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
LinearTransform.cpp
Matrix.cpp
PresburgerSet.cpp
+ PWMAFunction.cpp
Simplex.cpp
Utils.cpp
diff --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
index 837830eeff7a0..e9a082e349e38 100644
--- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
@@ -1065,24 +1065,17 @@ void IntegerPolyhedron::removeRedundantConstraints() {
equalities.resizeVertically(pos);
}
-/// Eliminate `pos2^th` local identifier, replacing its every instance with
-/// `pos1^th` local identifier. This function is intended to be used to remove
-/// redundancy when local variables at position `pos1` and `pos2` are restricted
-/// to have the same value.
-static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1,
- unsigned pos2) {
-
- assert(pos1 < poly.getNumLocalIds() && "Invalid local id position");
- assert(pos2 < poly.getNumLocalIds() && "Invalid local id position");
-
- unsigned localOffset = poly.getNumDimAndSymbolIds();
- pos1 += localOffset;
- pos2 += localOffset;
- for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i)
- poly.atIneq(i, pos1) += poly.atIneq(i, pos2);
- for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i)
- poly.atEq(i, pos1) += poly.atEq(i, pos2);
- poly.removeId(pos2);
+void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA,
+ unsigned posB) {
+ assert(posA < getNumLocalIds() && "Invalid local id position");
+ assert(posB < getNumLocalIds() && "Invalid local id position");
+
+ unsigned localOffset = getIdKindOffset(IdKind::Local);
+ posA += localOffset;
+ posB += localOffset;
+ inequalities.addToColumn(posB, posA, 1);
+ equalities.addToColumn(posB, posA, 1);
+ removeId(posB);
}
/// Adds additional local ids to the sets such that they both have the union
@@ -1129,8 +1122,8 @@ void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) {
// Merge function that merges the local variables in both sets by treating
// them as the same identifier.
auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
- eliminateRedundantLocalId(polyA, i, j);
- eliminateRedundantLocalId(polyB, i, j);
+ polyA.eliminateRedundantLocalId(i, j);
+ polyB.eliminateRedundantLocalId(i, j);
return true;
};
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
new file mode 100644
index 0000000000000..385f135767c82
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -0,0 +1,198 @@
+//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
+
+using namespace mlir;
+
+// Return the result of subtracting the two given vectors pointwise.
+// The vectors must be of the same size.
+// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
+static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
+ ArrayRef<int64_t> vecB) {
+ assert(vecA.size() == vecB.size() &&
+ "Cannot subtract vectors of
diff ering lengths!");
+ SmallVector<int64_t, 8> result;
+ result.reserve(vecA.size());
+ for (unsigned i = 0, e = vecA.size(); i < e; ++i)
+ result.push_back(vecA[i] - vecB[i]);
+ return result;
+}
+
+PresburgerSet PWMAFunction::getDomain() const {
+ PresburgerSet domain =
+ PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
+ for (const MultiAffineFunction &piece : pieces)
+ domain.unionPolyInPlace(piece.getDomain());
+ return domain;
+}
+
+Optional<SmallVector<int64_t, 8>>
+MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
+ assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
+ assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
+
+ if (!getDomain().containsPoint(point))
+ return {};
+
+ // The point lies in the domain, so we need to compute the output value.
+ // The matrix `output` has an affine expression in the ith row, corresponding
+ // to the expression for the ith value in the output vector. The last column
+ // of the matrix contains the constant term. Let v be the input point with
+ // a 1 appended at the end. We can see that output * v gives the desired
+ // output vector.
+ SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
+ pointHomogenous.push_back(1);
+ SmallVector<int64_t, 8> result =
+ output.postMultiplyWithColumn(pointHomogenous);
+ assert(result.size() == getNumOutputs());
+ return result;
+}
+
+Optional<SmallVector<int64_t, 8>>
+PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
+ assert(point.size() == getNumInputs() &&
+ "Point has incorrect dimensionality!");
+ for (const MultiAffineFunction &piece : pieces)
+ if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
+ return output;
+ return {};
+}
+
+void MultiAffineFunction::print(raw_ostream &os) const {
+ os << "Domain:";
+ IntegerPolyhedron::print(os);
+ os << "Output:\n";
+ output.print(os);
+ os << "\n";
+}
+
+void MultiAffineFunction::dump() const { print(llvm::errs()); }
+
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
+ return hasCompatibleDimensions(other) &&
+ getDomain().isEqual(other.getDomain()) &&
+ isEqualWhereDomainsOverlap(other);
+}
+
+unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
+ unsigned num) {
+ unsigned absolutePos = getIdKindOffset(kind) + pos;
+ output.insertColumns(absolutePos, num);
+ return IntegerPolyhedron::insertId(kind, pos, num);
+}
+
+void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
+ output.swapColumns(posA, posB);
+ IntegerPolyhedron::swapId(posA, posB);
+}
+
+void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
+ output.removeColumns(idStart, idLimit - idStart);
+ IntegerPolyhedron::removeIdRange(idStart, idLimit);
+}
+
+void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
+ unsigned posB) {
+ output.addToColumn(posB, posA, /*scale=*/1);
+ IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
+}
+
+bool MultiAffineFunction::isEqualWhereDomainsOverlap(
+ MultiAffineFunction other) const {
+ if (!hasCompatibleDimensions(other))
+ return false;
+
+ // `commonFunc` has the same output as `this`.
+ MultiAffineFunction commonFunc = *this;
+ // After this merge, `commonFunc` and `other` have the same local ids; they
+ // are merged.
+ commonFunc.mergeLocalIds(other);
+ // After this, the domain of `commonFunc` will be the intersection of the
+ // domains of `this` and `other`.
+ commonFunc.IntegerPolyhedron::append(other);
+
+ // `commonDomainMatching` contains the subset of the common domain
+ // where the outputs of `this` and `other` match.
+ //
+ // We want to add constraints equating the outputs of `this` and `other`.
+ // However, `this` may have
diff erence local ids from `other`, whereas we
+ // need both to have the same locals. Accordingly, we use `commonFunc.output`
+ // in place of `this->output`, since `commonFunc` has the same output but also
+ // has its locals merged.
+ IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
+ for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
+ commonDomainMatching.addEquality(
+ subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
+
+ // If the whole common domain is a subset of commonDomainMatching, then they
+ // are equal and the two functions match on the whole common domain.
+ return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
+}
+
+/// Two PWMAFunctions are equal if they have the same dimensionalities,
+/// the same domain, and take the same value at every point in the domain.
+bool PWMAFunction::isEqual(const PWMAFunction &other) const {
+ if (!hasCompatibleDimensions(other))
+ return false;
+
+ if (!this->getDomain().isEqual(other.getDomain()))
+ return false;
+
+ // Check if, whenever the domains of a piece of `this` and a piece of `other`
+ // overlap, they take the same output value. If `this` and `other` have the
+ // same domain (checked above), then this check passes iff the two functions
+ // have the same output at every point in the domain.
+ for (const MultiAffineFunction &aPiece : this->pieces)
+ for (const MultiAffineFunction &bPiece : other.pieces)
+ if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
+ return false;
+ return true;
+}
+
+void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
+ assert(hasCompatibleDimensions(piece) &&
+ "Piece to be added is not compatible with this PWMAFunction!");
+ assert(piece.isConsistent() && "Piece is internally inconsistent!");
+ assert(this->getDomain()
+ .intersect(PresburgerSet(piece.getDomain()))
+ .isIntegerEmpty() &&
+ "New piece's domain overlaps with that of existing pieces!");
+ pieces.push_back(piece);
+}
+
+void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
+ const Matrix &output) {
+ addPiece(MultiAffineFunction(domain, output));
+}
+
+void PWMAFunction::print(raw_ostream &os) const {
+ os << pieces.size() << " pieces:\n";
+ for (const MultiAffineFunction &piece : pieces)
+ piece.print(os);
+}
+
+/// The hasCompatibleDimensions functions don't check the number of local ids;
+/// functions are still compatible if they have
diff ering number of locals.
+bool MultiAffineFunction::hasCompatibleDimensions(
+ const MultiAffineFunction &f) const {
+ return getNumDimIds() == f.getNumDimIds() &&
+ getNumSymbolIds() == f.getNumSymbolIds() &&
+ getNumOutputs() == f.getNumOutputs();
+}
+bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
+ return getNumDimIds() == f.getNumDimIds() &&
+ getNumSymbolIds() == f.getNumSymbolIds() &&
+ getNumOutputs() == f.getNumOutputs();
+}
+bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
+ return getNumDimIds() == f.getNumDimIds() &&
+ getNumSymbolIds() == f.getNumSymbolIds() &&
+ getNumOutputs() == f.getNumOutputs();
+}
diff --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index 2bdb2e7bd11c4..e7142a7f87509 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
LinearTransformTest.cpp
MatrixTest.cpp
PresburgerSetTest.cpp
+ PWMAFunctionTest.cpp
SimplexTest.cpp
../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
)
diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
new file mode 100644
index 0000000000000..614f19cc58b0e
--- /dev/null
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -0,0 +1,183 @@
+//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains tests for PWMAFunction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
+#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h"
+#include "mlir/Analysis/Presburger/PresburgerSet.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace mlir {
+using testing::ElementsAre;
+
+/// Parses an IntegerPolyhedron from a StringRef. It is expected that the
+/// string represents a valid IntegerSet, otherwise it will violate a gtest
+/// assertion.
+static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) {
+ FailureOr<IntegerPolyhedron> poly = parseIntegerSetToFAC(str, context);
+ EXPECT_TRUE(succeeded(poly));
+ return *poly;
+}
+
+static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
+ ArrayRef<SmallVector<int64_t, 8>> matrix) {
+ Matrix results(numRow, numColumns);
+ assert(matrix.size() == numRow);
+ for (unsigned i = 0; i < numRow; ++i) {
+ assert(matrix[i].size() == numColumns &&
+ "Output expression has incorrect dimensionality!");
+ for (unsigned j = 0; j < numColumns; ++j)
+ results(i, j) = matrix[i][j];
+ }
+ return results;
+}
+
+/// Construct a PWMAFunction given the dimensionalities and an array describing
+/// the list of pieces. Each piece is given by a string describing the domain
+/// and a 2D array that represents the output.
+static PWMAFunction parsePWMAF(
+ unsigned numInputs, unsigned numOutputs,
+ ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+ data,
+ unsigned numSymbols = 0) {
+ static MLIRContext context;
+
+ PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
+ for (const auto &pair : data) {
+ IntegerPolyhedron domain = parsePoly(pair.first, &context);
+ result.addPiece(
+ domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
+ }
+ return result;
+}
+
+TEST(PWAFunctionTest, isEqual) {
+ MLIRContext context;
+
+ // The output expressions are
diff erent but it doesn't matter because they are
+ // equal in this domain.
+ PWMAFunction idAtZeros = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+ {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+ {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
+ });
+ PWMAFunction idAtZeros2 = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
+ {"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
+ {"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
+ });
+ EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
+
+ PWMAFunction notIdAtZeros = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+ {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
+ {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
+ });
+ EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
+
+ // These match at their intersection but one has a bigger domain.
+ PWMAFunction idNoNegNegQuadrant = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+ {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
+ });
+ PWMAFunction idOnlyPosX =
+ parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+ });
+ EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
+
+ // Different representations of the same domain.
+ PWMAFunction sumPlusOne = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/1,
+ {
+ {"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1.
+ {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
+ {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1.
+ });
+ PWMAFunction sumPlusOne2 =
+ parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
+ {
+ {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
+ });
+ EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
+
+ // Functions with zero input dimensions.
+ PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
+ {
+ {"() : ()", {{1}}}, // 1.
+ });
+ PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
+ {
+ {"() : ()", {{2}}}, // 1.
+ });
+ EXPECT_TRUE(noInputs1.isEqual(noInputs1));
+ EXPECT_FALSE(noInputs1.isEqual(noInputs2));
+
+ // Mismatched dimensionalities.
+ EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
+ EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
+
+ // Divisions.
+ // Domain is only multiples of 6; x = 6k for some k.
+ // x + 4(x/2) + 4(x/3) == 26k.
+ PWMAFunction mul2AndMul3 = parsePWMAF(
+ /*numInputs=*/1, /*numOutputs=*/1,
+ {
+ {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
+ {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
+ });
+ PWMAFunction mul6 = parsePWMAF(
+ /*numInputs=*/1, /*numOutputs=*/1,
+ {
+ {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
+ });
+ EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
+
+ PWMAFunction mul6
diff = parsePWMAF(
+ /*numInputs=*/1, /*numOutputs=*/1,
+ {
+ {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
+ });
+ EXPECT_FALSE(mul2AndMul3.isEqual(mul6
diff ));
+
+ PWMAFunction mul5 = parsePWMAF(
+ /*numInputs=*/1, /*numOutputs=*/1,
+ {
+ {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
+ });
+ EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
+}
+
+TEST(PWMAFunction, valueAt) {
+ PWMAFunction nonNegPWAF = parsePWMAF(
+ /*numInputs=*/2, /*numOutputs=*/2,
+ {
+ {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
+ {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
+ });
+ EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
+ EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
+ EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
+ EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
+}
+
+} // namespace mlir
More information about the Mlir-commits
mailing list