[Mlir-commits] [mlir] abe2dee - [mlir] NFC Async: always use 'b' for the current builder
Eugene Zhulenev
llvmlistbot at llvm.org
Wed Feb 16 21:21:00 PST 2022
Author: Eugene Zhulenev
Date: 2022-02-16T21:20:53-08:00
New Revision: abe2dee5ebb97403a953a8b71f8ffa8b72cff861
URL: https://github.com/llvm/llvm-project/commit/abe2dee5ebb97403a953a8b71f8ffa8b72cff861
DIFF: https://github.com/llvm/llvm-project/commit/abe2dee5ebb97403a953a8b71f8ffa8b72cff861.diff
LOG: [mlir] NFC Async: always use 'b' for the current builder
Currently some of the nested IR building inconsistently uses `nb` and `b`, it's very easy to call wrong builder outside of the current scope, so for simplicity all builders are always called `b`, and in nested IR building regions they just shadow the "parent" builder.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D120003
Added:
Modified:
mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index cdd85e5c5b40..e596fc3e7348 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -373,23 +373,23 @@ static ParallelComputeFunction createParallelComputeFunction(
LoopNestBuilder workLoopBuilder = [&](size_t loopIdx) -> LoopBodyBuilder {
return [&, loopIdx](OpBuilder &nestedBuilder, Location loc, Value iv,
ValueRange args) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
// Compute induction variable for `loopIdx`.
- computeBlockInductionVars[loopIdx] = nb.create<arith::AddIOp>(
- lowerBounds[loopIdx], nb.create<arith::MulIOp>(iv, steps[loopIdx]));
+ computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>(
+ lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx]));
// Check if we are inside first or last iteration of the loop.
- isBlockFirstCoord[loopIdx] = nb.create<arith::CmpIOp>(
+ isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]);
- isBlockLastCoord[loopIdx] = nb.create<arith::CmpIOp>(
+ isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]);
// Check if the previous loop is in its first or last iteration.
if (loopIdx > 0) {
- isBlockFirstCoord[loopIdx] = nb.create<arith::AndIOp>(
+ isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>(
isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]);
- isBlockLastCoord[loopIdx] = nb.create<arith::AndIOp>(
+ isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>(
isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]);
}
@@ -398,24 +398,24 @@ static ParallelComputeFunction createParallelComputeFunction(
if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) {
// For block aligned loops we always iterate starting from 0 up to
// the loop trip counts.
- nb.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(),
- workLoopBuilder(loopIdx + 1));
+ b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(),
+ workLoopBuilder(loopIdx + 1));
} else {
// Select nested loop lower/upper bounds depending on our position in
// the multi-dimensional iteration space.
- auto lb = nb.create<arith::SelectOp>(
- isBlockFirstCoord[loopIdx], blockFirstCoord[loopIdx + 1], c0);
+ auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx],
+ blockFirstCoord[loopIdx + 1], c0);
- auto ub = nb.create<arith::SelectOp>(isBlockLastCoord[loopIdx],
- blockEndCoord[loopIdx + 1],
- tripCounts[loopIdx + 1]);
+ auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx],
+ blockEndCoord[loopIdx + 1],
+ tripCounts[loopIdx + 1]);
- nb.create<scf::ForOp>(lb, ub, c1, ValueRange(),
- workLoopBuilder(loopIdx + 1));
+ b.create<scf::ForOp>(lb, ub, c1, ValueRange(),
+ workLoopBuilder(loopIdx + 1));
}
- nb.create<scf::YieldOp>(loc);
+ b.create<scf::YieldOp>(loc);
return;
}
@@ -425,7 +425,7 @@ static ParallelComputeFunction createParallelComputeFunction(
mapping.map(computeFuncType.captures, captures);
for (auto &bodyOp : op.getLoopBody().getOps())
- nb.clone(bodyOp, mapping);
+ b.clone(bodyOp, mapping);
};
};
@@ -602,38 +602,38 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1);
auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
// Call parallel compute function for the single block.
SmallVector<Value> operands = {c0, blockSize};
appendBlockComputeOperands(operands);
- nb.create<CallOp>(parallelComputeFunction.func.sym_name(),
- parallelComputeFunction.func.getCallableResults(),
- operands);
- nb.create<scf::YieldOp>();
+ b.create<CallOp>(parallelComputeFunction.func.sym_name(),
+ parallelComputeFunction.func.getCallableResults(),
+ operands);
+ b.create<scf::YieldOp>();
};
auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
// Create an async.group to wait on all async tokens from the concurrent
// execution of multiple parallel compute function. First block will be
// executed synchronously in the caller thread.
- Value groupSize = nb.create<arith::SubIOp>(blockCount, c1);
- Value group = nb.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
+ Value groupSize = b.create<arith::SubIOp>(blockCount, c1);
+ Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize);
// Launch async dispatch function for [0, blockCount) range.
SmallVector<Value> operands = {group, c0, blockCount, blockSize};
appendBlockComputeOperands(operands);
- nb.create<CallOp>(asyncDispatchFunction.sym_name(),
- asyncDispatchFunction.getCallableResults(), operands);
+ b.create<CallOp>(asyncDispatchFunction.sym_name(),
+ asyncDispatchFunction.getCallableResults(), operands);
// Wait for the completion of all parallel compute operations.
- nb.create<AwaitAllOp>(group);
+ b.create<AwaitAllOp>(group);
- nb.create<scf::YieldOp>();
+ b.create<scf::YieldOp>();
};
// Dispatch either single block compute function, or launch async dispatch.
@@ -680,7 +680,7 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
// Induction variable is the index of the block: [0, blockCount).
LoopBodyBuilder loopBuilder = [&](OpBuilder &loopBuilder, Location loc,
Value iv, ValueRange args) {
- ImplicitLocOpBuilder nb(loc, loopBuilder);
+ ImplicitLocOpBuilder b(loc, loopBuilder);
// Call parallel compute function inside the async.execute region.
auto executeBodyBuilder = [&](OpBuilder &executeBuilder,
@@ -692,10 +692,10 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter,
};
// Create async.execute operation to launch parallel computate function.
- auto execute = nb.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
- executeBodyBuilder);
- nb.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
- nb.create<scf::YieldOp>();
+ auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(),
+ executeBodyBuilder);
+ b.create<AddToGroupOp>(rewriter.getIndexType(), execute.token(), group);
+ b.create<scf::YieldOp>();
};
// Iterate over all compute blocks and launch parallel compute operations.
@@ -758,7 +758,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// Compute the parallel block size and dispatch concurrent tasks computing
// results for each block.
auto dispatch = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
// Collect statically known constants defining the loop nest in the parallel
// compute function. LLVM can't always push constants across the non-trivial
@@ -872,10 +872,10 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
// Unroll when numUnrollableLoops > 0 && blockSize >= maxIterations.
bool staticShouldUnroll = numUnrollableLoops > 0;
auto dispatchNotUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
doDispatch(b, rewriter, notUnrollableParallelComputeFunction, op,
blockSize, blockCount, tripCounts);
- nb.create<scf::YieldOp>();
+ b.create<scf::YieldOp>();
};
if (staticShouldUnroll) {
@@ -888,23 +888,23 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
rewriter);
auto dispatchUnrollable = [&](OpBuilder &nestedBuilder, Location loc) {
- ImplicitLocOpBuilder nb(loc, nestedBuilder);
+ ImplicitLocOpBuilder b(loc, nestedBuilder);
// Align the block size to be a multiple of the statically known
// number of iterations in the inner loops.
- Value numIters = nb.create<arith::ConstantIndexOp>(
+ Value numIters = b.create<arith::ConstantIndexOp>(
numIterations[op.getNumLoops() - numUnrollableLoops]);
- Value alignedBlockSize = nb.create<arith::MulIOp>(
- nb.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
+ Value alignedBlockSize = b.create<arith::MulIOp>(
+ b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters);
doDispatch(b, rewriter, unrollableParallelComputeFunction, op,
alignedBlockSize, blockCount, tripCounts);
- nb.create<scf::YieldOp>();
+ b.create<scf::YieldOp>();
};
b.create<scf::IfOp>(TypeRange(), dynamicShouldUnroll, dispatchUnrollable,
dispatchNotUnrollable);
- nb.create<scf::YieldOp>();
+ b.create<scf::YieldOp>();
} else {
- dispatchNotUnrollable(nb, loc);
+ dispatchNotUnrollable(b, loc);
}
};
More information about the Mlir-commits
mailing list