[Mlir-commits] [mlir] a873b6d - [MLIR] Generalize detecting mods during slice computing

Uday Bondhugula llvmlistbot at llvm.org
Wed Jun 23 00:01:02 PDT 2021


Author: Vinayaka Bandishti
Date: 2021-06-23T12:29:34+05:30
New Revision: a873b6d466f5c4b2e939eb02c38425e5f7ffa513

URL: https://github.com/llvm/llvm-project/commit/a873b6d466f5c4b2e939eb02c38425e5f7ffa513
DIFF: https://github.com/llvm/llvm-project/commit/a873b6d466f5c4b2e939eb02c38425e5f7ffa513.diff

LOG: [MLIR] Generalize detecting mods during slice computing

During slice computation of affine loop fusion, detect one id as the mod
of another id w.r.t a constant in a more generic way. Restrictions on
co-efficients of the ids is removed. Also, information from the
previously calculated ids is used for simplification of affine
expressions, e.g.,

If `id1` = `id2`,
  `id_n - divisor * id_q - id_r + id1 - id2 = 0`, is simplified to:
  `id_n - divisor * id_q - id_r = 0`.

If `c` is a non-zero integer,
  `c*id_n - c*divisor * id_q - c*id_r = 0`, is simplified to:
  `id_n - divisor * id_q - id_r = 0`.

Reviewed By: bondhugula, ayzhuang

Differential Revision: https://reviews.llvm.org/D104614

Added: 
    

Modified: 
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/test/Transforms/loop-fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 75bc6b76f5ed..8a3b7b6b9a92 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -1430,88 +1430,123 @@ unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
   return posLimit - posStart;
 }
 
-// Detect the identifier at 'pos' (say id_r) as modulo of another identifier
-// (say id_n) w.r.t a constant. When this happens, another identifier (say id_q)
-// could be detected as the floordiv of n. For eg:
-// id_n - 4*id_q - id_r = 0, 0 <= id_r <= 3    <=>
-//                          id_r = id_n mod 4, id_q = id_n floordiv 4.
-// lbConst and ubConst are the constant lower and upper bounds for 'pos' -
-// pre-detected at the caller.
+// Determine whether the identifier at 'pos' (say id_r) can be expressed as
+// modulo of another known identifier (say id_n) w.r.t a constant. For example,
+// if the following constraints hold true:
+// ```
+// 0 <= id_r <= divisor - 1
+// id_n - (divisor * q_expr) = id_r
+// ```
+// where `id_n` is a known identifier (called dividend), and `q_expr` is an
+// `AffineExpr` (called the quotient expression), `id_r` can be written as:
+//
+// `id_r = id_n mod divisor`.
+//
+// Additionally, in a special case of the above constaints where `q_expr` is an
+// identifier itself that is not yet known (say `id_q`), it can be written as a
+// floordiv in the following way:
+//
+// `id_q = id_n floordiv divisor`.
+//
+// Returns true if the above mod or floordiv are detected, updating 'memo' with
+// these new expressions. Returns false otherwise.
 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
                         int64_t lbConst, int64_t ubConst,
-                        SmallVectorImpl<AffineExpr> *memo) {
+                        SmallVectorImpl<AffineExpr> &memo,
+                        MLIRContext *context) {
   assert(pos < cst.getNumIds() && "invalid position");
 
-  // Check if 0 <= id_r <= divisor - 1 and if id_r is equal to
-  // id_n - divisor * id_q. If these are true, then id_n becomes the dividend
-  // and id_q the quotient when dividing id_n by the divisor.
-
+  // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
+  // be determined.
   if (lbConst != 0 || ubConst < 1)
     return false;
-
   int64_t divisor = ubConst + 1;
 
-  // Now check for: id_r =  id_n - divisor * id_q. As an example, we
-  // are looking r = d - 4q, i.e., either r - d + 4q = 0 or -r + d - 4q = 0.
-  unsigned seenQuotient = 0, seenDividend = 0;
-  int quotientPos = -1, dividendPos = -1;
-  for (unsigned r = 0, e = cst.getNumEqualities(); r < e; r++) {
-    // id_n should have coeff 1 or -1.
-    if (std::abs(cst.atEq(r, pos)) != 1)
+  // Check for the aforementioned conditions in each equality.
+  for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
+       curEquality < numEqualities; curEquality++) {
+    int64_t coefficientAtPos = cst.atEq(curEquality, pos);
+    // If current equality does not involve `id_r`, continue to the next
+    // equality.
+    if (coefficientAtPos == 0)
       continue;
-    // constant term should be 0.
-    if (cst.atEq(r, cst.getNumCols() - 1) != 0)
+
+    // Constant term should be 0 in this equality.
+    if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
       continue;
-    unsigned c, f;
-    int quotientSign = 1, dividendSign = 1;
-    for (c = 0, f = cst.getNumDimAndSymbolIds(); c < f; c++) {
-      if (c == pos)
+
+    // Traverse through the equality and construct the dividend expression
+    // `dividendExpr`, to contain all the identifiers which are known and are
+    // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
+    // `dividendExpr` gets simplified into a single identifier `id_n` discussed
+    // above.
+    auto dividendExpr = getAffineConstantExpr(0, context);
+
+    // Track the terms that go into quotient expression, later used to detect
+    // additional floordiv.
+    unsigned quotientCount = 0;
+    int quotientPosition = -1;
+    int quotientSign = 1;
+
+    // Consider each term in the current equality.
+    unsigned curId, e;
+    for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
+      // Ignore id_r.
+      if (curId == pos)
+        continue;
+      int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
+      // Ignore ids that do not contribute to the current equality.
+      if (coefficientOfCurId == 0)
+        continue;
+      // Check if the current id goes into the quotient expression.
+      if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
+        quotientCount++;
+        quotientPosition = curId;
+        quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
         continue;
-      // The coefficient of the quotient should be +/-divisor.
-      // TODO: could be extended to detect an affine function for the quotient
-      // (i.e., the coeff could be a non-zero multiple of divisor).
-      int64_t v = cst.atEq(r, c) * cst.atEq(r, pos);
-      if (v == divisor || v == -divisor) {
-        seenQuotient++;
-        quotientPos = c;
-        quotientSign = v > 0 ? 1 : -1;
       }
-      // The coefficient of the dividend should be +/-1.
-      // TODO: could be extended to detect an affine function of the other
-      // identifiers as the dividend.
-      else if (v == -1 || v == 1) {
-        seenDividend++;
-        dividendPos = c;
-        dividendSign = v < 0 ? 1 : -1;
-      } else if (cst.atEq(r, c) != 0) {
-        // Cannot be inferred as a mod since the constraint has a coefficient
-        // for an identifier that's neither a unit nor the divisor (see TODOs
-        // above).
+      // Identifiers that are part of dividendExpr should be known.
+      if (!memo[curId])
         break;
-      }
+      // Append the current identifier to the dividend expression.
+      dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
     }
-    if (c < f)
-      // Cannot be inferred as a mod since the constraint has a coefficient for
-      // an identifier that's neither a unit nor the divisor (see TODOs above).
+
+    // Can't construct expression as it depends on a yet uncomputed id.
+    if (curId < e)
       continue;
 
-    // We are looking for exactly one identifier as the dividend.
-    if (seenDividend == 1 && seenQuotient >= 1) {
-      if (!(*memo)[dividendPos])
-        return false;
-      // Successfully detected a mod.
-      (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
-      auto ub = cst.getConstantUpperBound(dividendPos);
+    // Express `id_r` in terms of the other ids collected so far.
+    if (coefficientAtPos > 0)
+      dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
+    else
+      dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
+
+    // Simplify the expression.
+    dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
+                                      cst.getNumSymbolIds());
+    // Only if the final dividend expression is just a single id (which we call
+    // `id_n`), we can proceed.
+    // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
+    // to dims themselves.
+    auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
+    if (!dimExpr)
+      continue;
+
+    // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
+    if (quotientCount >= 1) {
+      auto ub = cst.getConstantUpperBound(dimExpr.getPosition());
+      // If `id_n` has an upperbound that is less than the divisor, mod can be
+      // eliminated altogether.
       if (ub.hasValue() && ub.getValue() < divisor)
-        // The mod can be optimized away.
-        (*memo)[pos] = (*memo)[dividendPos] * dividendSign;
+        memo[pos] = dimExpr;
       else
-        (*memo)[pos] = (*memo)[dividendPos] % divisor * dividendSign;
+        memo[pos] = dimExpr % divisor;
+      // If a unique quotient `id_q` was seen, it can be expressed as
+      // `id_n floordiv divisor`.
+      if (quotientCount == 1 && !memo[quotientPosition])
+        memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
 
-      if (seenQuotient == 1 && !(*memo)[quotientPos])
-        // Successfully detected a floordiv as well.
-        (*memo)[quotientPos] =
-            (*memo)[dividendPos].floorDiv(divisor) * quotientSign;
       return true;
     }
   }
@@ -1885,7 +1920,7 @@ void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
         // Detect an identifier as modulo of another identifier w.r.t a
         // constant.
         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
-                        &memo)) {
+                        memo, context)) {
           changed = true;
           continue;
         }

diff  --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir
index 8a2f3358fe1d..14a2bf0223e2 100644
--- a/mlir/test/Transforms/loop-fusion.mlir
+++ b/mlir/test/Transforms/loop-fusion.mlir
@@ -3330,3 +3330,69 @@ func @fuse_large_number_of_loops(%arg0: memref<20x10xf32, 1>, %arg1: memref<20x1
 // CHECK:         affine.for
 // CHECK:         affine.for
 // CHECK-NOT:     affine.for
+
+// -----
+
+// Expects fusion of producer into consumer at depth 4 and subsequent removal of
+// source loop.
+// CHECK-LABEL: func @unflatten4d
+func @unflatten4d(%arg1: memref<7x8x9x10xf32>) {
+  %m = memref.alloc() : memref<5040xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 7 {
+    affine.for %i1 = 0 to 8 {
+      affine.for %i2 = 0 to 9 {
+        affine.for %i3 = 0 to 10 {
+          affine.store %cf7, %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
+        }
+      }
+    }
+  }
+  affine.for %i0 = 0 to 7 {
+    affine.for %i1 = 0 to 8 {
+      affine.for %i2 = 0 to 9 {
+        affine.for %i3 = 0 to 10 {
+          %v0 = affine.load %m[720 * %i0 + 90 * %i1 + 10 * %i2 + %i3] : memref<5040xf32>
+          affine.store %v0, %arg1[%i0, %i1, %i2, %i3] : memref<7x8x9x10xf32>
+        }
+      }
+    }
+  }
+  return
+}
+
+// CHECK:        affine.for
+// CHECK-NEXT:     affine.for
+// CHECK-NEXT:       affine.for
+// CHECK-NEXT:         affine.for
+// CHECK-NOT:    affine.for
+// CHECK: return
+
+// -----
+
+// Expects fusion of producer into consumer at depth 2 and subsequent removal of
+// source loop.
+// CHECK-LABEL: func @unflatten2d_with_transpose
+func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) {
+  %m = memref.alloc() : memref<56xf32>
+  %cf7 = constant 7.0 : f32
+
+  affine.for %i0 = 0 to 7 {
+    affine.for %i1 = 0 to 8 {
+      affine.store %cf7, %m[8 * %i0 + %i1] : memref<56xf32>
+    }
+  }
+  affine.for %i0 = 0 to 8 {
+    affine.for %i1 = 0 to 7 {
+      %v0 = affine.load %m[%i0 + 8 * %i1] : memref<56xf32>
+      affine.store %v0, %arg1[%i0, %i1] : memref<8x7xf32>
+    }
+  }
+  return
+}
+
+// CHECK:        affine.for
+// CHECK-NEXT:     affine.for
+// CHECK-NOT:    affine.for
+// CHECK: return


        


More information about the Mlir-commits mailing list