[Mlir-commits] [mlir] 6c1f655 - [mlir] Async: special handling for parallel loops with zero iterations

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jul 23 01:23:09 PDT 2021


Author: Eugene Zhulenev
Date: 2021-07-23T01:22:59-07:00
New Revision: 6c1f65581891265154db4fb789a2d9cf4893b9bf

URL: https://github.com/llvm/llvm-project/commit/6c1f65581891265154db4fb789a2d9cf4893b9bf
DIFF: https://github.com/llvm/llvm-project/commit/6c1f65581891265154db4fb789a2d9cf4893b9bf.diff

LOG: [mlir] Async: special handling for parallel loops with zero iterations

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D106590

Added: 
    

Modified: 
    mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
    mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
    mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 521180cbd4b9c..a8858913cc1fb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -650,52 +650,73 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
   for (size_t i = 1; i < tripCounts.size(); ++i)
     tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
 
-  // With large number of threads the value of creating many compute blocks
-  // is reduced because the problem typically becomes memory bound. For small
-  // number of threads it helps with stragglers.
-  float overshardingFactor = numWorkerThreads <= 4    ? 8.0
-                             : numWorkerThreads <= 8  ? 4.0
-                             : numWorkerThreads <= 16 ? 2.0
-                             : numWorkerThreads <= 32 ? 1.0
-                             : numWorkerThreads <= 64 ? 0.8
-                                                      : 0.6;
-
-  // Do not overload worker threads with too many compute blocks.
-  Value maxComputeBlocks = b.create<ConstantIndexOp>(
-      std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
-
-  // Target block size from the pass parameters.
-  Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
-
-  // Compute parallel block size from the parallel problem size:
-  //   blockSize = min(tripCount,
-  //                   max(ceil_div(tripCount, maxComputeBlocks),
-  //                       targetComputeBlockSize))
-  Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
-  Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
-  Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
-  Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
-  Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
-  Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
-
-  // Compute balanced block size for the estimated block count.
-  Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
-  Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
-
-  // Create a parallel compute function that takes a block id and computes the
-  // parallel operation body for a subset of iteration space.
-  ParallelComputeFunction parallelComputeFunction =
-      createParallelComputeFunction(op, rewriter);
-
-  // Dispatch parallel compute function using async recursive work splitting, or
-  // by submitting compute task sequentially from a caller thread.
-  if (asyncDispatch) {
-    doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
-                    blockCount, tripCounts);
-  } else {
-    doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
-                         blockCount, tripCounts);
-  }
+  // Short circuit no-op parallel loops (zero iterations) that can arise from
+  // the memrefs with dynamic dimension(s) equal to zero.
+  Value c0 = b.create<ConstantIndexOp>(0);
+  Value isZeroIterations = b.create<CmpIOp>(CmpIPredicate::eq, tripCount, c0);
+
+  // Do absolutely nothing if the trip count is zero.
+  auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
+    nestedBuilder.create<scf::YieldOp>(loc);
+  };
+
+  // Compute the parallel block size and dispatch concurrent tasks computing
+  // results for each block.
+  auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+    ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+    // With large number of threads the value of creating many compute blocks
+    // is reduced because the problem typically becomes memory bound. For small
+    // number of threads it helps with stragglers.
+    float overshardingFactor = numWorkerThreads <= 4    ? 8.0
+                               : numWorkerThreads <= 8  ? 4.0
+                               : numWorkerThreads <= 16 ? 2.0
+                               : numWorkerThreads <= 32 ? 1.0
+                               : numWorkerThreads <= 64 ? 0.8
+                                                        : 0.6;
+
+    // Do not overload worker threads with too many compute blocks.
+    Value maxComputeBlocks = b.create<ConstantIndexOp>(
+        std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
+
+    // Target block size from the pass parameters.
+    Value targetComputeBlock = b.create<ConstantIndexOp>(targetBlockSize);
+
+    // Compute parallel block size from the parallel problem size:
+    //   blockSize = min(tripCount,
+    //                   max(ceil_div(tripCount, maxComputeBlocks),
+    //                       targetComputeBlock))
+    Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
+    Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlock);
+    Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlock);
+    Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
+    Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
+    Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
+
+    // Compute balanced block size for the estimated block count.
+    Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
+    Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
+
+    // Create a parallel compute function that takes a block id and computes the
+    // parallel operation body for a subset of iteration space.
+    ParallelComputeFunction parallelComputeFunction =
+        createParallelComputeFunction(op, rewriter);
+
+    // Dispatch parallel compute function using async recursive work splitting,
+    // or by submitting compute task sequentially from a caller thread.
+    if (asyncDispatch) {
+      doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+                      blockCount, tripCounts);
+    } else {
+      doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+                           blockCount, tripCounts);
+    }
+
+    nb.create<scf::YieldOp>();
+  };
+
+  // Replace the `scf.parallel` operation with the parallel compute function.
+  b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
 
   // Parallel operation was replaced with a block iteration loop.
   rewriter.eraseOp(op);

diff  --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
index a6e308e422e20..7bc7edc27b486 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -1,16 +1,25 @@
 // RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=true  \
-// RUN: | FileCheck %s
+// RUN: | FileCheck %s --dump-input=always
 
-// CHECK-LABEL: @loop_1d
+// CHECK-LABEL: @loop_1d(
+// CHECK-SAME:    %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
 func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
-  // CHECK: %[[C0:.*]] = constant 0 : index
-  // CHECK: %[[GROUP:.*]] = async.create_group
-  // CHECK: scf.if {{.*}} {
-  // CHECK:   call @parallel_compute_fn(%[[C0]]
-  // CHECK: } else {
-  // CHECK:   call @async_dispatch_fn
-  // CHECK: }
-  // CHECK: async.await_all %[[GROUP]]
+  // CHECK:      %[[C0:.*]] = constant 0 : index
+
+  // CHECK:      %[[RANGE:.*]] = subi %[[UB]], %[[LB]]
+  // CHECK:      %[[TRIP_CNT:.*]] = ceildivi_signed %[[RANGE]], %[[STEP]]
+  // CHECK:      %[[IS_NOOP:.*]] = cmpi eq, %[[TRIP_CNT]], %[[C0]] : index
+
+  // CHECK:      scf.if %[[IS_NOOP]] {
+  // CHECK-NEXT: } else {
+  // CHECK:        %[[GROUP:.*]] = async.create_group
+  // CHECK:        scf.if {{.*}} {
+  // CHECK:          call @parallel_compute_fn(%[[C0]]
+  // CHECK:        } else {
+  // CHECK:          call @async_dispatch_fn
+  // CHECK:        }
+  // CHECK:        async.await_all %[[GROUP]]
+  // CHECK:      }
   scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
     %one = constant 1.0 : f32
     memref.store %one, %arg3[%i] : memref<?xf32>

diff  --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 75e4e63814892..b195863f959ce 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -5,6 +5,7 @@
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \
+// RUN:               -std-expand                                              \
 // RUN:               -convert-std-to-llvm                                     \
 // RUN: | mlir-cpu-runner                                                      \
 // RUN:  -e entry -entry-point-result=void -O0                                 \
@@ -18,6 +19,7 @@
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \
+// RUN:               -std-expand                                              \
 // RUN:               -convert-std-to-llvm                                     \
 // RUN: | mlir-cpu-runner                                                      \
 // RUN:  -e entry -entry-point-result=void -O0                                 \
@@ -34,6 +36,7 @@
 // RUN:               -convert-async-to-llvm                                   \
 // RUN:               -convert-scf-to-std                                      \
 // RUN:               -convert-memref-to-llvm                                  \
+// RUN:               -std-expand                                              \
 // RUN:               -convert-std-to-llvm                                     \
 // RUN: | mlir-cpu-runner                                                      \
 // RUN:  -e entry -entry-point-result=void -O0                                 \
@@ -41,6 +44,12 @@
 // RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
 // RUN: | FileCheck %s --dump-input=always
 
+// Suppress constant folding by introducing "dynamic" zero value at runtime.
+func private @zero() -> index {
+  %0 = constant 0 : index
+  return %0 : index
+}
+
 func @entry() {
   %c0 = constant 0.0 : f32
   %c1 = constant 1 : index
@@ -92,6 +101,15 @@ func @entry() {
   // CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0]
   call @print_memref_f32(%U): (memref<*xf32>) -> ()
 
+  // 4. Check that loop with zero iterations doesn't crash at runtime.
+  %lb1 = call @zero(): () -> (index)
+  %ub1 = call @zero(): () -> (index)
+
+  scf.parallel (%i) = (%lb1) to (%ub1) step (%c1) {
+    %false = constant 0 : i1
+    assert %false, "should never be executed"
+  }
+
   memref.dealloc %A : memref<9xf32>
   return
 }


        


More information about the Mlir-commits mailing list