[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add num-threads option (PR #74854)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Dec 8 08:05:15 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Pablo Antonio Martinez (pabloantoniom)
<details>
<summary>Changes</summary>
Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing to set the number of threads to be used in the `omp.parallel` to a fixed value.
---
Full diff: https://github.com/llvm/llvm-project/pull/74854.diff
3 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+5)
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+13-5)
- (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+5-5)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb..f42f061492a34 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
+ let options = [
+ Option<"numThreads", "num-threads", "unsigned",
+ /*default=*/"0", "Number of threads to use">
+ ];
+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
"memref::MemRefDialect"];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e..06fac6763da35 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -340,8 +340,10 @@ namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
- ParallelOpLowering(MLIRContext *context)
- : OpRewritePattern<scf::ParallelOp>(context) {}
+ unsigned numThreads;
+
+ ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -390,6 +392,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Create the parallel wrapper.
auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+ if (numThreads > 1) {
+ rewriter.setInsertionPoint(ompParallel);
+ mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(numThreads));
+ ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
+ }
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +451,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
@@ -463,7 +471,7 @@ struct SCFToOpenMPPass
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), numThreads)))
signalPassFailure();
}
};
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b..09baa9edc03de 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +20,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @nested_loops
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +43,7 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @adjacent_loops
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +55,7 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.terminator
// CHECK: }
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/74854
More information about the Mlir-commits
mailing list