[Mlir-commits] [mlir] 149311b - [async] Get the number of worker threads from the runtime.

Eugene Zhulenev llvmlistbot at llvm.org
Mon Jan 31 12:06:07 PST 2022


Author: bakhtiyar
Date: 2022-01-31T12:06:01-08:00
New Revision: 149311b4055a5f836b8f61ad66e700a93e86ab18

URL: https://github.com/llvm/llvm-project/commit/149311b4055a5f836b8f61ad66e700a93e86ab18
DIFF: https://github.com/llvm/llvm-project/commit/149311b4055a5f836b8f61ad66e700a93e86ab18.diff

LOG: [async] Get the number of worker threads from the runtime.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D117751

Added: 
    mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir

Modified: 
    mlir/include/mlir/Dialect/Async/IR/Async.h
    mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
    mlir/include/mlir/Dialect/Async/Passes.td
    mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Dialect/Async/IR/CMakeLists.txt
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h
index 0783009d2855c..0c60a3c06c131 100644
--- a/mlir/include/mlir/Dialect/Async/IR/Async.h
+++ b/mlir/include/mlir/Dialect/Async/IR/Async.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
index eef3b6178abd2..5a7ec4eb8c503 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
@@ -16,6 +16,7 @@
 include "mlir/Dialect/Async/IR/AsyncDialect.td"
 include "mlir/Dialect/Async/IR/AsyncTypes.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -529,4 +530,17 @@ def Async_RuntimeDropRefOp : Async_Op<"runtime.drop_ref"> {
   }];
 }
 
+def Async_RuntimeNumWorkerThreadsOp :
+  Async_Op<"runtime.num_worker_threads",
+           [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "gets the number of threads in the threadpool from the runtime";
+  let description = [{
+    The `async.runtime.num_worker_threads` operation gets the number of threads
+    in the threadpool from the runtime.
+  }];
+
+  let results = (outs Index:$result);
+  let assemblyFormat = "attr-dict `:` type($result)";
+}
+
 #endif // ASYNC_OPS

diff  --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td
index 2d92903e4c80b..8eb1eef9b71fd 100644
--- a/mlir/include/mlir/Dialect/Async/Passes.td
+++ b/mlir/include/mlir/Dialect/Async/Passes.td
@@ -25,7 +25,8 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
 
     Option<"numWorkerThreads", "num-workers",
       "int32_t", /*default=*/"8",
-      "The number of available workers to execute async operations.">,
+      "The number of available workers to execute async operations. If `-1` "
+      "the value will be retrieved from the runtime.">,
 
     Option<"minTaskSize", "min-task-size",
       "int32_t", /*default=*/"1000",

diff  --git a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
index a681c47200daf..ba59c39fac569 100644
--- a/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/AsyncRuntime.h
@@ -123,6 +123,9 @@ extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *, CoroHandle,
 extern "C" void
 mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *, CoroHandle, CoroResume);
 
+// Returns the current number of available worker threads in the threadpool.
+extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads();
+
 //===----------------------------------------------------------------------===//
 // Small async runtime support library for testing.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 812fbd60473ab..a709a97c2fe0a 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -59,6 +59,8 @@ static constexpr const char *kAwaitValueAndExecute =
     "mlirAsyncRuntimeAwaitValueAndExecute";
 static constexpr const char *kAwaitAllAndExecute =
     "mlirAsyncRuntimeAwaitAllInGroupAndExecute";
+static constexpr const char *kGetNumWorkerThreads =
+    "mlirAsyncRuntimGetNumWorkerThreads";
 
 namespace {
 /// Async Runtime API function types.
@@ -181,6 +183,10 @@ struct AsyncAPI {
     return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
   }
 
+  static FunctionType getNumWorkerThreads(MLIRContext *ctx) {
+    return FunctionType::get(ctx, {}, {IndexType::get(ctx)});
+  }
+
   // Auxiliary coroutine resume intrinsic wrapper.
   static Type resumeFunctionType(MLIRContext *ctx) {
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
@@ -226,6 +232,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
               AsyncAPI::awaitValueAndExecuteFunctionType(ctx));
   addFuncDecl(kAwaitAllAndExecute,
               AsyncAPI::awaitAllAndExecuteFunctionType(ctx));
+  addFuncDecl(kGetNumWorkerThreads, AsyncAPI::getNumWorkerThreads(ctx));
 }
 
 //===----------------------------------------------------------------------===//
@@ -879,6 +886,30 @@ class RuntimeAddToGroupOpLowering
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Convert async.runtime.num_worker_threads to the corresponding runtime API
+// call.
+//===----------------------------------------------------------------------===//
+
+namespace {
+class RuntimeNumWorkerThreadsOpLowering
+    : public OpConversionPattern<RuntimeNumWorkerThreadsOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(RuntimeNumWorkerThreadsOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    // Replace with a runtime API function call.
+    rewriter.replaceOpWithNewOp<CallOp>(op, kGetNumWorkerThreads,
+                                        rewriter.getIndexType());
+
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Async reference counting ops lowering (`async.runtime.add_ref` and
 // `async.runtime.drop_ref` to the corresponding API calls).
@@ -984,8 +1015,9 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
   patterns.add<RuntimeSetAvailableOpLowering, RuntimeSetErrorOpLowering,
                RuntimeIsErrorOpLowering, RuntimeAwaitOpLowering,
                RuntimeAwaitAndResumeOpLowering, RuntimeResumeOpLowering,
-               RuntimeAddToGroupOpLowering, RuntimeAddRefOpLowering,
-               RuntimeDropRefOpLowering>(converter, ctx);
+               RuntimeAddToGroupOpLowering, RuntimeNumWorkerThreadsOpLowering,
+               RuntimeAddRefOpLowering, RuntimeDropRefOpLowering>(converter,
+                                                                  ctx);
 
   // Lower async.runtime operations that rely on LLVM type converter to convert
   // from async value payload type to the LLVM type.

diff  --git a/mlir/lib/Dialect/Async/IR/CMakeLists.txt b/mlir/lib/Dialect/Async/IR/CMakeLists.txt
index 87946f715a0aa..9d27d82e130ed 100644
--- a/mlir/lib/Dialect/Async/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Async/IR/CMakeLists.txt
@@ -9,5 +9,6 @@ add_mlir_dialect_library(MLIRAsync
 
   LINK_LIBS PUBLIC
   MLIRDialect
+  MLIRInferTypeOpInterface
   MLIRIR
   )

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index edcbefc8e977a..e87b5157ed059 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -23,6 +23,7 @@
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/RegionUtils.h"
 
@@ -799,19 +800,53 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
         numUnrollableLoops++;
     }
 
+    Value numWorkerThreadsVal;
+    if (numWorkerThreads >= 0)
+      numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads);
+    else
+      numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>();
+
     // With large number of threads the value of creating many compute blocks
-    // is reduced because the problem typically becomes memory bound. For small
-    // number of threads it helps with stragglers.
-    float overshardingFactor = numWorkerThreads <= 4    ? 8.0
-                               : numWorkerThreads <= 8  ? 4.0
-                               : numWorkerThreads <= 16 ? 2.0
-                               : numWorkerThreads <= 32 ? 1.0
-                               : numWorkerThreads <= 64 ? 0.8
-                                                        : 0.6;
-
-    // Do not overload worker threads with too many compute blocks.
-    Value maxComputeBlocks = b.create<arith::ConstantIndexOp>(
-        std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
+    // is reduced because the problem typically becomes memory bound. For this
+    // reason we scale the number of workers using an equivalent to the
+    // following logic:
+    //   float overshardingFactor = numWorkerThreads <= 4    ? 8.0
+    //                              : numWorkerThreads <= 8  ? 4.0
+    //                              : numWorkerThreads <= 16 ? 2.0
+    //                              : numWorkerThreads <= 32 ? 1.0
+    //                              : numWorkerThreads <= 64 ? 0.8
+    //                                                       : 0.6;
+
+    // Pairs of non-inclusive lower end of the bracket and factor that the
+    // number of workers needs to be scaled with if it falls in that bucket.
+    const SmallVector<std::pair<int, float>> overshardingBrackets = {
+        {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}};
+    const float initialOvershardingFactor = 8.0f;
+
+    Value scalingFactor = b.create<arith::ConstantFloatOp>(
+        llvm::APFloat(initialOvershardingFactor), b.getF32Type());
+    for (const std::pair<int, float> &p : overshardingBrackets) {
+      Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first);
+      Value inBracket = b.create<arith::CmpIOp>(
+          arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin);
+      Value bracketScalingFactor = b.create<arith::ConstantFloatOp>(
+          llvm::APFloat(p.second), b.getF32Type());
+      scalingFactor =
+          b.create<SelectOp>(inBracket, bracketScalingFactor, scalingFactor);
+    }
+    Value numWorkersIndex =
+        b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
+    Value numWorkersFloat =
+        b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
+    Value scaledNumWorkers =
+        b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
+    Value scaledNumInt =
+        b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
+    Value scaledWorkers =
+        b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
+
+    Value maxComputeBlocks = b.create<arith::MaxSIOp>(
+        b.create<arith::ConstantIndexOp>(1), scaledWorkers);
 
     // Compute parallel block size from the parallel problem size:
     //   blockSize = min(tripCount,

diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 03243700e9ecb..42ec1ce02005d 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -438,6 +438,10 @@ extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
   }
 }
 
+extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
+  return getDefaultAsyncRuntime()->getThreadPool().getThreadCount();
+}
+
 //===----------------------------------------------------------------------===//
 // Small async runtime support library for testing.
 //===----------------------------------------------------------------------===//
@@ -515,6 +519,8 @@ void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
+  exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
+               &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
 }

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
new file mode 100644
index 0000000000000..04cfeabfb5828
--- /dev/null
+++ b/mlir/test/Dialect/Async/async-parallel-for-num-worker-threads.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -split-input-file -async-parallel-for=num-workers=-1  \
+// RUN: | FileCheck %s --dump-input=always
+
+// CHECK-LABEL: @num_worker_threads(
+// CHECK:       %[[MEMREF:.*]]: memref<?xf32>
+func @num_worker_threads(%arg0: memref<?xf32>) {
+
+  // CHECK: %[[scalingCstInit:.*]] = arith.constant 8.000000e+00 : f32
+  // CHECK: %[[bracketLowerBound4:.*]] = arith.constant 4 : index
+  // CHECK: %[[scalingCst4:.*]] = arith.constant 4.000000e+00 : f32
+  // CHECK: %[[bracketLowerBound8:.*]] = arith.constant 8 : index
+  // CHECK: %[[scalingCst8:.*]] = arith.constant 2.000000e+00 : f32
+  // CHECK: %[[bracketLowerBound16:.*]] = arith.constant 16 : index
+  // CHECK: %[[scalingCst16:.*]] = arith.constant 1.000000e+00 : f32
+  // CHECK: %[[bracketLowerBound32:.*]] = arith.constant 32 : index
+  // CHECK: %[[scalingCst32:.*]] = arith.constant 8.000000e-01 : f32
+  // CHECK: %[[bracketLowerBound64:.*]] = arith.constant 64 : index
+  // CHECK: %[[scalingCst64:.*]] = arith.constant 6.000000e-01 : f32
+  // CHECK:   %[[workersIndex:.*]] = async.runtime.num_worker_threads : index
+  // CHECK:   %[[inBracket4:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound4]] : index
+  // CHECK:   %[[scalingFactor4:.*]] = select %[[inBracket4]], %[[scalingCst4]], %[[scalingCstInit]] : f32
+  // CHECK:   %[[inBracket8:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound8]] : index
+  // CHECK:   %[[scalingFactor8:.*]] = select %[[inBracket8]], %[[scalingCst8]], %[[scalingFactor4]] : f32
+  // CHECK:   %[[inBracket16:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound16]] : index
+  // CHECK:   %[[scalingFactor16:.*]] = select %[[inBracket16]], %[[scalingCst16]], %[[scalingFactor8]] : f32
+  // CHECK:   %[[inBracket32:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound32]] : index
+  // CHECK:   %[[scalingFactor32:.*]] = select %[[inBracket32]], %[[scalingCst32]], %[[scalingFactor16]] : f32
+  // CHECK:   %[[inBracket64:.*]] = arith.cmpi sgt, %[[workersIndex]], %[[bracketLowerBound64]] : index
+  // CHECK:   %[[scalingFactor64:.*]] = select %[[inBracket64]], %[[scalingCst64]], %[[scalingFactor32]] : f32
+  // CHECK:   %[[workersInt:.*]] = arith.index_cast %[[workersIndex]] : index to i32
+  // CHECK:   %[[workersFloat:.*]] = arith.sitofp %[[workersInt]] : i32 to f32
+  // CHECK:   %[[scaledFloat:.*]] = arith.mulf %[[scalingFactor64]], %[[workersFloat]] : f32
+  // CHECK:   %[[scaledInt:.*]] = arith.fptosi %[[scaledFloat]] : f32 to i32
+  // CHECK:   %[[scaledIndex:.*]] = arith.index_cast %[[scaledInt]] : i32 to index
+
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 100 : index
+  %st = arith.constant 1 : index
+  scf.parallel (%i) = (%lb) to (%ub) step (%st) {
+    %one = arith.constant 1.0 : f32
+    memref.store %one, %arg0[%i] : memref<?xf32>
+  }
+
+  return
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 15aeed24c0245..5e0e3a69ece86 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1240,6 +1240,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":ControlFlowInterfacesTdFiles",
+        ":InferTypeOpInterfaceTdFiles",
         ":OpBaseTdFiles",
         ":SideEffectInterfacesTdFiles",
     ],
@@ -2140,6 +2141,7 @@ cc_library(
         ":ControlFlowInterfaces",
         ":Dialect",
         ":IR",
+        ":InferTypeOpInterface",
         ":SideEffectInterfaces",
         ":StandardOps",
         ":Support",


        


More information about the Mlir-commits mailing list