[Mlir-commits] [mlir] 0cbb6e7 - [mlir][scf] Expose isPerfectlyNestedForLoops (#152115)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Aug 26 01:05:23 PDT 2025
Author: Shay Kleiman
Date: 2025-08-26T11:05:19+03:00
New Revision: 0cbb6e7d6c30c4dd4d395941d332b0249088f9ff
URL: https://github.com/llvm/llvm-project/commit/0cbb6e7d6c30c4dd4d395941d332b0249088f9ff
DIFF: https://github.com/llvm/llvm-project/commit/0cbb6e7d6c30c4dd4d395941d332b0249088f9ff.diff
LOG: [mlir][scf] Expose isPerfectlyNestedForLoops (#152115)
The function `isPerfectlyNestedForLoops` is useful on its own and so I'm
exposing it for downstream use.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index e620067c15be9..ecd829ed14add 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -213,6 +213,14 @@ scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
scf::ForallOp forallOp);
+/// Check if the provided loops are perfectly nested for-loops. Perfect nesting
+/// means:
+/// 1. All loops are scf.for operations
+/// 2. Each outer loop's region iter args match the inner loop's init args
+/// 3. Each outer loop's yields match the inner loop's results
+/// 4. Each region iter arg and result has exactly one use
+bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 250c413eff9e5..834c02126fa53 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1916,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}
-/// Check that the loop is perfectly nested.
-/// The loops are expected to be ordered from outer most to inner most.
-/// For example:
-/// ```
-/// %0 = scf.for()
-/// %1 = scf.for()
-/// %2 = scf.for()
-/// %3 = ...
-/// yield %3
-/// yield %2
-/// yield %1
-/// ```
-/// Here loops should be [%0, %1].
-static bool
-isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
- assert(!loops.empty() && "unexpected empty loop nest");
- if (loops.size() == 1) {
- return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
- }
- for (auto [outerLoop, innerLoop] :
- llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
- auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
- auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
- if (!outerFor || !innerFor) {
- return false;
- }
- auto outerBBArgs = outerFor.getRegionIterArgs();
- auto innerIterArgs = innerFor.getInitArgs();
- if (outerBBArgs.size() != innerIterArgs.size()) {
- return false;
- }
-
- for (auto [outerBBArg, innerIterArg] :
- llvm::zip_equal(outerBBArgs, innerIterArgs)) {
- if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
- innerIterArg != outerBBArg) {
- return false;
- }
- }
-
- ValueRange outerYields =
- cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
- ValueRange innerResults = innerFor.getResults();
- if (outerYields.size() != innerResults.size()) {
- return false;
- }
- for (auto [outerYield, innerResult] :
- llvm::zip_equal(outerYields, innerResults)) {
- if (!llvm::hasSingleElement(innerResult.getUses()) ||
- outerYield != innerResult) {
- return false;
- }
- }
- }
- return true;
-}
-
/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 49102583ec5e7..684dff8121de6 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1512,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
rewriter.replaceOp(forallOp, normalizedForallOp);
return normalizedForallOp;
}
+
+bool mlir::isPerfectlyNestedForLoops(
+ MutableArrayRef<LoopLikeOpInterface> loops) {
+ assert(!loops.empty() && "unexpected empty loop nest");
+ if (loops.size() == 1)
+ return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
+ auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
+ if (!outerFor || !innerFor)
+ return false;
+ auto outerBBArgs = outerFor.getRegionIterArgs();
+ auto innerIterArgs = innerFor.getInitArgs();
+ if (outerBBArgs.size() != innerIterArgs.size())
+ return false;
+
+ for (auto [outerBBArg, innerIterArg] :
+ llvm::zip_equal(outerBBArgs, innerIterArgs)) {
+ if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
+ innerIterArg != outerBBArg)
+ return false;
+ }
+
+ ValueRange outerYields =
+ cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
+ ValueRange innerResults = innerFor.getResults();
+ if (outerYields.size() != innerResults.size())
+ return false;
+ for (auto [outerYield, innerResult] :
+ llvm::zip_equal(outerYields, innerResults)) {
+ if (!llvm::hasSingleElement(innerResult.getUses()) ||
+ outerYield != innerResult)
+ return false;
+ }
+ }
+ return true;
+}
More information about the Mlir-commits
mailing list