[Mlir-commits] [mlir] fce99e5 - [mlir][gpu] Handle async in gpu.launch_func lowering.
Christian Sigg
llvmlistbot at llvm.org
Thu Oct 29 14:16:51 PDT 2020
Author: Christian Sigg
Date: 2020-10-29T22:16:42+01:00
New Revision: fce99e5f739eddce3e52a9c4967fc8a993c5772d
URL: https://github.com/llvm/llvm-project/commit/fce99e5f739eddce3e52a9c4967fc8a993c5772d
DIFF: https://github.com/llvm/llvm-project/commit/fce99e5f739eddce3e52a9c4967fc8a993c5772d.diff
LOG: [mlir][gpu] Handle async in gpu.launch_func lowering.
For the synchronous case, destroy the stream after synchronization.
Sneak in a unrelated change to report why the gpu.wait conversion pattern didn't match.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D89933
Added:
Modified:
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index a8a416d7843d..ac58b7a7d7f1 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -297,7 +297,7 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (cast<gpu::WaitOp>(op).asyncToken())
- return failure(); // The gpu.wait is async.
+ return rewriter.notifyMatchFailure(op, "Cannot convert async op.");
Location loc = op->getLoc();
@@ -320,7 +320,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (!cast<gpu::WaitOp>(op).asyncToken())
- return failure(); // The gpu.wait is not async.
+ return rewriter.notifyMatchFailure(op, "Can only convert async op.");
Location loc = op->getLoc();
@@ -440,6 +440,11 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
// %5 = <see generateParamsArray>
// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
// call %streamSynchronize(%4)
+// call %streamDestroy(%4)
+// call %moduleUnload(%1)
+//
+// If the op is async, the stream corresponds to the (single) async dependency
+// as well as the async token the op produces.
LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -448,6 +453,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
op, "Cannot convert if operands aren't of LLVM type.");
auto launchOp = cast<gpu::LaunchFuncOp>(op);
+
+ if (launchOp.asyncDependencies().size() > 1)
+ return rewriter.notifyMatchFailure(
+ op, "Cannot convert with more than one async dependency.");
+
+ // Fail when the synchronous version of the op has async dependencies. The
+ // lowering destroys the stream, and we do not want to check that there is no
+ // use of the stream after this op.
+ if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
+ return rewriter.notifyMatchFailure(
+ op, "Cannot convert non-async op with async dependencies.");
+
Location loc = launchOp.getLoc();
// Create an LLVM global with CUBIN extracted from the kernel annotation and
@@ -478,8 +495,11 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter, {module.getResult(0), kernelName});
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
rewriter.getI32IntegerAttr(0));
- // Grab the global stream needed for execution.
- auto stream = streamCreateCallBuilder.create(loc, rewriter, {});
+ auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary());
+ Value stream =
+ adaptor.asyncDependencies().empty()
+ ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
+ : adaptor.asyncDependencies().front();
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
@@ -487,15 +507,22 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
loc, rewriter,
{function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
- launchOp.blockSizeZ(), zero, /* sharedMemBytes */
- stream.getResult(0), /* stream */
- kernelParams, /* kernel params */
- nullpointer /* extra */});
- streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0));
- streamDestroyCallBuilder.create(loc, rewriter, stream.getResult(0));
+ launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams,
+ /*extra=*/nullpointer});
+
+ if (launchOp.asyncToken()) {
+ // Async launch: make dependent ops use the same stream.
+ rewriter.replaceOp(op, {stream});
+ } else {
+ // Synchronize with host and destroy stream. This must be the stream created
+ // above (with no other uses) because we check that the synchronous version
+ // does not have any async dependencies.
+ streamSynchronizeCallBuilder.create(loc, rewriter, stream);
+ streamDestroyCallBuilder.create(loc, rewriter, stream);
+ rewriter.eraseOp(op);
+ }
moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
- rewriter.eraseOp(op);
return success();
}
More information about the Mlir-commits
mailing list