[Mlir-commits] [mlir] 6c1f655 - [mlir] Async: special handling for parallel loops with zero iterations
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Jul 23 01:23:09 PDT 2021
Author: Eugene Zhulenev
Date: 2021-07-23T01:22:59-07:00
New Revision: 6c1f65581891265154db4fb789a2d9cf4893b9bf
URL: https://github.com/llvm/llvm-project/commit/6c1f65581891265154db4fb789a2d9cf4893b9bf
DIFF: https://github.com/llvm/llvm-project/commit/6c1f65581891265154db4fb789a2d9cf4893b9bf.diff
LOG: [mlir] Async: special handling for parallel loops with zero iterations
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D106590
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 521180cbd4b9c..a8858913cc1fb 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -650,52 +650,73 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
for (size_t i = 1; i < tripCounts.size(); ++i)
tripCount = b.create<MulIOp>(tripCount, tripCounts[i]);
- // With large number of threads the value of creating many compute blocks
- // is reduced because the problem typically becomes memory bound. For small
- // number of threads it helps with stragglers.
- float overshardingFactor = numWorkerThreads <= 4 ? 8.0
- : numWorkerThreads <= 8 ? 4.0
- : numWorkerThreads <= 16 ? 2.0
- : numWorkerThreads <= 32 ? 1.0
- : numWorkerThreads <= 64 ? 0.8
- : 0.6;
-
- // Do not overload worker threads with too many compute blocks.
- Value maxComputeBlocks = b.create<ConstantIndexOp>(
- std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
-
- // Target block size from the pass parameters.
- Value targetComputeBlockSize = b.create<ConstantIndexOp>(targetBlockSize);
-
- // Compute parallel block size from the parallel problem size:
- // blockSize = min(tripCount,
- // max(ceil_div(tripCount, maxComputeBlocks),
- // targetComputeBlockSize))
- Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
- Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlockSize);
- Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlockSize);
- Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
- Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
- Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
-
- // Compute balanced block size for the estimated block count.
- Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
- Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
-
- // Create a parallel compute function that takes a block id and computes the
- // parallel operation body for a subset of iteration space.
- ParallelComputeFunction parallelComputeFunction =
- createParallelComputeFunction(op, rewriter);
-
- // Dispatch parallel compute function using async recursive work splitting, or
- // by submitting compute task sequentially from a caller thread.
- if (asyncDispatch) {
- doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
- blockCount, tripCounts);
- } else {
- doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
- blockCount, tripCounts);
- }
+ // Short circuit no-op parallel loops (zero iterations) that can arise from
+ // the memrefs with dynamic dimension(s) equal to zero.
+ Value c0 = b.create<ConstantIndexOp>(0);
+ Value isZeroIterations = b.create<CmpIOp>(CmpIPredicate::eq, tripCount, c0);
+
+ // Do absolutely nothing if the trip count is zero.
+ auto noOp = [&](OpBuilder &nestedBuilder, Location loc) {
+ nestedBuilder.create<scf::YieldOp>(loc);
+ };
+
+ // Compute the parallel block size and dispatch concurrent tasks computing
+ // results for each block.
+ auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
+ ImplicitLocOpBuilder nb(loc, nestedBuilder);
+
+ // With large number of threads the value of creating many compute blocks
+ // is reduced because the problem typically becomes memory bound. For small
+ // number of threads it helps with stragglers.
+ float overshardingFactor = numWorkerThreads <= 4 ? 8.0
+ : numWorkerThreads <= 8 ? 4.0
+ : numWorkerThreads <= 16 ? 2.0
+ : numWorkerThreads <= 32 ? 1.0
+ : numWorkerThreads <= 64 ? 0.8
+ : 0.6;
+
+ // Do not overload worker threads with too many compute blocks.
+ Value maxComputeBlocks = b.create<ConstantIndexOp>(
+ std::max(1, static_cast<int>(numWorkerThreads * overshardingFactor)));
+
+ // Target block size from the pass parameters.
+ Value targetComputeBlock = b.create<ConstantIndexOp>(targetBlockSize);
+
+ // Compute parallel block size from the parallel problem size:
+ // blockSize = min(tripCount,
+ // max(ceil_div(tripCount, maxComputeBlocks),
+ // targetComputeBlock))
+ Value bs0 = b.create<SignedCeilDivIOp>(tripCount, maxComputeBlocks);
+ Value bs1 = b.create<CmpIOp>(CmpIPredicate::sge, bs0, targetComputeBlock);
+ Value bs2 = b.create<SelectOp>(bs1, bs0, targetComputeBlock);
+ Value bs3 = b.create<CmpIOp>(CmpIPredicate::sle, tripCount, bs2);
+ Value blockSize0 = b.create<SelectOp>(bs3, tripCount, bs2);
+ Value blockCount0 = b.create<SignedCeilDivIOp>(tripCount, blockSize0);
+
+ // Compute balanced block size for the estimated block count.
+ Value blockSize = b.create<SignedCeilDivIOp>(tripCount, blockCount0);
+ Value blockCount = b.create<SignedCeilDivIOp>(tripCount, blockSize);
+
+ // Create a parallel compute function that takes a block id and computes the
+ // parallel operation body for a subset of iteration space.
+ ParallelComputeFunction parallelComputeFunction =
+ createParallelComputeFunction(op, rewriter);
+
+ // Dispatch parallel compute function using async recursive work splitting,
+ // or by submitting compute task sequentially from a caller thread.
+ if (asyncDispatch) {
+ doAsyncDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+ blockCount, tripCounts);
+ } else {
+ doSequantialDispatch(b, rewriter, parallelComputeFunction, op, blockSize,
+ blockCount, tripCounts);
+ }
+
+ nb.create<scf::YieldOp>();
+ };
+
+ // Replace the `scf.parallel` operation with the parallel compute function.
+ b.create<scf::IfOp>(TypeRange(), isZeroIterations, noOp, dispatch);
// Parallel operation was replaced with a block iteration loop.
rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
index a6e308e422e20..7bc7edc27b486 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-async-dispatch.mlir
@@ -1,16 +1,25 @@
// RUN: mlir-opt %s -split-input-file -async-parallel-for=async-dispatch=true \
-// RUN: | FileCheck %s
+// RUN: | FileCheck %s --dump-input=always
-// CHECK-LABEL: @loop_1d
+// CHECK-LABEL: @loop_1d(
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index
func @loop_1d(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<?xf32>) {
- // CHECK: %[[C0:.*]] = constant 0 : index
- // CHECK: %[[GROUP:.*]] = async.create_group
- // CHECK: scf.if {{.*}} {
- // CHECK: call @parallel_compute_fn(%[[C0]]
- // CHECK: } else {
- // CHECK: call @async_dispatch_fn
- // CHECK: }
- // CHECK: async.await_all %[[GROUP]]
+ // CHECK: %[[C0:.*]] = constant 0 : index
+
+ // CHECK: %[[RANGE:.*]] = subi %[[UB]], %[[LB]]
+ // CHECK: %[[TRIP_CNT:.*]] = ceildivi_signed %[[RANGE]], %[[STEP]]
+ // CHECK: %[[IS_NOOP:.*]] = cmpi eq, %[[TRIP_CNT]], %[[C0]] : index
+
+ // CHECK: scf.if %[[IS_NOOP]] {
+ // CHECK-NEXT: } else {
+ // CHECK: %[[GROUP:.*]] = async.create_group
+ // CHECK: scf.if {{.*}} {
+ // CHECK: call @parallel_compute_fn(%[[C0]]
+ // CHECK: } else {
+ // CHECK: call @async_dispatch_fn
+ // CHECK: }
+ // CHECK: async.await_all %[[GROUP]]
+ // CHECK: }
scf.parallel (%i) = (%arg0) to (%arg1) step (%arg2) {
%one = constant 1.0 : f32
memref.store %one, %arg3[%i] : memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
index 75e4e63814892..b195863f959ce 100644
--- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
+++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir
@@ -5,6 +5,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void -O0 \
@@ -18,6 +19,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void -O0 \
@@ -34,6 +36,7 @@
// RUN: -convert-async-to-llvm \
// RUN: -convert-scf-to-std \
// RUN: -convert-memref-to-llvm \
+// RUN: -std-expand \
// RUN: -convert-std-to-llvm \
// RUN: | mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void -O0 \
@@ -41,6 +44,12 @@
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_async_runtime%shlibext\
// RUN: | FileCheck %s --dump-input=always
+// Suppress constant folding by introducing "dynamic" zero value at runtime.
+func private @zero() -> index {
+ %0 = constant 0 : index
+ return %0 : index
+}
+
func @entry() {
%c0 = constant 0.0 : f32
%c1 = constant 1 : index
@@ -92,6 +101,15 @@ func @entry() {
// CHECK: [-20, 0, 0, -17, 0, 0, -14, 0, 0]
call @print_memref_f32(%U): (memref<*xf32>) -> ()
+ // 4. Check that loop with zero iterations doesn't crash at runtime.
+ %lb1 = call @zero(): () -> (index)
+ %ub1 = call @zero(): () -> (index)
+
+ scf.parallel (%i) = (%lb1) to (%ub1) step (%c1) {
+ %false = constant 0 : i1
+ assert %false, "should never be executed"
+ }
+
memref.dealloc %A : memref<9xf32>
return
}
More information about the Mlir-commits
mailing list