[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add num-threads option (PR #74854)
Pablo Antonio Martinez
llvmlistbot at llvm.org
Tue Dec 12 03:46:12 PST 2023
https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/74854
>From b3f2159e704ee24d9bca7fb4c8c5478900925305 Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Fri, 8 Dec 2023 15:58:53 +0000
Subject: [PATCH 1/3] [MLIR][SCFToOpenMP] Add num-threads option
---
mlir/include/mlir/Conversion/Passes.td | 5 +++++
.../lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 18 +++++++++++++-----
.../Conversion/SCFToOpenMP/scf-to-openmp.mlir | 10 +++++-----
3 files changed, 23 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 06756ff3df0bb..f42f061492a34 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -886,6 +886,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
"constructs.";
+ let options = [
+ Option<"numThreads", "num-threads", "unsigned",
+ /*default=*/"0", "Number of threads to use">
+ ];
+
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
"memref::MemRefDialect"];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ee3ee02cf535e..06fac6763da35 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -340,8 +340,10 @@ namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
- ParallelOpLowering(MLIRContext *context)
- : OpRewritePattern<scf::ParallelOp>(context) {}
+ unsigned numThreads;
+
+ ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -390,6 +392,12 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
// Create the parallel wrapper.
auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
+ if (numThreads > 1) {
+ rewriter.setInsertionPoint(ompParallel);
+ mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(numThreads));
+ ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
+ }
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -443,14 +451,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
@@ -463,7 +471,7 @@ struct SCFToOpenMPPass
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), numThreads)))
signalPassFailure();
}
};
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index e0fdcae1b896b..09baa9edc03de 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -1,9 +1,9 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='num-threads=4' %s | FileCheck %s
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +20,7 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @nested_loops
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +43,7 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @adjacent_loops
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +55,7 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.terminator
// CHECK: }
- // CHECK: omp.parallel {
+ // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
>From 051c2d72f4cf4e4edc2422f5d139bddab3c86c0d Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Mon, 11 Dec 2023 13:02:31 +0000
Subject: [PATCH 2/3] [MLIR][SCFToOpenMP] Fix default number of threads
condition. Do not assign the number of threads after creating the op, but
rather at the object constructor. Set default value for numThreads. Remove
mlir prefix
---
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 21 ++++++++++++-------
1 file changed, 14 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 06fac6763da35..983a0bdc0ab97 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -342,7 +342,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
unsigned numThreads;
- ParallelOpLowering(MLIRContext *context, unsigned numThreads)
+ ParallelOpLowering(MLIRContext *context, unsigned numThreads = 0)
: OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
@@ -390,14 +390,21 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
reduceOp, reduceOp.getOperand(), std::get<1>(pair));
}
- // Create the parallel wrapper.
- auto ompParallel = rewriter.create<omp::ParallelOp>(loc);
- if (numThreads > 1) {
- rewriter.setInsertionPoint(ompParallel);
- mlir::Value numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ Value numThreadsVar;
+ if (numThreads > 0) {
+ numThreadsVar = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32IntegerAttr(numThreads));
- ompParallel.getNumThreadsVarMutable().assign(numThreadsVar);
}
+ // Create the parallel wrapper.
+ auto ompParallel = rewriter.create<omp::ParallelOp>(
+ loc,
+ /* if_expr_var = */ Value{},
+ /* num_threads_var = */ numThreadsVar,
+ /* allocate_vars = */ llvm::SmallVector<Value>{},
+ /* allocators_vars = */ llvm::SmallVector<Value>{},
+ /* reduction_vars = */ llvm::SmallVector<Value>{},
+ /* reductions = */ ArrayAttr{},
+ /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
{
OpBuilder::InsertionGuard guard(rewriter);
>From 4214ded56ac73f565d9196826546d87e113291be Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 12 Dec 2023 11:43:59 +0000
Subject: [PATCH 3/3] [MLIR][SCFToOpenMP] Improve tests. Use const instead of
hardcoded 0 to avoid ambiguity
---
mlir/include/mlir/Conversion/Passes.td | 3 ++-
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +++--
mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir | 12 ++++++++----
3 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index f42f061492a34..6193aeb545bc6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -888,7 +888,8 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let options = [
Option<"numThreads", "num-threads", "unsigned",
- /*default=*/"0", "Number of threads to use">
+ /*default=kUseOpenMPDefaultNumThreads*/"0",
+ "Number of threads to use">
];
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 983a0bdc0ab97..67033ba812946 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -339,10 +339,11 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
-
+ static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
unsigned numThreads;
- ParallelOpLowering(MLIRContext *context, unsigned numThreads = 0)
+ ParallelOpLowering(MLIRContext *context,
+ unsigned numThreads = kUseOpenMPDefaultNumThreads)
: OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
index 09baa9edc03de..acd2690c56e2e 100644
--- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
+++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
@@ -3,7 +3,8 @@
// CHECK-LABEL: @parallel
func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
@@ -20,7 +21,8 @@ func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @nested_loops
func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -43,7 +45,8 @@ func.func @nested_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK-LABEL: @adjacent_loops
func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
%arg3: index, %arg4: index, %arg5: index) {
- // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
// CHECK: memref.alloca_scope
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
@@ -55,7 +58,8 @@ func.func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index,
// CHECK: omp.terminator
// CHECK: }
- // CHECK: omp.parallel num_threads(%[[NTH:.*]] : i32) {
+ // CHECK: %[[FOUR:.+]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[FOUR]] : i32) {
// CHECK: omp.wsloop for (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
// CHECK: memref.alloca_scope
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
More information about the Mlir-commits
mailing list