[Mlir-commits] [mlir] [mlir] Allow unroll & jam on SCF loops with results (PR #98887)
Javier Setoain
llvmlistbot at llvm.org
Mon Jul 15 04:32:48 PDT 2024
https://github.com/jsetoain created https://github.com/llvm/llvm-project/pull/98887
Unlike the affine version, the unroll & jam version for SCF loops does not support loops with results/iter_args, but there's not real reason to have this difference between the two.
Even though `iter_args` may indicate a loop-carried dependency and, therefore, its unsuitability for unroll & jam, there are many transformations that materialize loops with "transient" `iter_args` that don't represent real dependencies, and will eventually go away. E.g.: `linalg::tileLinalgOp` on non-bufferized linalg ops.
Given that this transformation doesn't perform a full dependency analysis to ensure its safety, it's already up to the user to make sure that the loop is parallel before proceeding. Allowing loops with results makes this transformation more widely applicable without really losing on safety.
>From 989d3e8991c6a473a108b838d939b44218329a45 Mon Sep 17 00:00:00 2001
From: Javier Setoain <javier.setoain at gmail.com>
Date: Mon, 15 Jul 2024 05:18:18 -0600
Subject: [PATCH] [mlir] Allow unroll & jam on SCF loops with results
Unlike the affine version, the unroll & jam version for SCF loops does
not support loops with results/iter_args, but there's not real reason
to have this difference between the two.
Even though `iter_args` may indicate a loop-carried dependency and,
therefore, its unsuitability for unroll & jam, there are many
transformations that materialize loops with "transient" `iter_args` that
don't represent real dependencies, and will eventually go away. E.g.:
`linalg::tileLinalgOp` on non-bufferized linalg ops.
Given that this transformation doesn't perform a full dependency
analysis to ensure its safety, it's already up to the user to make sure
that the loop is parallel before proceeding. Allowing loops with
results makes this transformation more widely applicable without really
losing on safety.
---
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 19 +++++++++----------
1 file changed, 9 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c0ee9d2afe91c..b390658009587 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -486,12 +486,15 @@ LogicalResult mlir::loopUnrollByFactor(
}
/// Check if bounds of all inner loops are defined outside of `forOp`
-/// and return false if not.
+/// or defined by constants, and return false if not.
static bool areInnerBoundsInvariant(scf::ForOp forOp) {
auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
- if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
- !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ if (!(forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ isa<arith::ConstantOp>(innerForOp.getLowerBound().getDefiningOp())) ||
+ !(forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ isa<arith::ConstantOp>(innerForOp.getUpperBound().getDefiningOp())) ||
+ !(forOp.isDefinedOutsideOfLoop(innerForOp.getStep()) ||
+ isa<arith::ConstantOp>(innerForOp.getStep().getDefiningOp())))
return WalkResult::interrupt();
return WalkResult::advance();
@@ -500,6 +503,8 @@ static bool areInnerBoundsInvariant(scf::ForOp forOp) {
}
/// Unrolls and jams this loop by the specified factor.
+/// This function doesn't verify that the loop is parallel, if there are true
+/// loop carried dependencies, this function will produce invalid code.
LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
uint64_t unrollJamFactor) {
assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
@@ -514,12 +519,6 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
return failure();
}
- // Currently, for operations with results are not supported.
- if (forOp->getNumResults() > 0) {
- LDBG("failed to unroll and jam: unsupported loop with results");
- return failure();
- }
-
// Currently, only constant trip count that divided by the unroll factor is
// supported.
std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
More information about the Mlir-commits
mailing list