[Mlir-commits] [mlir] 3bc5353 - Implement division merging
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 2 13:59:00 PST 2021
Author: Groverkss
Date: 2021-12-03T03:23:16+05:30
New Revision: 3bc5353fc6f291d6ac12c256600b5b05d7de8f74
URL: https://github.com/llvm/llvm-project/commit/3bc5353fc6f291d6ac12c256600b5b05d7de8f74
DIFF: https://github.com/llvm/llvm-project/commit/3bc5353fc6f291d6ac12c256600b5b05d7de8f74.diff
LOG: Implement division merging
Added:
Modified:
mlir/include/mlir/Analysis/AffineStructures.h
mlir/lib/Analysis/AffineStructures.cpp
mlir/unittests/Analysis/AffineStructuresTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 53b3d043606e4..4db03582b80c3 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -530,6 +530,12 @@ class FlatAffineConstraints {
/// Normalized each constraints by the GCD of its coefficients.
void normalizeConstraintsByGCD();
+ /// Get division representation for each local identifier. If no local
+ /// representation exists for the `i^th` local identifier, denominator[i] is
+ /// set to 0.
+ void getLocalIdsReprs(std::vector<SmallVector<int64_t, 8>> &reprs,
+ SmallVector<unsigned, 8> &denominator);
+
/// Removes identifiers in the column range [idStart, idLimit), and copies any
/// remaining valid data into place, updates member variables, and resizes
/// arrays as needed.
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index c0465a2532b46..30c8e0936d219 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1918,6 +1918,48 @@ void FlatAffineConstraints::removeRedundantConstraints() {
equalities.resizeVertically(pos);
}
+void FlatAffineConstraints::getLocalIdsReprs(
+ std::vector<SmallVector<int64_t, 8>> &reprs,
+ SmallVector<unsigned, 8> &denominators) {
+
+ assert(reprs.size() == getNumLocalIds() &&
+ "Size of reprs must be equal to number of local ids");
+ assert(denominators.size() == getNumLocalIds() &&
+ "Size of denominators must be equal to number of local ids");
+
+ // Get upper-lower bound inequality pairs for division representation.
+ std::vector<Optional<std::pair<unsigned, unsigned>>> divIneqPairs(
+ getNumLocalIds());
+ getLocalReprLbUbPairs(divIneqPairs);
+
+ for (unsigned i = 0, e = getNumLocalIds(); i < e; ++i) {
+ if (!divIneqPairs[i].hasValue()) {
+ denominators[i] = 0;
+ continue;
+ }
+
+ std::pair<unsigned, unsigned> divPair = divIneqPairs[i].getValue();
+ LogicalResult divExtracted =
+ getDivRepr(*this, i + getIdKindOffset(IdKind::Local), divPair.first,
+ divPair.second, reprs[i], denominators[i]);
+ assert(succeeded(divExtracted) &&
+ "Div should have been found since ub-lb pair exists");
+ }
+}
+
+/// Merge local identifer at `pos2` into local identifer at `pos1` in `fac`.
+static void mergeDivision(FlatAffineConstraints &fac, unsigned pos1,
+ unsigned pos2) {
+ unsigned localOffset = fac.getNumDimAndSymbolIds();
+ pos1 += localOffset;
+ pos2 += localOffset;
+ for (unsigned i = 0, e = fac.getNumInequalities(); i < e; ++i)
+ fac.atIneq(i, pos1) += fac.atIneq(i, pos2);
+ for (unsigned i = 0, e = fac.getNumEqualities(); i < e; ++i)
+ fac.atEq(i, pos1) += fac.atEq(i, pos2);
+ fac.removeId(pos2);
+}
+
/// Merge local ids of `this` and `other`. This is done by appending local ids
/// of `other` to `this` and inserting local ids of `this` to `other` at start
/// of its local ids. Number of dimension and symbol ids should match in
@@ -1927,9 +1969,67 @@ void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
"Number of dimension ids should match");
assert(getNumSymbolIds() == other.getNumSymbolIds() &&
"Number of symbol ids should match");
- unsigned initLocals = getNumLocalIds();
- insertLocalId(getNumLocalIds(), other.getNumLocalIds());
- other.insertLocalId(0, initLocals);
+
+ FlatAffineConstraints &fac1 = *this;
+ FlatAffineConstraints &fac2 = other;
+
+ // Get divisions inequality pairs from each FAC.
+ std::vector<SmallVector<int64_t, 8>> divs1(fac1.getNumLocalIds()),
+ divs2(fac2.getNumLocalIds());
+ SmallVector<unsigned, 8> denoms1(fac1.getNumLocalIds()),
+ denoms2(fac2.getNumLocalIds());
+ fac1.getLocalIdsReprs(divs1, denoms1);
+ fac2.getLocalIdsReprs(divs2, denoms2);
+
+ // Merge local ids of fac1 and fac2 without using division information,
+ // i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1`
+ // to `fac2` at start of its local ids.
+ unsigned initLocals = fac1.getNumLocalIds();
+ insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+ fac2.insertLocalId(0, initLocals);
+
+ // Merge division representation extracted from fac1 and fac2.
+ divs1.insert(divs1.end(), divs2.begin(), divs2.end());
+ denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end());
+
+ auto dependsOnExist = [&](unsigned offset, SmallVector<int64_t, 8> &div) {
+ for (unsigned i = offset, e = div.size(); i < e; ++i)
+ if (div[i] != 0)
+ return true;
+ return false;
+ };
+
+ // Find duplicate divisions and merge them.
+ // TODO: Add division normalization to support divisions that
diff er by
+ // a constant
+ for (unsigned i = 0; i < divs1.size(); ++i) {
+ // Check if a division exists which is duplicate of division at `i`.
+ for (unsigned j = i + 1; j < divs1.size(); ++j) {
+ // Check if division representation exists for both local ids.
+ if (denoms1[i] == 0 || denoms1[j] == 0)
+ continue;
+ // Check if denominators match.
+ if (denoms1[i] != denoms1[j])
+ continue;
+ // Check if representation is equal.
+ if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin()))
+ continue;
+ // If division representation contains a local variable, do not match.
+ // TODO: Support divisions that depend on other local ids. This can
+ // be done by ordering divisions such that a division representation
+ // for local identifier at position `i` only depends on local identifiers
+ // at position < `i`.
+ if (dependsOnExist(fac1.getIdKindOffset(IdKind::Local), divs1[j]))
+ continue;
+
+ // Merge divisions at position `j` into division at position `i`.
+ mergeDivision(fac1, i, j);
+ mergeDivision(fac2, i, j);
+ divs1.erase(divs1.begin() + j);
+ denoms1.erase(denoms1.begin() + j);
+ --j;
+ }
+ }
}
/// Removes local variables using equalities. Each equality is checked if it
diff --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp
index 4cb04710aba1d..8f30bead3fa3d 100644
--- a/mlir/unittests/Analysis/AffineStructuresTest.cpp
+++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp
@@ -809,4 +809,79 @@ TEST(FlatAffineConstraintsTest, simplifyLocalsTest) {
EXPECT_TRUE(fac3.isEmpty());
}
+TEST(FlatAffineConstraintsTest, mergeDivisionsSimple) {
+ {
+ // (x) : (exists z, y = [x / 2] : x = 3y and x + z + 1 >= 0).
+ FlatAffineConstraints fac1(1, 0, 1);
+ fac1.addLocalFloorDiv({1, 0, 0}, 2);
+ fac1.addEquality({1, 0, -3, 0});
+ fac1.addInequality({1, 1, 0, 1});
+
+ // (x) : (exists y = [x / 2], z : x = 5y).
+ FlatAffineConstraints fac2(1);
+ fac2.addLocalFloorDiv({1, 0}, 2);
+ fac2.addEquality({1, -5, 0});
+ fac2.appendLocalId();
+
+ fac1.mergeLocalIds(fac2);
+
+ // Local space should be same.
+ EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+ // 1 division matched + 2 unmatched local variables.
+ EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+ EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+ }
+
+ {
+ // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y).
+ FlatAffineConstraints fac1(1);
+ fac1.addLocalFloorDiv({1, 0}, 5);
+ fac1.addLocalFloorDiv({1, 0, 0}, 2);
+ fac1.addEquality({1, 0, -3, 0});
+
+ // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z).
+ FlatAffineConstraints fac2(1);
+ fac2.addLocalFloorDiv({1, 0}, 2);
+ fac2.addLocalFloorDiv({1, 0, 0}, 5);
+ fac2.addEquality({1, 0, -5, 0});
+
+ fac1.mergeLocalIds(fac2);
+
+ // Local space should be same.
+ EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+ // 2 division matched.
+ EXPECT_EQ(fac1.getNumLocalIds(), 2u);
+ EXPECT_EQ(fac2.getNumLocalIds(), 2u);
+ }
+}
+
+TEST(FlatAffineConstraintsTest, mergeDivisionsUnsupported) {
+ // Division merging for divisions depending on other local variables
+ // not yet supported.
+
+ // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
+ FlatAffineConstraints fac1(1);
+ fac1.addLocalFloorDiv({1, 0}, 2);
+ fac1.addLocalFloorDiv({1, 1, 0}, 3);
+ fac1.addInequality({-1, 1, 1, 0});
+
+ // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
+ FlatAffineConstraints fac2(1);
+ fac2.addLocalFloorDiv({1, 0}, 2);
+ fac2.addLocalFloorDiv({1, 1, 0}, 3);
+ fac2.addInequality({1, -1, -1, 0});
+
+ fac1.mergeLocalIds(fac2);
+
+ // Local space should be same.
+ EXPECT_EQ(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+ // 1 division matched + 2 unmerged division due to division depending on
+ // other local variables.
+ EXPECT_EQ(fac1.getNumLocalIds(), 3u);
+ EXPECT_EQ(fac2.getNumLocalIds(), 3u);
+}
+
} // namespace mlir
More information about the Mlir-commits
mailing list