[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)

Aviad Cohen llvmlistbot at llvm.org
Sun Jun 16 20:28:12 PDT 2024


================
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
----------------
AviadCo wrote:

Ack.

https://github.com/llvm/llvm-project/pull/94142


More information about the Mlir-commits mailing list