[Mlir-commits] [mlir] [mlir] Extend SCF loopUnrollByFactor to return the result loops (PR #114573)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 1 09:57:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Hongtao Yu (htyu)
<details>
<summary>Changes</summary>
There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in https://github.com/triton-lang/triton/pull/5027 . I'm changing the result type of `loopUnrollByFactor` for that.
---
Full diff: https://github.com/llvm/llvm-project/pull/114573.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+7-5)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+4-2)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+12-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
- if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
- result = loopUnrollByFactor(scfFor, getFactor());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+ auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+ result = resultLoops.empty() ? failure() : success();
+ }
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return success();
+ return {forOp};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return failure();
- return success();
+ return {};
+ return {forOp};
}
int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
+ SmallVector<scf::ForOp, 2> resultLoops;
+ resultLoops.push_back(forOp);
+
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
+ resultLoops.push_back(epilogueForOp);
}
// Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
- return success();
+ return resultLoops;
}
/// Check if bounds of all inner loops are defined outside of `forOp`
``````````
</details>
https://github.com/llvm/llvm-project/pull/114573
More information about the Mlir-commits
mailing list