[Mlir-commits] [mlir] 7f4f75c - [MLIR][SCFToOpenMP] Add num-threads option (#74854)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 14 01:07:21 PST 2023


Author: Pablo Antonio Martinez
Date: 2023-12-14T09:07:17Z
New Revision: 7f4f75c144c623bea5e6eb042940a5035c4ab826

URL: https://github.com/llvm/llvm-project/commit/7f4f75c144c623bea5e6eb042940a5035c4ab826
DIFF: https://github.com/llvm/llvm-project/commit/7f4f75c144c623bea5e6eb042940a5035c4ab826.diff

LOG: [MLIR][SCFToOpenMP] Add num-threads option (#74854)

Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing
to set the number of threads to be used in the `omp.parallel` to a fixed
value.

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3..6193aeb545bc6b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,12 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
                 "constructs.";
 
+  let options = [
+    Option<"numThreads", "num-threads", "unsigned",
+           /*default=kUseOpenMPDefaultNumThreads*/"0",
+           "Number of threads to use">
+  ];
+
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
                            "memref::MemRefDialect"];
 }

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e1..67033ba812946f 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -339,9 +339,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
 namespace {
 
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
+  static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
+  unsigned numThreads;
 
-  ParallelOpLowering(MLIRContext *context)
-      : OpRewritePattern<scf::ParallelOp>(context) {}
+  ParallelOpLowering(MLIRContext *context,
+                     unsigned numThreads = kUseOpenMPDefaultNumThreads)
+      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -388,8 +391,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
           reduceOp, reduceOp.getOperand(), std::get<1>(pair));
     }
 
+    Value numThreadsVar;
+    if (numThreads > 0) {
+      numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(numThreads));
+    }
     // Create the parallel wrapper.
-    auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+    auto ompParallel = rewriter.create<omp::ParallelOp>(
+        loc,
+        /* if_expr_var = */ Value{},
+        /* num_threads_var = */ numThreadsVar,
+        /* allocate_vars = */ llvm::SmallVector<Value>{},
+        /* allocators_vars = */ llvm::SmallVector<Value>{},
+        /* reduction_vars = */ llvm::SmallVector<Value>{},
+        /* reductions = */ ArrayAttr{},
+        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +459,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext());
+  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
@@ -463,7 +479,7 @@ struct SCFToOpenMPPass
 
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation())))
+    if (failed(applyPatterns(getOperation(), numThreads)))
       signalPassFailure();
   }
 };

diff  --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b8..acd2690c56e2e6 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
           %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +21,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @nested_loops
 func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +45,8 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @adjacent_loops
 func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +58,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
   // CHECK:   omp.terminator
   // CHECK: }
 
-  // CHECK: omp.parallel {
+  // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {


        


More information about the Mlir-commits mailing list