[Mlir-commits] [mlir] fce99e5 - [mlir][gpu] Handle async in gpu.launch_func lowering.

Christian Sigg llvmlistbot at llvm.org
Thu Oct 29 14:16:51 PDT 2020


Author: Christian Sigg
Date: 2020-10-29T22:16:42+01:00
New Revision: fce99e5f739eddce3e52a9c4967fc8a993c5772d

URL: https://github.com/llvm/llvm-project/commit/fce99e5f739eddce3e52a9c4967fc8a993c5772d
DIFF: https://github.com/llvm/llvm-project/commit/fce99e5f739eddce3e52a9c4967fc8a993c5772d.diff

LOG: [mlir][gpu] Handle async in gpu.launch_func lowering.

For the synchronous case, destroy the stream after synchronization.

Sneak in a unrelated change to report why the gpu.wait conversion pattern didn't match.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D89933

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index a8a416d7843d..ac58b7a7d7f1 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -297,7 +297,7 @@ LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
   if (cast<gpu::WaitOp>(op).asyncToken())
-    return failure(); // The gpu.wait is async.
+    return rewriter.notifyMatchFailure(op, "Cannot convert async op.");
 
   Location loc = op->getLoc();
 
@@ -320,7 +320,7 @@ LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
   if (!cast<gpu::WaitOp>(op).asyncToken())
-    return failure(); // The gpu.wait is not async.
+    return rewriter.notifyMatchFailure(op, "Can only convert async op.");
 
   Location loc = op->getLoc();
 
@@ -440,6 +440,11 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
 // %5 = <see generateParamsArray>
 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
 // call %streamSynchronize(%4)
+// call %streamDestroy(%4)
+// call %moduleUnload(%1)
+//
+// If the op is async, the stream corresponds to the (single) async dependency
+// as well as the async token the op produces.
 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
@@ -448,6 +453,18 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
         op, "Cannot convert if operands aren't of LLVM type.");
 
   auto launchOp = cast<gpu::LaunchFuncOp>(op);
+
+  if (launchOp.asyncDependencies().size() > 1)
+    return rewriter.notifyMatchFailure(
+        op, "Cannot convert with more than one async dependency.");
+
+  // Fail when the synchronous version of the op has async dependencies. The
+  // lowering destroys the stream, and we do not want to check that there is no
+  // use of the stream after this op.
+  if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
+    return rewriter.notifyMatchFailure(
+        op, "Cannot convert non-async op with async dependencies.");
+
   Location loc = launchOp.getLoc();
 
   // Create an LLVM global with CUBIN extracted from the kernel annotation and
@@ -478,8 +495,11 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
       loc, rewriter, {module.getResult(0), kernelName});
   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                                 rewriter.getI32IntegerAttr(0));
-  // Grab the global stream needed for execution.
-  auto stream = streamCreateCallBuilder.create(loc, rewriter, {});
+  auto adaptor = gpu::LaunchFuncOpAdaptor(operands, op->getAttrDictionary());
+  Value stream =
+      adaptor.asyncDependencies().empty()
+          ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
+          : adaptor.asyncDependencies().front();
   // Create array of pointers to kernel arguments.
   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
@@ -487,15 +507,22 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
       loc, rewriter,
       {function.getResult(0), launchOp.gridSizeX(), launchOp.gridSizeY(),
        launchOp.gridSizeZ(), launchOp.blockSizeX(), launchOp.blockSizeY(),
-       launchOp.blockSizeZ(), zero, /* sharedMemBytes */
-       stream.getResult(0),         /* stream */
-       kernelParams,                /* kernel params */
-       nullpointer /* extra */});
-  streamSynchronizeCallBuilder.create(loc, rewriter, stream.getResult(0));
-  streamDestroyCallBuilder.create(loc, rewriter, stream.getResult(0));
+       launchOp.blockSizeZ(), /*sharedMemBytes=*/zero, stream, kernelParams,
+       /*extra=*/nullpointer});
+
+  if (launchOp.asyncToken()) {
+    // Async launch: make dependent ops use the same stream.
+    rewriter.replaceOp(op, {stream});
+  } else {
+    // Synchronize with host and destroy stream. This must be the stream created
+    // above (with no other uses) because we check that the synchronous version
+    // does not have any async dependencies.
+    streamSynchronizeCallBuilder.create(loc, rewriter, stream);
+    streamDestroyCallBuilder.create(loc, rewriter, stream);
+    rewriter.eraseOp(op);
+  }
   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
 
-  rewriter.eraseOp(op);
   return success();
 }
 


        


More information about the Mlir-commits mailing list