[Mlir-commits] [mlir] beff16f - [mlir] Async: update condition for dispatching block-aligned compute function
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Feb 23 10:30:03 PST 2022
Author: Eugene Zhulenev
Date: 2022-02-23T10:29:55-08:00
New Revision: beff16f7bd6353054ee7dbf43c6f35082ad61577
URL: https://github.com/llvm/llvm-project/commit/beff16f7bd6353054ee7dbf43c6f35082ad61577
DIFF: https://github.com/llvm/llvm-project/commit/beff16f7bd6353054ee7dbf43c6f35082ad61577.diff
LOG: [mlir] Async: update condition for dispatching block-aligned compute function
+ compare block size with the unrollable inner dimension
+ reduce nesting in the code and simplify a bit IR building
Reviewed By: cota
Differential Revision: https://reviews.llvm.org/D120075
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index e596fc3e73488..c4ba141b9bca0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -779,10 +779,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// and we can elide dynamic loop boundaries, and give LLVM an opportunity to
// unroll the loops. The constant `512` is arbitrary, it should depend on
// how many iterations LLVM will typically decide to unroll.
- static constexpr int64_t maxIterations = 512;
+ static constexpr int64_t maxUnrollableIterations = 512;
// The number of inner loops with statically known number of iterations less
- // than the `maxIterations` value.
+ // than the `maxUnrollableIterations` value.
int numUnrollableLoops = 0;
auto getInt = [](IntegerAttr attr) { return attr ? attr.getInt() : 0; };
@@ -796,7 +796,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
numIterations[i] = tripCount * innerIterations;
// Update the number of inner loops that we can potentially unroll.
- if (innerIterations > 0 && innerIterations <= maxIterations)
+ if (innerIterations > 0 && innerIterations <= maxUnrollableIterations)
numUnrollableLoops++;
}
@@ -856,9 +856,6 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize);
Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1);
- ParallelComputeFunction notUnrollableParallelComputeFunction =
- createParallelComputeFunction(op, staticBounds, 0, rewriter);
-
// Dispatch parallel compute function using async recursive work splitting,
// or by submitting compute task sequentially from a caller thread.
auto doDispatch = asyncDispatch ? doAsyncDispatch : doSequentialDispatch;
@@ -869,42 +866,47 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// Compute the number of parallel compute blocks.
Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize);
- // Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
- bool staticShouldUnroll = numUnrollableLoops > 0;
- auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
+ // Dispatch parallel compute function without hints to unroll inner loops.
+ auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) {
+ ParallelComputeFunction compute =
+ createParallelComputeFunction(op, staticBounds, 0, rewriter);
+
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
+ doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts);
+ b.create<scf::YieldOp>();
+ };
+
+ // Dispatch parallel compute function with hints for unrolling inner loops.
+ auto dispatchBlockAligned = [&](OpBuilder &nestedBuilder, Location loc) {
+ ParallelComputeFunction compute = createParallelComputeFunction(
+ op, staticBounds, numUnrollableLoops, rewriter);
+
ImplicitLocOpBuilder b(loc, nestedBuilder);
- doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
- blockSize, blockCount, tripCounts);
+ // Align the block size to be a multiple of the statically known
+ // number of iterations in the inner loops.
+ Value numIters = b.create<arith::ConstantIndexOp>(
+ numIterations[op.getNumLoops() - numUnrollableLoops]);
+ Value alignedBlockSize = b.create<arith::MulIOp>(
+ b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
+ doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount,
+ tripCounts);
b.create<scf::YieldOp>();
};
- if (staticShouldUnroll) {
- Value dynamicShouldUnroll = b.create<arith::CmpIOp>(
- arith::CmpIPredicate::sge, blockSize,
- b.create<arith::ConstantIndexOp>(maxIterations));
-
- ParallelComputeFunction unrollableParallelComputeFunction =
- createParallelComputeFunction(op, staticBounds, numUnrollableLoops,
- rewriter);
-
- auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder b(loc, nestedBuilder);
- // Align the block size to be a multiple of the statically known
- // number of iterations in the inner loops.
- Value numIters = b.create<arith::ConstantIndexOp>(
- numIterations[op.getNumLoops() - numUnrollableLoops]);
- Value alignedBlockSize = b.create<arith::MulIOp>(
- b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
- doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
- alignedBlockSize, blockCount, tripCounts);
- b.create<scf::YieldOp>();
- };
-
- b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
- dispatchNotUnrollable);
+ // Dispatch to block aligned compute function only if the computed block
+ // size is larger than the number of iterations in the unrollable inner
+ // loops, because otherwise it can reduce the available parallelism.
+ if (numUnrollableLoops > 0) {
+ Value numIters = b.create<arith::ConstantIndexOp>(
+ numIterations[op.getNumLoops() - numUnrollableLoops]);
+ Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>(
+ arith::CmpIPredicate::sge, blockSize, numIters);
+
+ b.create<scf::IfOp>(TypeRange(), useBlockAlignedComputeFn,
+ dispatchBlockAligned, dispatchDefault);
b.create<scf::YieldOp>();
} else {
- dispatchNotUnrollable(b, loc);
+ dispatchDefault(b, loc);
}
};
diff --git a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
index 217e63bd67adf..8fc1c66e554fb 100644
--- a/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
+++ b/mlir/test/Dialect/Async/async-parallel-for-compute-fn.mlir
@@ -87,7 +87,7 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
return
}
-// CHECK-LABEL: func private @parallel_compute_fn(
+// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
@@ -100,12 +100,14 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: scf.for %[[I:arg[0-9]+]]
-// CHECK: arith.select
-// CHECK: scf.for %[[J:arg[0-9]+]]
-// CHECK: memref.store
+// CHECK-NOT: arith.select
+// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
-// CHECK-LABEL: func private @parallel_compute_fn_with_aligned_loops(
+// CHECK-LABEL: func private @parallel_compute_fn(
// CHECK-SAME: %[[BLOCK_INDEX:arg[0-9]+]]: index,
// CHECK-SAME: %[[BLOCK_SIZE:arg[0-9]+]]: index,
// CHECK-SAME: %[[TRIP_COUNT0:arg[0-9]+]]: index,
@@ -118,9 +120,7 @@ func @sink_constant_step(%arg0: memref<?x10xf32>, %lb: index, %ub: index) {
// CHECK-SAME: %[[STEP1:arg[0-9]+]]: index,
// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<?x10xf32>
// CHECK-SAME: ) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C10:.*]] = arith.constant 10 : index
// CHECK: scf.for %[[I:arg[0-9]+]]
-// CHECK-NOT: arith.select
-// CHECK: scf.for %[[J:arg[0-9]+]] = %c0 to %c10 step %c1
+// CHECK: arith.select
+// CHECK: scf.for %[[J:arg[0-9]+]]
+// CHECK: memref.store
More information about the Mlir-commits
mailing list