[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add automatic-num-threads option (PR #85771)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 19 04:06:48 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Pablo Antonio Martinez (pabloantoniom)
<details>
<summary>Changes</summary>
Add a new option to the `convert-scf-to-openmp` pass to infer the number of threads from the iteration space of the parallel loop.
The number of threads launched in the parallel region can be configured with the `num_threads` attribute. This can be useful if, for example, we have small loops and we want to launch only the number of threads that will be needed, effectively avoiding the overhead of creating useless threads. To do this, we can set the number of threads manually (with the `num-threads` option). However, setting this option manually each time may not be ideal. This commit adds an option to automatically infer the optimal number of threads from the scf loop.
---
Full diff: https://github.com/llvm/llvm-project/pull/85771.diff
3 Files Affected:
- (modified) mlir/include/mlir/Conversion/Passes.td (+5-1)
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+41-6)
- (added) mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir (+75)
``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..bef9c3d1b56906 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -918,7 +918,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
let options = [
Option<"numThreads", "num-threads", "unsigned",
/*default=kUseOpenMPDefaultNumThreads*/"0",
- "Number of threads to use">
+ "Number of threads to use">,
+ Option<"automaticNumThreads", "automatic-num-threads", "bool",
+ /*default=*/"false",
+ "Automatically set the number of threads inferred from the iteration"
+ "space of the SCF parallel loop">
];
let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 464a647564aced..6dcbb8c950214e 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -356,10 +356,13 @@ namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
unsigned numThreads;
+ bool automaticNumThreads;
ParallelOpLowering(MLIRContext *context,
- unsigned numThreads = kUseOpenMPDefaultNumThreads)
- : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
+ unsigned numThreads = kUseOpenMPDefaultNumThreads,
+ bool automaticNumThreads = false)
+ : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads),
+ automaticNumThreads(automaticNumThreads) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
@@ -437,7 +440,37 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.eraseOp(reduce);
Value numThreadsVar;
- if (numThreads > 0) {
+ if (automaticNumThreads) {
+ unsigned inferredNumThreads = 1;
+ for (auto [lb, ub, step] :
+ llvm::zip_equal(parallelOp.getLowerBound(),
+ parallelOp.getUpperBound(), parallelOp.getStep())) {
+ std::optional<int64_t> cstLb = getConstantIntValue(lb);
+ std::optional<int64_t> cstUb = getConstantIntValue(ub);
+ std::optional<int64_t> cstStep = getConstantIntValue(step);
+
+ if (!cstLb.has_value())
+ return emitError(loc)
+ << "Expected a parallel loop with constant lower bounds when "
+ "trying to automatically infer number of threads";
+
+ if (!cstUb.has_value())
+ return emitError(loc)
+ << "Expected a parallel loop with constant upper bounds when "
+ "trying to automatically infer number of threads\n";
+
+ if (!cstStep.has_value())
+ return emitError(loc)
+ << "Expected a forall with constant steps when trying to "
+ "automatically infer number of threads\n";
+
+ inferredNumThreads =
+ inferredNumThreads *
+ ((cstUb.value() - cstLb.value()) / cstStep.value());
+ }
+ numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(inferredNumThreads));
+ } else if (numThreads > 0) {
numThreadsVar = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32IntegerAttr(numThreads));
}
@@ -504,14 +537,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads,
+ bool automaticNumThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
+ patterns.add<ParallelOpLowering>(module.getContext(), numThreads,
+ automaticNumThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
@@ -524,7 +559,7 @@ struct SCFToOpenMPPass
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation(), numThreads)))
+ if (failed(applyPatterns(getOperation(), numThreads, automaticNumThreads)))
signalPassFailure();
}
};
diff --git a/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
new file mode 100644
index 00000000000000..d498efdf8ae3a9
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-scf-to-openmp='automatic-num-threads' -split-input-file %s | FileCheck %s
+
+func.func @automatic(%arg0: memref<100xf32>) -> memref<100xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c8 = arith.constant 8 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+ // CHECK: %[[NTH:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+ scf.parallel (%arg1) = (%c0) to (%c8) step (%c1) {
+ %0 = memref.load %alloc[%arg1] : memref<100xf32>
+ %1 = arith.addf %0, %cst : f32
+ memref.store %1, %alloc[%arg1] : memref<100xf32>
+ scf.reduce
+ }
+ return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_multiple_ub(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c8 = arith.constant 8 : index
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+ // CHECK: %[[NTH:.*]] = llvm.mlir.constant(32 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+ scf.parallel (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%c8, %c2, %c2) step (%c1, %c1, %c1) {
+ %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+ %1 = arith.addf %0, %cst : f32
+ memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+ scf.reduce
+ }
+ return %alloc : memref<100x100x100xf32>
+}
+
+// -----
+func.func @automatic_nonzero_lb(%arg0: memref<100xf32>) -> memref<100xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+ %c4 = arith.constant 4 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+ // CHECK: %[[NTH:.*]] = llvm.mlir.constant(4 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+ scf.parallel (%arg1) = (%c4) to (%c8) step (%c1) {
+ %0 = memref.load %alloc[%arg1] : memref<100xf32>
+ %1 = arith.addf %0, %cst : f32
+ memref.store %1, %alloc[%arg1] : memref<100xf32>
+ scf.reduce
+ }
+ return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_steps(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+ %cst = arith.constant 1.000000e+00 : f32
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c12 = arith.constant 12 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+ // CHECK: %[[NTH:.*]] = llvm.mlir.constant(12 : i32) : i32
+ // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+ scf.parallel (%arg1, %arg2, %arg3) = (%c4, %c8, %c12) to (%c16, %c16, %c16) step (%c4, %c4, %c2) {
+ %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+ %1 = arith.addf %0, %cst : f32
+ memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+ scf.reduce
+ }
+ return %alloc : memref<100x100x100xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/85771
More information about the Mlir-commits
mailing list