[Mlir-commits] [mlir] c1194c2 - [mlir:Async] Change async-parallel-for block size/count calculation

Eugene Zhulenev llvmlistbot at llvm.org
Tue Jun 29 12:57:18 PDT 2021


Author: Eugene Zhulenev
Date: 2021-06-29T12:57:11-07:00
New Revision: c1194c2ec35029f96ce75ab54555dccf2b7e8681

URL: https://github.com/llvm/llvm-project/commit/c1194c2ec35029f96ce75ab54555dccf2b7e8681
DIFF: https://github.com/llvm/llvm-project/commit/c1194c2ec35029f96ce75ab54555dccf2b7e8681.diff

LOG: [mlir:Async] Change async-parallel-for block size/count calculation

Depends On D105037

Avoid creating too many tasks when the number of workers is large.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D105126

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index a104fb73571d..373ee8b01dca 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -653,9 +653,19 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   for (size_t i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
 
+  // With large number of threads the value of creating many compute blocks
+  // is reduced because the problem typically becomes memory bound. For small
+  // number of threads it helps with stragglers.
+  float overshardingFactor = numWorkerThreads <= 4    ? 8.0
+                             : numWorkerThreads <= 8  ? 4.0
+                             : numWorkerThreads <= 16 ? 2.0
+                             : numWorkerThreads <= 32 ? 1.0
+                             : numWorkerThreads <= 64 ? 0.8
+                                                      : 0.6;
+
   // Do not overload worker threads with too many compute blocks.
-  Value maxComputeBlocks =
-      b.create<ConstantIndexOp>(numWorkerThreads * kMaxOversharding);
+  Value maxComputeBlocks = b.create<ConstantIndexOp>(
+      std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
 
   // Target block size from the pass parameters.
   Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
@@ -668,7 +678,11 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
   Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
   Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
-  Value blockSize = b.create<SelectOp>(bs3, tripCount, bs2);
+  Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
+  Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
+
+  // Compute balanced block size for the estimated block count.
+  Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
   Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
 
   // Create a parallel compute function that takes a block id and computes the


        


More information about the Mlir-commits mailing list