[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add automatic-num-threads option (PR #85771)

Pablo Antonio Martinez llvmlistbot at llvm.org
Tue Mar 19 04:06:21 PDT 2024


https://github.com/pabloantoniom created https://github.com/llvm/llvm-project/pull/85771

Add a new option to the `convert-scf-to-openmp` pass to infer the number of threads from the iteration space of the parallel loop.

The number of threads launched in the parallel region can be configured with the `num_threads` attribute. This can be useful if, for example, we have small loops and we want to launch only the number of threads that will be needed, effectively avoiding the overhead of creating useless threads. To do this, we can set the number of threads manually (with the `num-threads` option). However, setting this option manually each time may not be ideal. This commit adds an option to automatically infer the optimal number of threads from the scf loop.

>From def661ee7d5881398edfce01ebdf3877ef10e39c Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 19 Mar 2024 10:14:04 +0000
Subject: [PATCH] [MLIR][SCFToOpenMP] Add automatic-num-threads option

Add a new option to the convert-scf-to-openmp pass to infer the number
of threads from the iteration space of the parallel loop.

The number of threads launched in the parallel region can be configured
with the num_threads attribute. This can be useful if, for example, we
have small loops and we want to launch only the number of threads that
will be needed, effecitvely avoiding the overhead of creating useless
threads. To do this, we can set the number of threads manually
(with the num-threads option). However, setting this attribute manually
each time may not be ideal. This commit adds an option to automatically
infer the optimal number of threads from the scf loop.
---
 mlir/include/mlir/Conversion/Passes.td        |  6 +-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 47 ++++++++++--
 .../SCFToOpenMP/automatic-num-threads.mlir    | 75 +++++++++++++++++++
 3 files changed, 121 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..bef9c3d1b56906 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -918,7 +918,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let options = [
     Option<"numThreads", "num-threads", "unsigned",
            /*default=kUseOpenMPDefaultNumThreads*/"0",
-           "Number of threads to use">
+           "Number of threads to use">,
+    Option<"automaticNumThreads", "automatic-num-threads", "bool",
+           /*default=*/"false",
+           "Automatically set the number of threads inferred from the iteration"
+           "space of the SCF parallel loop">
   ];
 
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 464a647564aced..6dcbb8c950214e 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -356,10 +356,13 @@ namespace {
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
   static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
   unsigned numThreads;
+  bool automaticNumThreads;
 
   ParallelOpLowering(MLIRContext *context,
-                     unsigned numThreads = kUseOpenMPDefaultNumThreads)
-      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
+                     unsigned numThreads = kUseOpenMPDefaultNumThreads,
+                     bool automaticNumThreads = false)
+      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads),
+        automaticNumThreads(automaticNumThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -437,7 +440,37 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     rewriter.eraseOp(reduce);
 
     Value numThreadsVar;
-    if (numThreads > 0) {
+    if (automaticNumThreads) {
+      unsigned inferredNumThreads = 1;
+      for (auto [lb, ub, step] :
+           llvm::zip_equal(parallelOp.getLowerBound(),
+                           parallelOp.getUpperBound(), parallelOp.getStep())) {
+        std::optional<int64_t> cstLb = getConstantIntValue(lb);
+        std::optional<int64_t> cstUb = getConstantIntValue(ub);
+        std::optional<int64_t> cstStep = getConstantIntValue(step);
+
+        if (!cstLb.has_value())
+          return emitError(loc)
+                 << "Expected a parallel loop with constant lower bounds when "
+                    "trying to automatically infer number of threads";
+
+        if (!cstUb.has_value())
+          return emitError(loc)
+                 << "Expected a parallel loop with constant upper bounds when "
+                    "trying to automatically infer number of threads\n";
+
+        if (!cstStep.has_value())
+          return emitError(loc)
+                 << "Expected a forall with constant steps when trying to "
+                    "automatically infer number of threads\n";
+
+        inferredNumThreads =
+            inferredNumThreads *
+            ((cstUb.value() - cstLb.value()) / cstStep.value());
+      }
+      numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(inferredNumThreads));
+    } else if (numThreads > 0) {
       numThreadsVar = rewriter.create<LLVM::ConstantOp>(
           loc, rewriter.getI32IntegerAttr(numThreads));
     }
@@ -504,14 +537,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads,
+                                   bool automaticNumThreads) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
+  patterns.add<ParallelOpLowering>(module.getContext(), numThreads,
+                                   automaticNumThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
@@ -524,7 +559,7 @@ struct SCFToOpenMPPass
 
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation(), numThreads)))
+    if (failed(applyPatterns(getOperation(), numThreads, automaticNumThreads)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
new file mode 100644
index 00000000000000..d498efdf8ae3a9
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-scf-to-openmp='automatic-num-threads' -split-input-file %s | FileCheck %s
+
+func.func @automatic(%arg0: memref<100xf32>) -> memref<100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1) = (%c0) to (%c8) step (%c1) {
+    %0 = memref.load %alloc[%arg1] : memref<100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1] : memref<100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_multiple_ub(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c8 = arith.constant 8 : index
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%c8, %c2, %c2) step (%c1, %c1, %c1) {
+    %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100x100x100xf32>
+}
+
+// -----
+func.func @automatic_nonzero_lb(%arg0: memref<100xf32>) -> memref<100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1) = (%c4) to (%c8) step (%c1) {
+    %0 = memref.load %alloc[%arg1] : memref<100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1] : memref<100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_steps(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c12 = arith.constant 12 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(12 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1, %arg2, %arg3) = (%c4, %c8, %c12) to (%c16, %c16, %c16) step (%c4, %c4, %c2) {
+    %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100x100x100xf32>
+}



More information about the Mlir-commits mailing list