[Mlir-commits] [mlir] 8a09674 - Address arjun's comments
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 2 13:59:07 PST 2021
Author: Groverkss
Date: 2021-12-03T03:23:18+05:30
New Revision: 8a0967481f972fa2c273ad179632880e2f3acc56
URL: https://github.com/llvm/llvm-project/commit/8a0967481f972fa2c273ad179632880e2f3acc56
DIFF: https://github.com/llvm/llvm-project/commit/8a0967481f972fa2c273ad179632880e2f3acc56.diff
LOG: Address arjun's comments
Added:
Modified:
mlir/include/mlir/Analysis/AffineStructures.h
mlir/lib/Analysis/AffineStructures.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index 1a2766c020bc..a06f790a1afb 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -443,7 +443,7 @@ class FlatAffineConstraints {
/// Merges and aligns local ids of `this` and `other`. Local ids with
/// identical division representations are merged. The number of dimensions
- /// and symbol ids should match in `this` and `other`.
+ /// and symbol ids in `this` and `other` should match.
void mergeLocalIds(FlatAffineConstraints &other);
/// Removes all equalities and inequalities.
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 3bb97df5f656..d73541321dbc 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1918,9 +1918,16 @@ void FlatAffineConstraints::removeRedundantConstraints() {
equalities.resizeVertically(pos);
}
-/// Merge local identifer at `pos2` into local identifer at `pos1` in `fac`.
-static void mergeDivision(FlatAffineConstraints &fac, unsigned pos1,
- unsigned pos2) {
+/// Eliminate `pos2^th` local identifier, replacing its every instance with
+/// `pos1^th` local identifier. This function is intended to be used to remove
+/// redundancy when local variables at position `pos1` and `pos2` are restricted
+/// to have the same value.
+static void eleminateRedundantLocalId(FlatAffineConstraints &fac, unsigned pos1,
+ unsigned pos2) {
+
+ assert(pos1 <= fac.getNumLocalIds() && "Invalid local id position");
+ assert(pos2 <= fac.getNumLocalIds() && "Invalid local id position");
+
unsigned localOffset = fac.getNumDimAndSymbolIds();
pos1 += localOffset;
pos2 += localOffset;
@@ -1940,58 +1947,68 @@ void FlatAffineConstraints::mergeLocalIds(FlatAffineConstraints &other) {
FlatAffineConstraints &fac1 = *this;
FlatAffineConstraints &fac2 = other;
- // Get divisions inequality pairs from each FAC.
- std::vector<SmallVector<int64_t, 8>> divs1(fac1.getNumLocalIds()),
- divs2(fac2.getNumLocalIds());
- SmallVector<unsigned, 4> denoms1(fac1.getNumLocalIds()),
- denoms2(fac2.getNumLocalIds());
+ // Get divisions representations from each FAC.
+ std::vector<SmallVector<int64_t, 8>> divs1, divs2;
+ SmallVector<unsigned, 4> denoms1, denoms2;
fac1.getLocalReprs(divs1, denoms1);
fac2.getLocalReprs(divs2, denoms2);
// Merge local ids of fac1 and fac2 without using division information,
// i.e. append local ids of `fac2` to `fac1` and insert local ids of `fac1`
- // to `fac2` at start of its local ids.
+ // to `fac2` at start of its local ids. Also, insert these local ids in
+ // division representation.
unsigned initLocals = fac1.getNumLocalIds();
+ for (unsigned i = 0, e = divs1.size(); i < e; ++i)
+ if (denoms1[i] != 0)
+ divs1[i].insert(divs1[i].begin() + fac1.getNumIds(),
+ fac2.getNumLocalIds(), 0);
insertLocalId(fac1.getNumLocalIds(), fac2.getNumLocalIds());
+
+ for (unsigned i = 0, e = divs2.size(); i < e; ++i)
+ if (denoms2[i] != 0)
+ divs2[i].insert(divs2[i].begin() + fac2.getIdKindOffset(IdKind::Local),
+ initLocals, 0);
fac2.insertLocalId(0, initLocals);
// Merge division representation extracted from fac1 and fac2.
divs1.insert(divs1.end(), divs2.begin(), divs2.end());
denoms1.insert(denoms1.end(), denoms2.begin(), denoms2.end());
- auto dependsOnExist = [&](unsigned offset, SmallVector<int64_t, 8> &div) {
- for (unsigned i = offset, e = div.size(); i < e; ++i)
- if (div[i] != 0)
- return true;
- return false;
- };
-
// Find duplicate divisions and merge them.
// TODO: Add division normalization to support divisions that
diff er by
- // a constant
+ // a constant.
+ // TODO: Add division ordering such that a division representation for local
+ // identifier at position `i` only depends on local identifiers at position <
+ // `i`. This makes sure that all divisions depending on local variables that
+ // can be merged, are merged.
+ unsigned localOffset = getIdKindOffset(IdKind::Local);
for (unsigned i = 0; i < divs1.size(); ++i) {
- // Check if a division exists which is duplicate of division at `i`.
+ // Check if division representations exists `i^th` local id.
+ if (denoms1[i] == 0)
+ continue;
+ // Check if a division exists which is a duplicate of the division at `i`.
for (unsigned j = i + 1; j < divs1.size(); ++j) {
- // Check if division representation exists for both local ids.
- if (denoms1[i] == 0 || denoms1[j] == 0)
+ // Check if division representations exists for `j^th` local id.
+ if (denoms1[j] == 0)
continue;
- // Check if denominators match.
+ // Check if the denominators match.
if (denoms1[i] != denoms1[j])
continue;
- // Check if representation is equal.
+ // Check if the representations are equal.
if (!std::equal(divs1[i].begin(), divs1[i].end(), divs1[j].begin()))
continue;
- // If division representation contains a local variable, do not match.
- // TODO: Support divisions that depend on other local ids. This can
- // be done by ordering divisions such that a division representation
- // for local identifier at position `i` only depends on local identifiers
- // at position < `i`.
- if (dependsOnExist(fac1.getIdKindOffset(IdKind::Local), divs1[j]))
- continue;
// Merge divisions at position `j` into division at position `i`.
- mergeDivision(fac1, i, j);
- mergeDivision(fac2, i, j);
+ eleminateRedundantLocalId(fac1, i, j);
+ eleminateRedundantLocalId(fac2, i, j);
+ for (unsigned k = 0, g = divs1.size(); k < g; ++k) {
+ SmallVector<int64_t, 8> &div = divs1[k];
+ if (denoms1[k] != 0) {
+ div[localOffset + i] += div[localOffset + j];
+ div.erase(div.begin() + localOffset + j);
+ }
+ }
+
divs1.erase(divs1.begin() + j);
denoms1.erase(denoms1.begin() + j);
--j;
More information about the Mlir-commits
mailing list