[Mlir-commits] [mlir] [mlir] Extend SCF loopUnrollByFactor to return the result loops (PR #114573)
Hongtao Yu
llvmlistbot at llvm.org
Sun Nov 3 12:08:08 PST 2024
https://github.com/htyu updated https://github.com/llvm/llvm-project/pull/114573
>From 6e779e649aee2ebcdf7594e469dc94da6d544380 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 09:52:11 -0700
Subject: [PATCH 1/3] [mlir] Extend SCF loopUnrollByFactor to return the result
loops
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 12 +++++++-----
.../SCF/TransformOps/SCFTransformOps.cpp | 6 ++++--
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 18 ++++++++++++------
3 files changed, 23 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
- if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
- result = loopUnrollByFactor(scfFor, getFactor());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+ auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+ result = resultLoops.empty() ? failure() : success();
+ }
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return success();
+ return {forOp};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return failure();
- return success();
+ return {};
+ return {forOp};
}
int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
+ SmallVector<scf::ForOp, 2> resultLoops;
+ resultLoops.push_back(forOp);
+
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
+ resultLoops.push_back(epilogueForOp);
}
// Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
- return success();
+ return resultLoops;
}
/// Check if bounds of all inner loops are defined outside of `forOp`
>From ec4600241bf8f1acac2c403b9c4b2a43a68802f7 Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Fri, 1 Nov 2024 11:06:51 -0700
Subject: [PATCH 2/3] make return value more structured.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 17 +++++++++++------
.../SCF/TransformOps/SCFTransformOps.cpp | 5 ++---
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 17 ++++++++---------
3 files changed, 21 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index eda64ea69f81d1..c3bd6d86864186 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,13 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
+struct UnrolledLoopInfo {
+ scf::ForOp mainLoopOp;
+ scf::ForOp epilogueLoopOp;
+};
+
/// Unrolls this for operation by the specified unroll factor. Returns the
-/// unrolled main loop and the eplilog loop in sequence, if the loop is
-/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
-/// either due to restrictions or due to invalid unroll factors. Requires
-/// positive loop bounds and step. If specified, annotates the Ops in each
-/// unrolled iteration by applying `annotateFn`.
-SmallVector<scf::ForOp> loopUnrollByFactor(
+/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
+/// returns a strucutre of null fields if the loop cannot be unrolled either due
+/// to restrictions or due to invalid unroll factors. Requires positive loop
+/// bounds and step. If specified, annotates the Ops in each unrolled iteration
+/// by applying `annotateFn`.
+UnrolledLoopInfo loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c84cb13f8b6bb2..cefd023c40d96c 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -355,9 +355,8 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
- result = resultLoops.empty() ? failure() : success();
- }
- else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+ result = resultLoops.mainLoopOp ? success() : failure();
+ } else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
return emitSilenceableError()
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 8394ac47888100..a50e90af3af658 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -373,16 +373,15 @@ static void generateUnrolledLoop(
}
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
-/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
-/// vector.
-SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
+/// eplilog loop, if the loop is unrolled. Otherwise return null.
+UnrolledLoopInfo mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return {forOp};
+ return {forOp, nullptr};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -403,8 +402,8 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return {};
- return {forOp};
+ return {nullptr, nullptr};
+ return {forOp, nullptr};
}
int64_t tripCountEvenMultiple =
@@ -452,8 +451,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
- SmallVector<scf::ForOp, 2> resultLoops;
- resultLoops.push_back(forOp);
+ UnrolledLoopInfo resultLoops;
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
@@ -473,7 +471,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
- resultLoops.push_back(epilogueForOp);
+ resultLoops.epilogueLoopOp = epilogueForOp;
}
// Create unrolled loop.
@@ -496,6 +494,7 @@ SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
+ resultLoops.mainLoopOp = forOp;
return resultLoops;
}
>From 69e5cbe2623c6cbe6e0a4ec3be2d6cc753baf4fc Mon Sep 17 00:00:00 2001
From: Hongtao Yu <hoy at fb.com>
Date: Sun, 3 Nov 2024 12:06:24 -0800
Subject: [PATCH 3/3] Changed return type to FailureOr<UnrolledLoopInfo>
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 6 +++---
.../SCF/TransformOps/SCFTransformOps.cpp | 3 +--
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 18 +++++++++---------
3 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index c3bd6d86864186..9c41fc6a30b809 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -112,8 +112,8 @@ void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
struct UnrolledLoopInfo {
- scf::ForOp mainLoopOp;
- scf::ForOp epilogueLoopOp;
+ scf::ForOp mainLoopOp = nullptr;
+ scf::ForOp epilogueLoopOp = nullptr;
};
/// Unrolls this for operation by the specified unroll factor. Returns the
@@ -122,7 +122,7 @@ struct UnrolledLoopInfo {
/// to restrictions or due to invalid unroll factors. Requires positive loop
/// bounds and step. If specified, annotates the Ops in each unrolled iteration
/// by applying `annotateFn`.
-UnrolledLoopInfo loopUnrollByFactor(
+FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index cefd023c40d96c..21455bd7251130 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -354,8 +354,7 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
- auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
- result = resultLoops.mainLoopOp ? success() : failure();
+ result = loopUnrollByFactor(scfFor, getFactor());
} else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a50e90af3af658..e591ca49bccb8f 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -373,15 +373,15 @@ static void generateUnrolledLoop(
}
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
-/// eplilog loop, if the loop is unrolled. Otherwise return null.
-UnrolledLoopInfo mlir::loopUnrollByFactor(
+/// eplilog loop, if the loop is unrolled.
+FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return {forOp, nullptr};
+ return UnrolledLoopInfo{forOp, nullptr};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,8 +402,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return {nullptr, nullptr};
- return {forOp, nullptr};
+ return failure();
+ return UnrolledLoopInfo{forOp, nullptr};
}
int64_t tripCountEvenMultiple =
@@ -470,8 +470,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
}
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
- (void)epilogueForOp.promoteIfSingleIteration(rewriter);
- resultLoops.epilogueLoopOp = epilogueForOp;
+ if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
+ resultLoops.epilogueLoopOp = epilogueForOp;
}
// Create unrolled loop.
@@ -493,8 +493,8 @@ UnrolledLoopInfo mlir::loopUnrollByFactor(
},
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
- (void)forOp.promoteIfSingleIteration(rewriter);
- resultLoops.mainLoopOp = forOp;
+ if (forOp.promoteIfSingleIteration(rewriter).failed())
+ resultLoops.mainLoopOp = forOp;
return resultLoops;
}
More information about the Mlir-commits
mailing list