[Mlir-commits] [mlir] [mlir] Extend SCF loopUnrollByFactor to return the result loops (PR #114573)
Hongtao Yu
llvmlistbot at llvm.org
Fri Nov 1 11:11:55 PDT 2024
https://github.com/htyu updated https://github.com/llvm/llvm-project/pull/114573
>From 6e779e649aee2ebcdf7594e469dc94da6d544380 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 09:52:11 -0700
Subject: [PATCH 1/2] [mlir] Extend SCF loopUnrollByFactor to return the result
loops
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 12 +++++++-----
.../SCF/TransformOps/SCFTransformOps.cpp | 6 ++++--
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 18 ++++++++++++------
3 files changed, 23 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
- if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
- result = loopUnrollByFactor(scfFor, getFactor());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+ auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+ result = resultLoops.empty() ? failure() : success();
+ }
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return success();
+ return {forOp};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return failure();
- return success();
+ return {};
+ return {forOp};
}
int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
+ SmallVector<scf::ForOp, 2> resultLoops;
+ resultLoops.push_back(forOp);
+
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
+ resultLoops.push_back(epilogueForOp);
}
// Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
- return success();
+ return resultLoops;
}
/// Check if bounds of all inner loops are defined outside of `forOp`
>From ec4600241bf8f1acac2c403b9c4b2a43a68802f7 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 11:06:51 -0700
Subject: [PATCH 2/2] make return value more structured.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 17 +++++++++++------
.../SCF/TransformOps/SCFTransformOps.cpp | 5 ++---
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 17 ++++++++---------
3 files changed, 21 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index eda64ea69f81d1..c3bd6d86864186 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,13 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
+struct UnrolledLoopInfo {
+ scf::ForOp mainLoopOp;
+ scf::ForOp epilogueLoopOp;
+};
+
/// Unrolls this for operation by the specified unroll factor. Returns the
-/// unrolled main loop and the eplilog loop in sequence, if the loop is
-/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
-/// either due to restrictions or due to invalid unroll factors. Requires
-/// positive loop bounds and step. If specified, annotates the Ops in each
-/// unrolled iteration by applying `annotateFn`.
-SmallVector<scf::ForOp> loopUnrollByFactor(
+/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
+/// returns a strucutre of null fields if the loop cannot be unrolled either due
+/// to restrictions or due to invalid unroll factors. Requires positive loop
+/// bounds and step. If specified, annotates the Ops in each unrolled iteration
+/// by applying `annotateFn`.
+UnrolledLoopInfo loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c84cb13f8b6bb2..cefd023c40d96c 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -355,9 +355,8 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
- result = resultLoops.empty() ? failure() : success();
- }
- else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+ result = resultLoops.mainLoopOp ? success() : failure();
+ } else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
return emitSilenceableError()
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 8394ac47888100..a50e90af3af658 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -373,16 +373,15 @@ static void generateUnrolledLoop(
}
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
-/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
-/// vector.
-SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
+/// eplilog loop, if the loop is unrolled. Otherwise return null.
+UnrolledLoopInfo mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return {forOp};
+ return {forOp, nullptr};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -403,8 +402,8 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return {};
- return {forOp};
+ return {nullptr, nullptr};
+ return {forOp, nullptr};
}
int64_t tripCountEvenMultiple =
@@ -452,8 +451,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
- SmallVector<scf::ForOp, 2> resultLoops;
- resultLoops.push_back(forOp);
+ UnrolledLoopInfo resultLoops;
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
@@ -473,7 +471,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
- resultLoops.push_back(epilogueForOp);
+ resultLoops.epilogueLoopOp = epilogueForOp;
}
// Create unrolled loop.
@@ -496,6 +494,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
+ resultLoops.mainLoopOp = forOp;
return resultLoops;
}
More information about the Mlir-commits
mailing list