[llvm-branch-commits] [mlir] [MLIR][SCF] Update scf.parallel	lowering to OpenMP (3/5) (PR #89212)
    via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Thu Apr 18 03:53:02 PDT 2024
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch makes changes to the `scf.parallel` to `omp.parallel` + `omp.wsloop` lowering pass in order to introduce a nested `omp.loop_nest` as well, and to follow the new loop wrapper role for `omp.wsloop`.
This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
---
Full diff: https://github.com/llvm/llvm-project/pull/89212.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+40-12) 
- (modified) mlir/test/Conversion/SCFToOpenMP/reductions.mlir (+5) 
- (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+23-8) 
``````````diff
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 7f91367ad427a2..f0b8d6c5309357 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -461,18 +461,51 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
       // Replace the loop.
       {
         OpBuilder::InsertionGuard allocaGuard(rewriter);
-        auto loop = rewriter.create<omp::WsloopOp>(
+        // Create worksharing loop wrapper.
+        auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc());
+        if (!reductionVariables.empty()) {
+          wsloopOp.setReductionsAttr(
+              ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
+          wsloopOp.getReductionVarsMutable().append(reductionVariables);
+        }
+        rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator.
+
+        // The wrapper's entry block arguments will define the reduction
+        // variables.
+        llvm::SmallVector<mlir::Type> reductionTypes;
+        reductionTypes.reserve(reductionVariables.size());
+        llvm::transform(reductionVariables, std::back_inserter(reductionTypes),
+                        [](mlir::Value v) { return v.getType(); });
+        rewriter.createBlock(
+            &wsloopOp.getRegion(), {}, reductionTypes,
+            llvm::SmallVector<mlir::Location>(reductionVariables.size(),
+                                              parallelOp.getLoc()));
+
+        rewriter.setInsertionPoint(
+            rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()));
+
+        // Create loop nest and populate region with contents of scf.parallel.
+        auto loopOp = rewriter.create<omp::LoopNestOp>(
             parallelOp.getLoc(), parallelOp.getLowerBound(),
             parallelOp.getUpperBound(), parallelOp.getStep());
-        rewriter.create<omp::TerminatorOp>(loc);
 
-        rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.getRegion(),
-                                    loop.getRegion().begin());
+        rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(),
+                                    loopOp.getRegion().begin());
+
+        // Remove reduction-related block arguments from omp.loop_nest and
+        // redirect uses to the corresponding omp.wsloop block argument.
+        mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
+        unsigned numLoops = parallelOp.getNumLoops();
+        rewriter.replaceAllUsesWith(
+            loopOpEntryBlock.getArguments().drop_front(numLoops),
+            wsloopOp.getRegion().getArguments());
+        loopOpEntryBlock.eraseArguments(
+            numLoops, loopOpEntryBlock.getNumArguments() - numLoops);
 
-        Block *ops = rewriter.splitBlock(&*loop.getRegion().begin(),
-                                         loop.getRegion().begin()->begin());
+        Block *ops = rewriter.splitBlock(&*loopOp.getRegion().begin(),
+                                         loopOp.getRegion().begin()->begin());
 
-        rewriter.setInsertionPointToStart(&*loop.getRegion().begin());
+        rewriter.setInsertionPointToStart(&*loopOp.getRegion().begin());
 
         auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(),
                                                             TypeRange());
@@ -481,11 +514,6 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         rewriter.mergeBlocks(ops, scopeBlock);
         rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
         rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange());
-        if (!reductionVariables.empty()) {
-          loop.setReductionsAttr(
-              ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols));
-          loop.getReductionVarsMutable().append(reductionVariables);
-        }
       }
     }
 
diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
index 3b6c145d62f1a8..fc6d56559c2618 100644
--- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir
@@ -28,6 +28,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK: omp.parallel
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF]] %[[BUF]] -> %[[PVT_BUF:[a-z0-9]+]]
+  // CHECK: omp.loop_nest
   // CHECK: memref.alloca_scope
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                             step (%arg4, %step) init (%zero) -> (f32) {
@@ -43,6 +44,7 @@ func.func @reduction1(%arg0 : index, %arg1 : index, %arg2 : index,
     }
     // CHECK: omp.yield
   }
+  // CHECK:   omp.terminator
   // CHECK: omp.terminator
   // CHECK: llvm.load %[[BUF]]
   return
@@ -107,6 +109,7 @@ func.func @reduction_muli(%arg0 : index, %arg1 : index, %arg2 : index,
   %one = arith.constant 1 : i32
   // CHECK: %[[RED_VAR:.*]] = llvm.alloca %{{.*}} x i32 : (i64) -> !llvm.ptr
   // CHECK: omp.wsloop reduction(@[[$REDI]] %[[RED_VAR]] -> %[[RED_PVT_VAR:.*]] : !llvm.ptr)
+  // CHECK: omp.loop_nest
   scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                             step (%arg4, %step) init (%one) -> (i32) {
     // CHECK: %[[C2:.*]] = arith.constant 2 : i32
@@ -208,6 +211,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
   // CHECK: omp.wsloop
   // CHECK-SAME: reduction(@[[$REDF1]] %[[BUF1]] -> %[[PVT_BUF1:[a-z0-9]+]]
   // CHECK-SAME:           @[[$REDF2]] %[[BUF2]] -> %[[PVT_BUF2:[a-z0-9]+]]
+  // CHECK: omp.loop_nest
   // CHECK: memref.alloca_scope
   %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
                         step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
@@ -236,6 +240,7 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
     }
     // CHECK: omp.yield
   }
+  // CHECK:   omp.terminator
   // CHECK: omp.terminator
   // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
   // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index acd2690c56e2e6..b2f19d294cb5fe 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -2,10 +2,11 @@
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
-          %arg3: index, %arg4: index, %arg5: index) {
+                    %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
     // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> ()
@@ -13,6 +14,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return
@@ -23,20 +26,26 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
-    // CHECK: memref.alloca_scope
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+  // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: omp.parallel
-    // CHECK: omp.wsloop for (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+    // CHECK: omp.wsloop {
+    // CHECK: omp.loop_nest (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
     // CHECK: memref.alloca_scope
     scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
       // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()
       "test.payload"(%i, %j) : (index, index) -> ()
       // CHECK: }
     }
-    // CHECK:   omp.yield
+    // CHECK:     omp.yield
+    // CHECK:   }
+    // CHECK:   omp.terminator
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return
@@ -47,7 +56,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> ()
@@ -55,12 +65,15 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
 
   // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
-  // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
     // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> ()
@@ -68,6 +81,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
     // CHECK:   omp.yield
     // CHECK: }
   }
+  // CHECK:     omp.terminator
+  // CHECK:   }
   // CHECK:   omp.terminator
   // CHECK: }
   return
``````````
</details>
https://github.com/llvm/llvm-project/pull/89212
    
    
More information about the llvm-branch-commits
mailing list