[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add num-threads option (PR #74854)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 8 08:05:15 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openmp

Author: Pablo Antonio Martinez (pabloantoniom)

<details>
<summary>Changes</summary>

Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing to set the number of threads to be used in the `omp.parallel` to a fixed value.

---
Full diff: https://github.com/llvm/llvm-project/pull/74854.diff


3 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.td (+5) 
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+13-5) 
- (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+5-5) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb..f42f061492a34 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
                 "constructs.";
 
+  let options = [
+    Option<"numThreads", "num-threads", "unsigned",
+           /*default=*/"0", "Number of threads to use">
+  ];
+
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
                            "memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e..06fac6763da35 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -340,8 +340,10 @@ namespace {
 
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
-  ParallelOpLowering(MLIRContext *context)
-      : OpRewritePattern<scf::ParallelOp>(context) {}
+  unsigned numThreads;
+
+  ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -390,6 +392,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
     // Create the parallel wrapper.
     auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+    if (numThreads > 1) {
+      rewriter.setInsertionPoint(ompParallel);
+      mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(numThreads));
+      ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
+    }
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +451,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext());
+  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
@@ -463,7 +471,7 @@ struct SCFToOpenMPPass
 
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation())))
+    if (failed(applyPatterns(getOperation(), numThreads)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b..09baa9edc03de 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
           %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +20,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @nested_loops
 func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +43,7 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @adjacent_loops
 func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +55,7 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
   // CHECK:   omp.terminator
   // CHECK: }
 
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/74854


More information about the Mlir-commits mailing list