[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)

Wenyi Zhao llvmlistbot at llvm.org
Tue Jun 11 02:33:53 PDT 2024


================
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
----------------
wyzero wrote:

can `*tripCount` be zero? Do we need some special logic for it?

https://github.com/llvm/llvm-project/pull/94142


More information about the Mlir-commits mailing list