[Mlir-commits] [mlir] 6484567 - [MLIR][SCF] Find all innermost loops for parallel loop tiling
Frederik Gossen
llvmlistbot at llvm.org
Fri Nov 27 01:41:11 PST 2020
Author: Frederik Gossen
Date: 2020-11-27T10:08:56+01:00
New Revision: 6484567f14881003a7c46d1587dbb0cf8082282a
URL: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a
DIFF: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a.diff
LOG: [MLIR][SCF] Find all innermost loops for parallel loop tiling
Overcome the assumption that parallel loops are only nested in other parallel
loops.
Differential Revision: https://reviews.llvm.org/D92188
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 7bcc989a5b28..7bd589214f4c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -22,15 +22,15 @@ using namespace mlir::scf;
/// Tile a parallel loop of the form
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
-/// step (%arg4, %arg5)
+/// step (%arg4, %arg5)
///
/// into
/// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
-/// step (%arg4*tileSize[0],
-/// %arg5*tileSize[1])
+/// step (%arg4*tileSize[0],
+/// %arg5*tileSize[1])
/// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
-/// min(%arg5*tileSize[1], %arg3-%i1))
-/// step (%arg4, %arg5)
+/// min(%arg5*tileSize[1], %arg3-%i1))
+/// step (%arg4, %arg5)
///
/// where the uses of %i0 and %i1 in the loop body are replaced by
/// %i0 + j0 and %i1 + %j1.
@@ -126,17 +126,27 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
op.erase();
}
-/// Get a list of most nested parallel loops. Assumes that ParallelOps are
-/// only directly nested.
-static bool getInnermostNestedLoops(Block *block,
- SmallVectorImpl<ParallelOp> &loops) {
- bool hasInnerLoop = false;
- for (auto parallelOp : block->getOps<ParallelOp>()) {
- hasInnerLoop = true;
- if (!getInnermostNestedLoops(parallelOp.getBody(), loops))
- loops.push_back(parallelOp);
+/// Get a list of most nested parallel loops.
+static bool getInnermostPloops(Operation *rootOp,
+ SmallVectorImpl<ParallelOp> &result) {
+ assert(rootOp != nullptr && "Root operation must not be a nullptr.");
+ bool rootEnclosesPloops = false;
+ for (Region ®ion : rootOp->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ for (Operation &op : block) {
+ bool enclosesPloops = getInnermostPloops(&op, result);
+ rootEnclosesPloops |= enclosesPloops;
+ if (auto ploop = dyn_cast<ParallelOp>(op)) {
+ rootEnclosesPloops = true;
+
+ // Collect ploop if it is an innermost one.
+ if (!enclosesPloops)
+ result.push_back(ploop);
+ }
+ }
+ }
}
- return hasInnerLoop;
+ return rootEnclosesPloops;
}
namespace {
@@ -148,14 +158,12 @@ struct ParallelLoopTiling
}
void runOnFunction() override {
- SmallVector<ParallelOp, 2> mostNestedParallelOps;
- for (Block &block : getFunction()) {
- getInnermostNestedLoops(&block, mostNestedParallelOps);
- }
- for (ParallelOp pLoop : mostNestedParallelOps) {
+ SmallVector<ParallelOp, 2> innermostPloops;
+ getInnermostPloops(getFunction().getOperation(), innermostPloops);
+ for (ParallelOp ploop : innermostPloops) {
// FIXME: Add reduction support.
- if (pLoop.getNumReductions() == 0)
- tileParallelLoop(pLoop, tileSizes);
+ if (ploop.getNumReductions() == 0)
+ tileParallelLoop(ploop, tileSizes);
}
}
};
diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
index e0dc8344f14d..5d3a676f58ab 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
@@ -112,3 +112,29 @@ func @tile_nested_innermost() {
// CHECK: }
// CHECK: return
// CHECK: }
+
+// -----
+
+func @tile_nested_in_non_ploop() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ scf.for %i = %c0 to %c2 step %c1 {
+ scf.for %j = %c0 to %c2 step %c1 {
+ scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+ }
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @tile_nested_in_non_ploop
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.parallel
+// CHECK: scf.parallel
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
More information about the Mlir-commits
mailing list