[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)
Aviad Cohen
llvmlistbot at llvm.org
Sun Jun 16 20:28:12 PDT 2024
================
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+ uint64_t unrollJamFactor) {
+ assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+ if (unrollJamFactor == 1)
+ return success();
+
+ // If any control operand of any inner loop of `forOp` is defined within
+ // `forOp`, no unroll jam.
+ if (!areInnerBoundsInvariant(forOp)) {
+ LDBG("failed to unroll and jam: inner bounds are not invariant");
+ return failure();
+ }
+
+ // Currently, for operations with results are not supported.
+ if (forOp->getNumResults() > 0) {
+ LDBG("failed to unroll and jam: unsupported loop with results");
+ return failure();
+ }
+
+ // Currently, only constant trip count that divided by the unroll factor is
+ // supported.
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount) {
+ // If the trip count is dynamic, do not unroll & jam.
+ LDBG("failed to unroll and jam: trip count could not be determined");
+ return failure();
+ }
+ if (unrollJamFactor > *tripCount) {
+ LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+ "count");
+ unrollJamFactor = *tripCount;
+ } else if (*tripCount % unrollJamFactor != 0) {
+ LDBG("failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor");
+ return failure();
+ }
+
+ // Nothing in the loop body other than the terminator.
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer<scf::ForOp> jbg;
+ jbg.walk(forOp);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Collect inner loops.
+ SmallVector<scf::ForOp, 4> innerLoops;
----------------
AviadCo wrote:
Ack.
https://github.com/llvm/llvm-project/pull/94142
More information about the Mlir-commits
mailing list