[Mlir-commits] [mlir] [MLIR][SCFToOpenMP] Add automatic-num-threads option (PR #85771)

Pablo Antonio Martinez llvmlistbot at llvm.org
Tue Mar 26 07:58:32 PDT 2024


https://github.com/pabloantoniom updated https://github.com/llvm/llvm-project/pull/85771

>From def661ee7d5881398edfce01ebdf3877ef10e39c Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 19 Mar 2024 10:14:04 +0000
Subject: [PATCH 1/2] [MLIR][SCFToOpenMP] Add automatic-num-threads option

Add a new option to the convert-scf-to-openmp pass to infer the number
of threads from the iteration space of the parallel loop.

The number of threads launched in the parallel region can be configured
with the num_threads attribute. This can be useful if, for example, we
have small loops and we want to launch only the number of threads that
will be needed, effecitvely avoiding the overhead of creating useless
threads. To do this, we can set the number of threads manually
(with the num-threads option). However, setting this attribute manually
each time may not be ideal. This commit adds an option to automatically
infer the optimal number of threads from the scf loop.
---
 mlir/include/mlir/Conversion/Passes.td        |  6 +-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    | 47 ++++++++++--
 .../SCFToOpenMP/automatic-num-threads.mlir    | 75 +++++++++++++++++++
 3 files changed, 121 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bd81cc6d5323bf..bef9c3d1b56906 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -918,7 +918,11 @@ def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let options = [
     Option<"numThreads", "num-threads", "unsigned",
            /*default=kUseOpenMPDefaultNumThreads*/"0",
-           "Number of threads to use">
+           "Number of threads to use">,
+    Option<"automaticNumThreads", "automatic-num-threads", "bool",
+           /*default=*/"false",
+           "Automatically set the number of threads inferred from the iteration"
+           "space of the SCF parallel loop">
   ];
 
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 464a647564aced..6dcbb8c950214e 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -356,10 +356,13 @@ namespace {
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
   static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
   unsigned numThreads;
+  bool automaticNumThreads;
 
   ParallelOpLowering(MLIRContext *context,
-                     unsigned numThreads = kUseOpenMPDefaultNumThreads)
-      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
+                     unsigned numThreads = kUseOpenMPDefaultNumThreads,
+                     bool automaticNumThreads = false)
+      : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads),
+        automaticNumThreads(automaticNumThreads) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -437,7 +440,37 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     rewriter.eraseOp(reduce);
 
     Value numThreadsVar;
-    if (numThreads > 0) {
+    if (automaticNumThreads) {
+      unsigned inferredNumThreads = 1;
+      for (auto [lb, ub, step] :
+           llvm::zip_equal(parallelOp.getLowerBound(),
+                           parallelOp.getUpperBound(), parallelOp.getStep())) {
+        std::optional<int64_t> cstLb = getConstantIntValue(lb);
+        std::optional<int64_t> cstUb = getConstantIntValue(ub);
+        std::optional<int64_t> cstStep = getConstantIntValue(step);
+
+        if (!cstLb.has_value())
+          return emitError(loc)
+                 << "Expected a parallel loop with constant lower bounds when "
+                    "trying to automatically infer number of threads";
+
+        if (!cstUb.has_value())
+          return emitError(loc)
+                 << "Expected a parallel loop with constant upper bounds when "
+                    "trying to automatically infer number of threads\n";
+
+        if (!cstStep.has_value())
+          return emitError(loc)
+                 << "Expected a forall with constant steps when trying to "
+                    "automatically infer number of threads\n";
+
+        inferredNumThreads =
+            inferredNumThreads *
+            ((cstUb.value() - cstLb.value()) / cstStep.value());
+      }
+      numThreadsVar = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(inferredNumThreads));
+    } else if (numThreads > 0) {
       numThreadsVar = rewriter.create<LLVM::ConstantOp>(
           loc, rewriter.getI32IntegerAttr(numThreads));
     }
@@ -504,14 +537,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
+static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads,
+                                   bool automaticNumThreads) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
+  patterns.add<ParallelOpLowering>(module.getContext(), numThreads,
+                                   automaticNumThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
@@ -524,7 +559,7 @@ struct SCFToOpenMPPass
 
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation(), numThreads)))
+    if (failed(applyPatterns(getOperation(), numThreads, automaticNumThreads)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
new file mode 100644
index 00000000000000..d498efdf8ae3a9
--- /dev/null
+++ b/mlir/test/Conversion/SCFToOpenMP/automatic-num-threads.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-scf-to-openmp='automatic-num-threads' -split-input-file %s | FileCheck %s
+
+func.func @automatic(%arg0: memref<100xf32>) -> memref<100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c8 = arith.constant 8 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1) = (%c0) to (%c8) step (%c1) {
+    %0 = memref.load %alloc[%arg1] : memref<100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1] : memref<100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_multiple_ub(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c8 = arith.constant 8 : index
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+  %c0 = arith.constant 0 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(32 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1, %arg2, %arg3) = (%c0, %c0, %c0) to (%c8, %c2, %c2) step (%c1, %c1, %c1) {
+    %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100x100x100xf32>
+}
+
+// -----
+func.func @automatic_nonzero_lb(%arg0: memref<100xf32>) -> memref<100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1) = (%c4) to (%c8) step (%c1) {
+    %0 = memref.load %alloc[%arg1] : memref<100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1] : memref<100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100xf32>
+}
+
+// -----
+func.func @automatic_steps(%arg0: memref<100x100x100xf32>) -> memref<100x100x100xf32> {
+  %cst = arith.constant 1.000000e+00 : f32
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %c12 = arith.constant 12 : index
+  %alloc = memref.alloc() {alignment = 64 : i64} : memref<100x100x100xf32>
+  // CHECK: %[[NTH:.*]] = llvm.mlir.constant(12 : i32) : i32
+  // CHECK: omp.parallel num_threads(%[[NTH]] : i32)
+  scf.parallel (%arg1, %arg2, %arg3) = (%c4, %c8, %c12) to (%c16, %c16, %c16) step (%c4, %c4, %c2) {
+    %0 = memref.load %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    %1 = arith.addf %0, %cst : f32
+    memref.store %1, %alloc[%arg1, %arg2, %arg3] : memref<100x100x100xf32>
+    scf.reduce
+  }
+  return %alloc : memref<100x100x100xf32>
+}

>From 86d5e339f53d0b1be4bd8ae9a57c598b11e4661b Mon Sep 17 00:00:00 2001
From: Pablo Antonio Martinez <pablo.antonio.martinez at huawei.com>
Date: Tue, 26 Mar 2024 14:58:08 +0000
Subject: [PATCH 2/2] Fix error message

---
 mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6dcbb8c950214e..fb805263770161 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -460,9 +460,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
                     "trying to automatically infer number of threads\n";
 
         if (!cstStep.has_value())
-          return emitError(loc)
-                 << "Expected a forall with constant steps when trying to "
-                    "automatically infer number of threads\n";
+          return emitError(loc) << "Expected a parallel loop with constant "
+                                   "steps when trying to "
+                                   "automatically infer number of threads\n";
 
         inferredNumThreads =
             inferredNumThreads *



More information about the Mlir-commits mailing list