[Mlir-commits] [mlir] a873b6d - [MLIR] Generalize detecting mods during slice computing
Uday Bondhugula
llvmlistbot at llvm.org
Wed Jun 23 00:01:02 PDT 2021
Author: Vinayaka Bandishti
Date: 2021-06-23T12:29:34+05:30
New Revision: a873b6d466f5c4b2e939eb02c38425e5f7ffa513
URL: https://github.com/llvm/llvm-project/commit/a873b6d466f5c4b2e939eb02c38425e5f7ffa513
DIFF: https://github.com/llvm/llvm-project/commit/a873b6d466f5c4b2e939eb02c38425e5f7ffa513.diff
LOG: [MLIR] Generalize detecting mods during slice computing
During slice computation of affine loop fusion, detect one id as the mod
of another id w.r.t a constant in a more generic way. Restrictions on
co-efficients of the ids is removed. Also, information from the
previously calculated ids is used for simplification of affine
expressions, e.g.,
If `id1` = `id2`,
`id_n - divisor * id_q - id_r + id1 - id2 = 0`, is simplified to:
`id_n - divisor * id_q - id_r = 0`.
If `c` is a non-zero integer,
`c*id_n - c*divisor * id_q - c*id_r = 0`, is simplified to:
`id_n - divisor * id_q - id_r = 0`.
Reviewed By: bondhugula, ayzhuang
Differential Revision: https://reviews.llvm.org/D104614
Added:
Modified:
mlir/lib/Analysis/AffineStructures.cpp
mlir/test/Transforms/loop-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 75bc6b76f5ed..8a3b7b6b9a92 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1430,88 +1430,123 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
return posLimit - posStart;
}
-// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
-// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
-// could be detected as the floordiv of n. For eg:
-// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3 <=>
-// id_r = id_n mod 4, id_q = id_n floordiv 4.
-// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
-// pre-detected at the caller.
+// Determine whether the identifier at 'pos' (say id_r) can be expressed as
+// modulo of another known identifier (say id_n) w.r.t a constant. For example,
+// if the following constraints hold true:
+// ```
+// 0 <= id_r <= divisor - 1
+// id_n - (divisor * q_expr) = id_r
+// ```
+// where `id_n` is a known identifier (called dividend), and `q_expr` is an
+// `AffineExpr` (called the quotient expression), `id_r` can be written as:
+//
+// `id_r = id_n mod divisor`.
+//
+// Additionally, in a special case of the above constaints where `q_expr` is an
+// identifier itself that is not yet known (say `id_q`), it can be written as a
+// floordiv in the following way:
+//
+// `id_q = id_n floordiv divisor`.
+//
+// Returns true if the above mod or floordiv are detected, updating 'memo' with
+// these new expressions. Returns false otherwise.
static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
int64_t lbConst, int64_t ubConst,
- SmallVectorImpl<AffineExpr> *memo) {
+ SmallVectorImpl<AffineExpr> &memo,
+ MLIRContext *context) {
assert(pos < cst.getNumIds() && "invalid position");
- // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
- // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
- // and id_q the quotient when dividing id_n by the divisor.
-
+ // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
+ // be determined.
if (lbConst != 0 || ubConst < 1)
return false;
-
int64_t divisor = ubConst + 1;
- // Now check for: id_r = id_n - divisor * id_q. As an example, we
- // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
- unsigned seenQuotient = 0, seenDividend = 0;
- int quotientPos = -1, dividendPos = -1;
- for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
- // id_n should have coeff 1 or -1.
- if (std::abs(cst.atEq(r, pos)) != 1)
+ // Check for the aforementioned conditions in each equality.
+ for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
+ curEquality < numEqualities; curEquality++) {
+ int64_t coefficientAtPos = cst.atEq(curEquality, pos);
+ // If current equality does not involve `id_r`, continue to the next
+ // equality.
+ if (coefficientAtPos == 0)
continue;
- // constant term should be 0.
- if (cst.atEq(r, cst.getNumCols() - 1) != 0)
+
+ // Constant term should be 0 in this equality.
+ if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
continue;
- unsigned c, f;
- int quotientSign = 1, dividendSign = 1;
- for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
- if (c == pos)
+
+ // Traverse through the equality and construct the dividend expression
+ // `dividendExpr`, to contain all the identifiers which are known and are
+ // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
+ // `dividendExpr` gets simplified into a single identifier `id_n` discussed
+ // above.
+ auto dividendExpr = getAffineConstantExpr(0, context);
+
+ // Track the terms that go into quotient expression, later used to detect
+ // additional floordiv.
+ unsigned quotientCount = 0;
+ int quotientPosition = -1;
+ int quotientSign = 1;
+
+ // Consider each term in the current equality.
+ unsigned curId, e;
+ for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
+ // Ignore id_r.
+ if (curId == pos)
+ continue;
+ int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
+ // Ignore ids that do not contribute to the current equality.
+ if (coefficientOfCurId == 0)
+ continue;
+ // Check if the current id goes into the quotient expression.
+ if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
+ quotientCount++;
+ quotientPosition = curId;
+ quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
continue;
- // The coefficient of the quotient should be +/-divisor.
- // TODO: could be extended to detect an affine function for the quotient
- // (i.e., the coeff could be a non-zero multiple of divisor).
- int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
- if (v == divisor || v == -divisor) {
- seenQuotient++;
- quotientPos = c;
- quotientSign = v > 0 ? 1 : -1;
}
- // The coefficient of the dividend should be +/-1.
- // TODO: could be extended to detect an affine function of the other
- // identifiers as the dividend.
- else if (v == -1 || v == 1) {
- seenDividend++;
- dividendPos = c;
- dividendSign = v < 0 ? 1 : -1;
- } else if (cst.atEq(r, c) != 0) {
- // Cannot be inferred as a mod since the constraint has a coefficient
- // for an identifier that's neither a unit nor the divisor (see TODOs
- // above).
+ // Identifiers that are part of dividendExpr should be known.
+ if (!memo[curId])
break;
- }
+ // Append the current identifier to the dividend expression.
+ dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
}
- if (c < f)
- // Cannot be inferred as a mod since the constraint has a coefficient for
- // an identifier that's neither a unit nor the divisor (see TODOs above).
+
+ // Can't construct expression as it depends on a yet uncomputed id.
+ if (curId < e)
continue;
- // We are looking for exactly one identifier as the dividend.
- if (seenDividend == 1 && seenQuotient >= 1) {
- if (!(*memo)[dividendPos])
- return false;
- // Successfully detected a mod.
- (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
- auto ub = cst.getConstantUpperBound(dividendPos);
+ // Express `id_r` in terms of the other ids collected so far.
+ if (coefficientAtPos > 0)
+ dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
+ else
+ dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
+
+ // Simplify the expression.
+ dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
+ cst.getNumSymbolIds());
+ // Only if the final dividend expression is just a single id (which we call
+ // `id_n`), we can proceed.
+ // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
+ // to dims themselves.
+ auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
+ if (!dimExpr)
+ continue;
+
+ // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
+ if (quotientCount >= 1) {
+ auto ub = cst.getConstantUpperBound(dimExpr.getPosition());
+ // If `id_n` has an upperbound that is less than the divisor, mod can be
+ // eliminated altogether.
if (ub.hasValue() && ub.getValue() < divisor)
- // The mod can be optimized away.
- (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
+ memo[pos] = dimExpr;
else
- (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+ memo[pos] = dimExpr % divisor;
+ // If a unique quotient `id_q` was seen, it can be expressed as
+ // `id_n floordiv divisor`.
+ if (quotientCount == 1 && !memo[quotientPosition])
+ memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
- if (seenQuotient == 1 && !(*memo)[quotientPos])
- // Successfully detected a floordiv as well.
- (*memo)[quotientPos] =
- (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
return true;
}
}
@@ -1885,7 +1920,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
// Detect an identifier as modulo of another identifier w.r.t a
// constant.
if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
- &memo)) {
+ memo, context)) {
changed = true;
continue;
}
diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 8a2f3358fe1d..14a2bf0223e2 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3330,3 +3330,69 @@ func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x1
// CHECK: affine.for
// CHECK: affine.for
// CHECK-NOT: affine.for
+
+// -----
+
+// Expects fusion of producer into consumer at depth 4 and subsequent removal of
+// source loop.
+// CHECK-LABEL: func @unflatten4d
+func @unflatten4d(%arg1: memref<7x8x9x10xf32>) {
+ %m = memref.alloc() : memref<5040xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 7 {
+ affine.for %i1 = 0 to 8 {
+ affine.for %i2 = 0 to 9 {
+ affine.for %i3 = 0 to 10 {
+ affine.store %cf7, %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
+ }
+ }
+ }
+ }
+ affine.for %i0 = 0 to 7 {
+ affine.for %i1 = 0 to 8 {
+ affine.for %i2 = 0 to 9 {
+ affine.for %i3 = 0 to 10 {
+ %v0 = affine.load %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
+ affine.store %v0, %arg1[%i0, %i1, %i2, %i3] : memref<7x8x9x10xf32>
+ }
+ }
+ }
+ }
+ return
+}
+
+// CHECK: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NOT: affine.for
+// CHECK: return
+
+// -----
+
+// Expects fusion of producer into consumer at depth 2 and subsequent removal of
+// source loop.
+// CHECK-LABEL: func @unflatten2d_with_transpose
+func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) {
+ %m = memref.alloc() : memref<56xf32>
+ %cf7 = constant 7.0 : f32
+
+ affine.for %i0 = 0 to 7 {
+ affine.for %i1 = 0 to 8 {
+ affine.store %cf7, %m[8 * %i0 + %i1] : memref<56xf32>
+ }
+ }
+ affine.for %i0 = 0 to 8 {
+ affine.for %i1 = 0 to 7 {
+ %v0 = affine.load %m[%i0 + 8 * %i1] : memref<56xf32>
+ affine.store %v0, %arg1[%i0, %i1] : memref<8x7xf32>
+ }
+ }
+ return
+}
+
+// CHECK: affine.for
+// CHECK-NEXT: affine.for
+// CHECK-NOT: affine.for
+// CHECK: return
More information about the Mlir-commits
mailing list