[Mlir-commits] [mlir] bf6477e - [MLIR][OpenMP] Place alloca scope within wsloop in scf.parallel to omp lowering

William S. Moses llvmlistbot at llvm.org
Wed Mar 2 09:47:06 PST 2022


Author: William S. Moses
Date: 2022-03-02T12:46:58-05:00
New Revision: bf6477ebebf82c4c914a116a3d1d673da6d2164e

URL: https://github.com/llvm/llvm-project/commit/bf6477ebebf82c4c914a116a3d1d673da6d2164e
DIFF: https://github.com/llvm/llvm-project/commit/bf6477ebebf82c4c914a116a3d1d673da6d2164e.diff

LOG: [MLIR][OpenMP] Place alloca scope within wsloop in scf.parallel to omp lowering

https://reviews.llvm.org/D120423 replaced the use of stacksave/restore with memref.alloca_scope, but kept the save/restore at the same location. This PR places the allocation scope within the wsloop, thus keeping the same allocation scope as the original scf.parallel (e.g. no longer over stack allocating).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120772

Added: 
    

Modified: 
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/test/Conversion/SCFToOpenMP/reductions.mlir
    mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 4f80e8a04d434..152183c172d9d 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -335,16 +335,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
-    // Replace SCF yield with OpenMP yield.
-    {
-      OpBuilder::InsertionGuard guard(rewriter);
-      rewriter.setInsertionPointToEnd(parallelOp.getBody());
-      assert(llvm::hasSingleElement(parallelOp.getRegion()) &&
-             "expected scf.parallel to have one block");
-      rewriter.replaceOpWithNewOp<omp::YieldOp>(
-          parallelOp.getBody()->getTerminator(), ValueRange());
-    }
-
     // Declare reductions.
     // TODO: consider checking it here is already a compatible reduction
     // declaration and use it instead of redeclaring.
@@ -394,22 +384,31 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       OpBuilder::InsertionGuard guard(rewriter);
       rewriter.createBlock(&ompParallel.region());
 
+      // Replace the loop.
       {
-        auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
-                                                            TypeRange());
-        rewriter.create<omp::TerminatorOp>(loc);
         OpBuilder::InsertionGuard allocaGuard(rewriter);
-        rewriter.createBlock(&scope.getBodyRegion());
-        rewriter.setInsertionPointToStart(&scope.getBodyRegion().front());
-
-        // Replace the loop.
         auto loop = rewriter.create<omp::WsLoopOp>(
             parallelOp.getLoc(), parallelOp.getLowerBound(),
             parallelOp.getUpperBound(), parallelOp.getStep());
-        rewriter.create<memref::AllocaScopeReturnOp>(loc);
+        rewriter.create<omp::TerminatorOp>(loc);
 
         rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(),
                                     loop.region().begin());
+
+        Block *ops = rewriter.splitBlock(&*loop.region().begin(),
+                                         loop.region().begin()->begin());
+
+        rewriter.setInsertionPointToStart(&*loop.region().begin());
+
+        auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
+                                                            TypeRange());
+        rewriter.create<omp::YieldOp>(loc, ValueRange());
+        Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion());
+        rewriter.mergeBlocks(ops, scopeBlock);
+        auto oldYield = cast<scf::YieldOp>(scopeBlock->getTerminator());
+        rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
+        rewriter.replaceOpWithNewOp<memref::AllocaScopeReturnOp>(
+            oldYield, oldYield->getOperands());
         if (!reductionVariables.empty()) {
           loop.reductionsAttr(
               ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));

diff  --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 3e8881ff1d976..ee76868da88d8 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -26,9 +26,9 @@ func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
   %step = arith.constant 1 : index
   %zero = arith.constant 0.0 : f32
   // CHECK: omp.parallel
-  // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]]
+  // CHECK: memref.alloca_scope
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                             step (%arg4, %step) init (%zero) -> (f32) {
     // CHECK: %[[CST_INNER:.*]] = arith.constant 1.0
@@ -161,10 +161,10 @@ func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK: llvm.store %[[IONE]], %[[BUF2]]
 
   // CHECK: omp.parallel
-  // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
   // CHECK-SAME:           @[[$REDF2]] -> %[[BUF2]]
+  // CHECK: memref.alloca_scope
   %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                         step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
     %one = arith.constant 1.0 : f32

diff  --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index 0c16b85dde48b..3b0b3fe38e2ae 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -4,8 +4,8 @@
 func @parallel(%arg0: index, %arg1: index, %arg2: index,
           %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: omp.parallel {
-  // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+  // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
     // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
     "test.payload"(%i, %j) : (index, index) -> ()
@@ -21,12 +21,12 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index,
 func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: omp.parallel {
-    // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+    // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: omp.parallel
-    // CHECK: memref.alloca_scope
     // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+    // CHECK: memref.alloca_scope
     scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
       // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
       "test.payload"(%i, %j) : (index, index) -> ()
@@ -44,8 +44,8 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
 func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: omp.parallel {
-  // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+  // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
     "test.payload1"(%i) : (index) -> ()
@@ -56,8 +56,8 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
   // CHECK: }
 
   // CHECK: omp.parallel {
-  // CHECK: memref.alloca_scope
   // CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+  // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
     // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
     "test.payload2"(%j) : (index) -> ()


        


More information about the Mlir-commits mailing list