[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add num-threads option (PR #74854)

Pablo Antonio Martinez llvmlistbot at llvm.org
Mon Dec 11 05:03:24 PST 2023


https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/74854

>From b3f2159e704ee24d9bca7fb4c8c5478900925305 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 8 Dec 2023 15:58:53 +0000
Subject: [PATCH 1/2] [MLIR][SCFToOpenMP] Add num-threads option

---
 mlir/include/mlir/Conversion/Passes.td         |  5 +++++
 .../lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 18 +++++++++++++-----
 .../Conversion/SCFToOpenMP/scf-to-openmp.mlir  | 10 +++++-----
 3 files changed, 23 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3..f42f061492a343 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
                 "constructs.";
 
+  let options = [
+    Option<"numThreads", "num-threads", "unsigned",
+           /*default=*/"0", "Number of threads to use">
+  ];
+
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
                            "memref::MemRefDialect"];
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e1..06fac6763da350 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -340,8 +340,10 @@ namespace {
 
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
-  ParallelOpLowering(MLIRContext *context)
-      : OpRewritePattern<scf::ParallelOp>(context) {}
+  unsigned numThreads;
+
+  ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -390,6 +392,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
     // Create the parallel wrapper.
     auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+    if (numThreads > 1) {
+      rewriter.setInsertionPoint(ompParallel);
+      mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(numThreads));
+      ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
+    }
     {
 
       OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +451,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext());
+  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
@@ -463,7 +471,7 @@ struct SCFToOpenMPPass
 
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation())))
+    if (failed(applyPatterns(getOperation(), numThreads)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b8..09baa9edc03de7 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
           %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +20,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @nested_loops
 func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
                    %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
     // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +43,7 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
 // CHECK-LABEL: @adjacent_loops
 func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
                      %arg3: index, %arg4: index, %arg5: index) {
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
   // CHECK: memref.alloca_scope
   scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +55,7 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
   // CHECK:   omp.terminator
   // CHECK: }
 
-  // CHECK: omp.parallel {
+  // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
   // CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
   // CHECK: memref.alloca_scope
   scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {

>From 051c2d72f4cf4e4edc2422f5d139bddab3c86c0d Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 11 Dec 2023 13:02:31 +0000
Subject: [PATCH 2/2] [MLIR][SCFToOpenMP] Fix default number of threads
 condition. Do not assign the number of threads after creating the op, but
 rather at the object constructor. Set default value for numThreads. Remove
 mlir prefix

---
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 21 ++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 06fac6763da350..983a0bdc0ab973 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -342,7 +342,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
   unsigned numThreads;
 
-  ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+  ParallelOpLowering(MLIRContext *context, unsigned numThreads = 0)
       : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
@@ -390,14 +390,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
           reduceOp, reduceOp.getOperand(), std::get<1>(pair));
     }
 
-    // Create the parallel wrapper.
-    auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
-    if (numThreads > 1) {
-      rewriter.setInsertionPoint(ompParallel);
-      mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+    Value numThreadsVar;
+    if (numThreads > 0) {
+      numThreadsVar = rewriter.create<LLVM::ConstantOp>(
           loc, rewriter.getI32IntegerAttr(numThreads));
-      ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
     }
+    // Create the parallel wrapper.
+    auto ompParallel = rewriter.create<omp::ParallelOp>(
+        loc,
+        /* if_expr_var = */ Value{},
+        /* num_threads_var = */ numThreadsVar,
+        /* allocate_vars = */ llvm::SmallVector<Value>{},
+        /* allocators_vars = */ llvm::SmallVector<Value>{},
+        /* reduction_vars = */ llvm::SmallVector<Value>{},
+        /* reductions = */ ArrayAttr{},
+        /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
     {
 
       OpBuilder::InsertionGuard guard(rewriter);



More information about the Mlir-commits mailing list