[Mlir-commits] [mlir] 7f4f75c - [MLIR][SCFToOpenMP] Add num-threads option (#74854)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 14 01:07:21 PST 2023
Author: Pablo Antonio Martinez
Date: 2023-12-14T09:07:17Z
New Revision: 7f4f75c144c623bea5e6eb042940a5035c4ab826
URL: https://github.com/llvm/llvm-project/commit/7f4f75c144c623bea5e6eb042940a5035c4ab826
DIFF: https://github.com/llvm/llvm-project/commit/7f4f75c144c623bea5e6eb042940a5035c4ab826.diff
LOG: [MLIR][SCFToOpenMP] Add num-threads option (#74854)
Add `num-threads` option to the `-convert-scf-to-openmp` pass, allowing
to set the number of threads to be used in the `omp.parallel` to a fixed
value.
Added:
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb3..6193aeb545bc6b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,12 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
+ let options = [
+ Option<"numThreads", "num-threads", "unsigned",
+ /*default=kUseOpenMPDefaultNumThreads*/"0",
+ "Number of threads to use">
+ ];
+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
"memref::MemRefDialect"];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e1..67033ba812946f 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -339,9 +339,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
+ static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
+ unsigned numThreads;
- ParallelOpLowering(MLIRContext *context)
- : OpRewritePattern<scf::ParallelOp>(context) {}
+ ParallelOpLowering(MLIRContext *context,
+ unsigned numThreads = kUseOpenMPDefaultNumThreads)
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -388,8 +391,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
reduceOp, reduceOp.getOperand(), std::get<1>(pair));
}
+ Value numThreadsVar;
+ if (numThreads > 0) {
+ numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(numThreads));
+ }
// Create the parallel wrapper.
- auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+ auto ompParallel = rewriter.create<omp::ParallelOp>(
+ loc,
+ /* if_expr_var = */ Value{},
+ /* num_threads_var = */ numThreadsVar,
+ /* allocate_vars = */ llvm::SmallVector<Value>{},
+ /* allocators_vars = */ llvm::SmallVector<Value>{},
+ /* reduction_vars = */ llvm::SmallVector<Value>{},
+ /* reductions = */ ArrayAttr{},
+ /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +459,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
@@ -463,7 +479,7 @@ struct SCFToOpenMPPass
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), numThreads)))
signalPassFailure();
}
};
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b8..acd2690c56e2e6 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +21,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @nested_loops
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +45,8 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @adjacent_loops
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +58,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.terminator
// CHECK: }
- // CHECK: omp.parallel {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
More information about the Mlir-commits
mailing list