[llvm-branch-commits] [mlir] 6484567 - [MLIR][SCF] Find all innermost loops for parallel loop tiling

Frederik Gossen via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 27 01:45:49 PST 2020


Author: Frederik Gossen
Date: 2020-11-27T10:08:56+01:00
New Revision: 6484567f14881003a7c46d1587dbb0cf8082282a

URL: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a
DIFF: https://github.com/llvm/llvm-project/commit/6484567f14881003a7c46d1587dbb0cf8082282a.diff

LOG: [MLIR][SCF] Find all innermost loops for parallel loop tiling

Overcome the assumption that parallel loops are only nested in other parallel
loops.

Differential Revision: https://reviews.llvm.org/D92188

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
    mlir/test/Dialect/SCF/parallel-loop-tiling.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
index 7bcc989a5b28..7bd589214f4c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp
@@ -22,15 +22,15 @@ using namespace mlir::scf;
 
 /// Tile a parallel loop of the form
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
-///                                             step (%arg4, %arg5)
+///                                            step (%arg4, %arg5)
 ///
 /// into
 ///   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
-///                                             step (%arg4*tileSize[0],
-///                                                   %arg5*tileSize[1])
+///                                            step (%arg4*tileSize[0],
+///                                                  %arg5*tileSize[1])
 ///     scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0)
-///                                           min(%arg5*tileSize[1], %arg3-%i1))
-///                                        step (%arg4, %arg5)
+///                                          min(%arg5*tileSize[1], %arg3-%i1))
+///                                      step (%arg4, %arg5)
 ///
 /// where the uses of %i0 and %i1 in the loop body are replaced by
 /// %i0 + j0 and %i1 + %j1.
@@ -126,17 +126,27 @@ void mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
   op.erase();
 }
 
-/// Get a list of most nested parallel loops. Assumes that ParallelOps are
-/// only directly nested.
-static bool getInnermostNestedLoops(Block *block,
-                                    SmallVectorImpl<ParallelOp> &loops) {
-  bool hasInnerLoop = false;
-  for (auto parallelOp : block->getOps<ParallelOp>()) {
-    hasInnerLoop = true;
-    if (!getInnermostNestedLoops(parallelOp.getBody(), loops))
-      loops.push_back(parallelOp);
+/// Get a list of most nested parallel loops.
+static bool getInnermostPloops(Operation *rootOp,
+                               SmallVectorImpl<ParallelOp> &result) {
+  assert(rootOp != nullptr && "Root operation must not be a nullptr.");
+  bool rootEnclosesPloops = false;
+  for (Region &region : rootOp->getRegions()) {
+    for (Block &block : region.getBlocks()) {
+      for (Operation &op : block) {
+        bool enclosesPloops = getInnermostPloops(&op, result);
+        rootEnclosesPloops |= enclosesPloops;
+        if (auto ploop = dyn_cast<ParallelOp>(op)) {
+          rootEnclosesPloops = true;
+
+          // Collect ploop if it is an innermost one.
+          if (!enclosesPloops)
+            result.push_back(ploop);
+        }
+      }
+    }
   }
-  return hasInnerLoop;
+  return rootEnclosesPloops;
 }
 
 namespace {
@@ -148,14 +158,12 @@ struct ParallelLoopTiling
   }
 
   void runOnFunction() override {
-    SmallVector<ParallelOp, 2> mostNestedParallelOps;
-    for (Block &block : getFunction()) {
-      getInnermostNestedLoops(&block, mostNestedParallelOps);
-    }
-    for (ParallelOp pLoop : mostNestedParallelOps) {
+    SmallVector<ParallelOp, 2> innermostPloops;
+    getInnermostPloops(getFunction().getOperation(), innermostPloops);
+    for (ParallelOp ploop : innermostPloops) {
       // FIXME: Add reduction support.
-      if (pLoop.getNumReductions() == 0)
-        tileParallelLoop(pLoop, tileSizes);
+      if (ploop.getNumReductions() == 0)
+        tileParallelLoop(ploop, tileSizes);
     }
   }
 };

diff  --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
index e0dc8344f14d..5d3a676f58ab 100644
--- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
+++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir
@@ -112,3 +112,29 @@ func @tile_nested_innermost() {
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
+
+// -----
+
+func @tile_nested_in_non_ploop() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  scf.for %i = %c0 to %c2 step %c1 {
+    scf.for %j = %c0 to %c2 step %c1 {
+      scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @tile_nested_in_non_ploop
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             scf.parallel
+// CHECK:               scf.parallel
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+// CHECK:       }


        


More information about the llvm-branch-commits mailing list