[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add num-threads option (PR #74854)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Fri Dec 8 08:04:48 PST 2023
https://github.com/pabloantoniom created https://github.com/llvm/llvm-project/pull/74854
Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing to set the number of threads to be used in the `omp.parallel` to a fixed value.
>From b3f2159e704ee24d9bca7fb4c8c5478900925305 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 8 Dec 2023 15:58:53 +0000
Subject: [PATCH] [MLIR][SCFToOpenMP] Add num-threads option
---
mlir/include/mlir/Conversion/Passes.td | 5 +++++
.../lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 18 +++++++++++++-----
.../Conversion/SCFToOpenMP/scf-to-openmp.mlir | 10 +++++-----
3 files changed, 23 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb..f42f061492a34 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
+ let options = [
+ Option<"numThreads", "num-threads", "unsigned",
+ /*default=*/"0", "Number of threads to use">
+ ];
+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
"memref::MemRefDialect"];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e..06fac6763da35 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -340,8 +340,10 @@ namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
- ParallelOpLowering(MLIRContext *context)
- : OpRewritePattern<scf::ParallelOp>(context) {}
+ unsigned numThreads;
+
+ ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -390,6 +392,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Create the parallel wrapper.
auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+ if (numThreads > 1) {
+ rewriter.setInsertionPoint(ompParallel);
+ mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(numThreads));
+ ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
+ }
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +451,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
@@ -463,7 +471,7 @@ struct SCFToOpenMPPass
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), numThreads)))
signalPassFailure();
}
};
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b..09baa9edc03de 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +20,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @nested_loops
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +43,7 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @adjacent_loops
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +55,7 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.terminator
// CHECK: }
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
More information about the Mlir-commits
mailing list