[Mlir-commits] [mlir] f47a38f - Add async dependencies support for gpu.launch op

Uday Bondhugula llvmlistbot at llvm.org
Thu Apr 21 03:56:19 PDT 2022


Author: Uday Bondhugula
Date: 2022-04-21T16:25:59+05:30
New Revision: f47a38f51724fab217838aa09cb029c7e0392285

URL: https://github.com/llvm/llvm-project/commit/f47a38f51724fab217838aa09cb029c7e0392285
DIFF: https://github.com/llvm/llvm-project/commit/f47a38f51724fab217838aa09cb029c7e0392285.diff

LOG: Add async dependencies support for gpu.launch op

Add async dependencies support for gpu.launch op: this allows specifying
a list of async tokens ("streams") as dependencies for the launch.

Update the GPU kernel outlining pass lowering to propagate async
dependencies from gpu.launch to gpu.launch_func op. Previously, a new
stream was being created and destroyed for a kernel launch. The async
deps support allows the kernel launch to be serialized on an existing
stream.

Differential Revision: https://reviews.llvm.org/D123499

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Dialect/GPU/outlining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 91111bb8a76bf..407ab3765770f 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -420,7 +420,9 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
   let builders = [
     OpBuilder<(ins "GPUFuncOp":$kernelFunc, "KernelDim3":$gridSize,
       "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
-      "ValueRange":$kernelOperands)>
+      "ValueRange":$kernelOperands,
+      CArg<"Type", "nullptr">:$asyncTokenType,
+      CArg<"ValueRange", "{}">:$asyncDependencies)>
   ];
 
   let extraClassDeclaration = [{
@@ -466,25 +468,32 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
   let hasVerifier = 1;
 }
 
-def GPU_LaunchOp : GPU_Op<"launch", [AutomaticAllocationScope]>,
-    Arguments<(ins Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
+def GPU_LaunchOp : GPU_Op<"launch",
+    [AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface]>,
+    Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+               Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
                Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
                Optional<I32>:$dynamicSharedMemorySize)>,
-    Results<(outs)> {
+    Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
   let summary = "GPU kernel launch operation";
 
   let description = [{
     Launch a kernel on the specified grid of thread blocks. The body of the
     kernel is defined by the single region that this operation contains. The
-    operation takes six operands followed by an optional operand: the first
-    three operands are grid sizes along the x,y,z dimensions and the following
-    three are block sizes along the x,y,z dimensions. The last operand is
-    optional and corresponds to the amount of dynamic shared memory a kernel's
-    workgroup should be allocated; when this operand is not present, a zero size
-    is assumed.
-
-    When a lower-dimensional kernel is required, unused sizes must
-    be explicitly set to `1`.
+    operation takes an optional list of async dependencies followed by six
+    operands and an optional operand.
+
+    The `async` keyword indicates the kernel should be launched asynchronously;
+    the operation returns a new !gpu.async.token when the keyword is specified.
+    The kernel launched does not start executing until the ops producing its
+    async dependencies (optional operands) have completed.
+
+    The first three operands (following any async dependencies) are grid sizes
+    along the x,y,z dimensions and the following three are block sizes along the
+    x,y,z dimensions. When a lower-dimensional kernel is required, unused sizes
+    must be explicitly set to `1`.  The last operand is optional and corresponds
+    to the amount of dynamic shared memory a kernel's workgroup should be
+    allocated; when this operand is not present, a zero size is assumed.
 
     The body region has _twelve_ arguments, grouped as follows:
 
@@ -496,7 +505,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [AutomaticAllocationScope]>,
     Syntax:
 
     ```
-    operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment
+    operation ::= `gpu.launch` (`async` (`[` ssa-id-list `]`)? )?
+                             `block` `(` ssa-id-list `)` `in` ssa-reassignment
                              `threads` `(` ssa-id-list `)` `in` ssa-reassignment
                              (dynamic_shared_memory_size ssa-use)?
                              region attr-dict?
@@ -548,7 +558,9 @@ def GPU_LaunchOp : GPU_Op<"launch", [AutomaticAllocationScope]>,
     OpBuilder<(ins "Value":$gridSizeX, "Value":$gridSizeY,
       "Value":$gridSizeZ, "Value":$blockSizeX, "Value":$blockSizeY,
       "Value":$blockSizeZ,
-      CArg<"Value", "nullptr">:$dynamic_shared_memory_size)>
+      CArg<"Value", "nullptr">:$dynamicSharedMemorySize,
+      CArg<"Type", "nullptr">:$asyncTokenType,
+      CArg<"ValueRange", "{}">:$asyncDependencies)>
   ];
 
   let extraClassDeclaration = [{

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 76e477c2dd740..b87256dd994ea 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -275,6 +275,44 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
   return walkResult.wasInterrupted() ? failure() : success();
 }
 
+/// Parses an optional list of async operands with an optional leading keyword.
+/// (`async`)? (`[` ssa-id-list `]`)?
+///
+/// This method is used by the tablegen assembly format for async ops as well.
+static ParseResult parseAsyncDependencies(
+    OpAsmParser &parser, Type &asyncTokenType,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
+  auto loc = parser.getCurrentLocation();
+  if (succeeded(parser.parseOptionalKeyword("async"))) {
+    if (parser.getNumResults() == 0)
+      return parser.emitError(loc, "needs to be named when marked 'async'");
+    asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
+  }
+  return parser.parseOperandList(asyncDependencies,
+                                 OpAsmParser::Delimiter::OptionalSquare);
+}
+
+/// Prints optional async dependencies with its leading keyword.
+///   (`async`)? (`[` ssa-id-list `]`)?
+// Used by the tablegen assembly format for several async ops.
+static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
+                                   Type asyncTokenType,
+                                   OperandRange asyncDependencies) {
+  if (asyncTokenType)
+    printer << "async";
+  if (asyncDependencies.empty())
+    return;
+  if (asyncTokenType)
+    printer << ' ';
+  printer << '[';
+  llvm::interleaveComma(asyncDependencies, printer);
+  printer << ']';
+}
+
+//===----------------------------------------------------------------------===//
+// AllReduceOp
+//===----------------------------------------------------------------------===//
+
 LogicalResult gpu::AllReduceOp::verifyRegions() {
   if (body().empty() != op().hasValue())
     return emitError("expected either an op attribute or a non-empty body");
@@ -358,7 +396,12 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
 void LaunchOp::build(OpBuilder &builder, OperationState &result,
                      Value gridSizeX, Value gridSizeY, Value gridSizeZ,
                      Value blockSizeX, Value blockSizeY, Value blockSizeZ,
-                     Value dynamicSharedMemorySize) {
+                     Value dynamicSharedMemorySize, Type asyncTokenType,
+                     ValueRange asyncDependencies) {
+  result.addOperands(asyncDependencies);
+  if (asyncTokenType)
+    result.types.push_back(builder.getType<AsyncTokenType>());
+
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands(
       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -373,6 +416,11 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
   for (unsigned i = 0; i < kNumConfigRegionAttributes; ++i)
     body->addArgument(builder.getIndexType(), result.location);
   kernelRegion->push_back(body);
+  SmallVector<int32_t, 8> segmentSizes(8, 1);
+  segmentSizes.front() = asyncDependencies.size();
+  segmentSizes.back() = dynamicSharedMemorySize ? 1 : 0;
+  result.addAttribute(getOperandSegmentSizeAttr(),
+                      builder.getI32VectorAttr(segmentSizes));
 }
 
 KernelDim3 LaunchOp::getBlockIds() {
@@ -400,11 +448,13 @@ KernelDim3 LaunchOp::getBlockSize() {
 }
 
 KernelDim3 LaunchOp::getGridSizeOperandValues() {
-  return KernelDim3{getOperand(0), getOperand(1), getOperand(2)};
+  auto operands = getOperands().drop_front(asyncDependencies().size());
+  return KernelDim3{operands[0], operands[1], operands[2]};
 }
 
 KernelDim3 LaunchOp::getBlockSizeOperandValues() {
-  return KernelDim3{getOperand(3), getOperand(4), getOperand(5)};
+  auto operands = getOperands().drop_front(asyncDependencies().size());
+  return KernelDim3{operands[3], operands[4], operands[5]};
 }
 
 LogicalResult LaunchOp::verifyRegions() {
@@ -412,9 +462,9 @@ LogicalResult LaunchOp::verifyRegions() {
   // sizes and transforms them into kNumConfigRegionAttributes region arguments
   // for block/thread identifiers and grid/block sizes.
   if (!body().empty()) {
-    if (body().getNumArguments() != LaunchOp::kNumConfigOperands +
-                                        getNumOperands() -
-                                        (dynamicSharedMemorySize() ? 1 : 0))
+    if (body().getNumArguments() !=
+        LaunchOp::kNumConfigOperands + getNumOperands() -
+            (dynamicSharedMemorySize() ? 1 : 0) - asyncDependencies().size())
       return emitOpError("unexpected number of region arguments");
   }
 
@@ -435,6 +485,9 @@ LogicalResult LaunchOp::verifyRegions() {
     }
   }
 
+  if (getNumResults() == 0 && asyncToken())
+    return emitOpError("needs to be named when async keyword is specified");
+
   return success();
 }
 
@@ -451,6 +504,11 @@ static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
 }
 
 void LaunchOp::print(OpAsmPrinter &p) {
+  if (asyncToken()) {
+    p << " async";
+    if (!asyncDependencies().empty())
+      p << " [" << asyncDependencies() << ']';
+  }
   // Print the launch configuration.
   p << ' ' << getBlocksKeyword();
   printSizeAssignment(p, getGridSize(), getGridSizeOperandValues(),
@@ -464,7 +522,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
 
   p << ' ';
   p.printRegion(body(), /*printEntryBlockArgs=*/false);
-  p.printOptionalAttrDict((*this)->getAttrs());
+  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
+                              LaunchOp::getOperandSegmentSizeAttr()});
 }
 
 // Parse the size assignment blocks for blocks and threads.  These have the form
@@ -498,11 +557,10 @@ parseSizeAssignment(OpAsmParser &parser,
 }
 
 /// Parses a Launch operation.
-/// operation ::= `gpu.launch` `blocks` `(` ssa-id-list `)` `in`
-/// ssa-reassignment
-///                           `threads` `(` ssa-id-list `)` `in`
-///                           ssa-reassignment
-///                            region attr-dict?
+/// operation ::= `gpu.launch` (`async` `[` ssa-id-list `]`)?
+//        `blocks` `(` ssa-id-list `)` `in` ssa-reassignment
+///       `threads` `(` ssa-id-list `)` `in` ssa-reassignment
+///       region attr-dict?
 /// ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
 ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   // Sizes of the grid and block.
@@ -518,6 +576,17 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
       LaunchOp::kNumConfigRegionAttributes);
   MutableArrayRef<OpAsmParser::UnresolvedOperand> regionArgsRef(regionArgs);
 
+  // Parse optional async dependencies.
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> asyncDependencies;
+  Type asyncTokenType;
+  if (failed(
+          parseAsyncDependencies(parser, asyncTokenType, asyncDependencies)) ||
+      parser.resolveOperands(asyncDependencies, asyncTokenType,
+                             result.operands))
+    return failure();
+  if (parser.getNumResults() > 0)
+    result.types.push_back(asyncTokenType);
+
   // Parse the size assignment segments: the first segment assigns grid sizes
   // and defines values for block identifiers; the second segment assigns block
   // sizes and defines values for thread identifiers.  In the region argument
@@ -536,13 +605,16 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
     return failure();
 
   OpAsmParser::UnresolvedOperand dynamicSharedMemorySize;
+  bool hasDynamicSharedMemorySize = false;
   if (!parser.parseOptionalKeyword(
-          LaunchOp::getDynamicSharedMemorySizeKeyword()))
+          LaunchOp::getDynamicSharedMemorySizeKeyword())) {
+    hasDynamicSharedMemorySize = true;
     if (parser.parseOperand(dynamicSharedMemorySize) ||
         parser.resolveOperand(dynamicSharedMemorySize,
                               parser.getBuilder().getI32Type(),
                               result.operands))
       return failure();
+  }
 
   // Introduce the body region and parse it. The region has
   // kNumConfigRegionAttributes arguments that correspond to
@@ -551,8 +623,16 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<Type, LaunchOp::kNumConfigRegionAttributes> dataTypes(
       LaunchOp::kNumConfigRegionAttributes, index);
   Region *body = result.addRegion();
-  return failure(parser.parseRegion(*body, regionArgs, dataTypes) ||
-                 parser.parseOptionalAttrDict(result.attributes));
+  if (parser.parseRegion(*body, regionArgs, dataTypes) ||
+      parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  SmallVector<int32_t, 8> segmentSizes(8, 1);
+  segmentSizes.front() = asyncDependencies.size();
+  segmentSizes.back() = hasDynamicSharedMemorySize ? 1 : 0;
+  result.addAttribute(LaunchOp::getOperandSegmentSizeAttr(),
+                      parser.getBuilder().getI32VectorAttr(segmentSizes));
+  return success();
 }
 
 /// Simplify the gpu.launch when the range of a thread or block ID is
@@ -602,7 +682,12 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
                          KernelDim3 blockSize, Value dynamicSharedMemorySize,
-                         ValueRange kernelOperands) {
+                         ValueRange kernelOperands, Type asyncTokenType,
+                         ValueRange asyncDependencies) {
+  result.addOperands(asyncDependencies);
+  if (asyncTokenType)
+    result.types.push_back(builder.getType<AsyncTokenType>());
+
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x,
                       blockSize.y, blockSize.z});
@@ -615,7 +700,7 @@ void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          {SymbolRefAttr::get(kernelFunc.getNameAttr())});
   result.addAttribute(getKernelAttrName(), kernelSymbol);
   SmallVector<int32_t, 9> segmentSizes(9, 1);
-  segmentSizes.front() = 0; // Initially no async dependencies.
+  segmentSizes.front() = asyncDependencies.size();
   segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
   segmentSizes.back() = static_cast<int32_t>(kernelOperands.size());
   result.addAttribute(getOperandSegmentSizeAttr(),
@@ -1039,36 +1124,6 @@ LogicalResult MemcpyOp::verify() {
   return success();
 }
 
-static ParseResult parseAsyncDependencies(
-    OpAsmParser &parser, Type &asyncTokenType,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &asyncDependencies) {
-  auto loc = parser.getCurrentLocation();
-  if (succeeded(parser.parseOptionalKeyword("async"))) {
-    if (parser.getNumResults() == 0)
-      return parser.emitError(loc, "needs to be named when marked 'async'");
-    asyncTokenType = parser.getBuilder().getType<AsyncTokenType>();
-  }
-  return parser.parseOperandList(asyncDependencies,
-                                 OpAsmParser::Delimiter::OptionalSquare);
-}
-
-/// Prints optional async dependencies with its leading keyword.
-///   (`async`)? (`[` ssa-id-list `]`)?
-// Used by the tablegen assembly format for several async ops.
-static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op,
-                                   Type asyncTokenType,
-                                   OperandRange asyncDependencies) {
-  if (asyncTokenType)
-    printer << "async";
-  if (asyncDependencies.empty())
-    return;
-  if (asyncTokenType)
-    printer << ' ';
-  printer << '[';
-  llvm::interleaveComma(asyncDependencies, printer);
-  printer << ']';
-}
-
 namespace {
 
 /// Erases a common case of copy ops where a destination value is used only by

diff  --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 5b15c1d0f42f4..3019cc8bcfa49 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -225,10 +225,13 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
   OpBuilder builder(launchOp);
   // The launch op has an optional dynamic shared memory size. If it doesn't
   // exist, we use zero.
-  builder.create<gpu::LaunchFuncOp>(
+  Value asyncToken = launchOp.asyncToken();
+  auto launchFunc = builder.create<gpu::LaunchFuncOp>(
       launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
       launchOp.getBlockSizeOperandValues(), launchOp.dynamicSharedMemorySize(),
-      operands);
+      operands, asyncToken ? asyncToken.getType() : nullptr,
+      launchOp.asyncDependencies());
+  launchOp.replaceAllUsesWith(launchFunc);
   launchOp.erase();
 }
 

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 5360e8f7f8ced..fd94c81a05a25 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -4,7 +4,7 @@ func.func @not_enough_sizes(%sz : index) {
   // expected-error at +1 {{expected 6 or more operands, but found 5}}
   "gpu.launch"(%sz, %sz, %sz, %sz, %sz) ({
     gpu.return
-  }) : (index, index, index, index, index) -> ()
+  }) {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0]> : vector<8xi32>} : (index, index, index, index, index) -> ()
   return
 }
 
@@ -12,11 +12,11 @@ func.func @not_enough_sizes(%sz : index) {
 
 func.func @no_region_attrs(%sz : index) {
   // expected-error at +1 {{unexpected number of region arguments}}
- "gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
+  "gpu.launch"(%sz, %sz, %sz, %sz, %sz, %sz) ({
   ^bb1(%bx: index, %by: index, %bz: index,
        %tx: index, %ty: index, %tz: index):
     gpu.terminator
-  }) : (index, index, index, index, index, index) -> ()
+  }) {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0]> : vector<8xi32>} : (index, index, index, index, index, index) -> ()
   return
 }
 

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 541607ecbbf1c..d05785e2d001b 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -1,4 +1,8 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s
+// Verify the printed output can be parsed.
+// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+// Verify the generic form can be parsed.
+// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
 
 module attributes {gpu.container_module} {
 
@@ -26,6 +30,32 @@ module attributes {gpu.container_module} {
     return
   }
 
+  // CHECK-LABEL:func @launch_async(%{{.*}}: index, %{{.*}}: index) {
+  func @launch_async(%blk : index, %thrd : index) {
+    // CHECK: gpu.launch async [%{{.+}}] blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
+    %t = gpu.wait async
+    %name = gpu.launch async [%t] blocks(%arg0, %arg1, %arg2) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
+               threads(%arg3, %arg4, %arg5) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) {
+      gpu.terminator
+    }
+    return
+  }
+
+  // CHECK-LABEL:func @launch_async_no_deps(%{{.*}}: index, %{{.*}}: index) {
+  func @launch_async_no_deps(%blk : index, %thrd : index) {
+    // CHECK: %{{.*}} = gpu.launch async blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
+    %t0 = gpu.launch async blocks(%arg0, %arg1, %arg2) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
+               threads(%arg3, %arg4, %arg5) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) {
+      gpu.terminator
+    }
+    // CHECK: gpu.launch async blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}, %{{.*}} = %{{.*}})
+    %t1 = gpu.launch async [] blocks(%arg0, %arg1, %arg2) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk)
+               threads(%arg3, %arg4, %arg5) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) {
+      gpu.terminator
+    }
+    return
+  }
+
   gpu.module @kernels {
     gpu.func @kernel_1(%arg0 : f32, %arg1 : memref<?xf32, 1>) kernel {
       %tIdX = gpu.thread_id x

diff  --git a/mlir/test/Dialect/GPU/outlining.mlir b/mlir/test/Dialect/GPU/outlining.mlir
index 4a07225ac2288..89080c5871fbc 100644
--- a/mlir/test/Dialect/GPU/outlining.mlir
+++ b/mlir/test/Dialect/GPU/outlining.mlir
@@ -80,6 +80,26 @@ func.func @multiple_launches() {
                                            %block_z2 = %cst) {
     gpu.terminator
   }
+
+  // With async and async deps.
+  // CHECK: %[[TOKEN:.*]] = gpu.wait async
+  // CHECK: gpu.launch_func async [%[[TOKEN]]] @multiple_launches_kernel_1::@multiple_launches_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]])
+  %t = gpu.wait async
+  %u = gpu.launch async [%t] blocks(%bx2, %by2, %bz2) in (%grid_x2 = %cst, %grid_y2 = %cst,
+                                          %grid_z2 = %cst)
+             threads(%tx2, %ty2, %tz2) in (%block_x2 = %cst, %block_y2 = %cst,
+                                           %block_z2 = %cst) {
+    gpu.terminator
+  }
+
+  // CHECK: gpu.launch_func async @multiple_launches_kernel_2::@multiple_launches_kernel blocks in (%[[CST]], %[[CST]], %[[CST]]) threads in (%[[CST]], %[[CST]], %[[CST]])
+  %v = gpu.launch async blocks(%bx2, %by2, %bz2) in (%grid_x2 = %cst, %grid_y2 = %cst,
+                                     %grid_z2 = %cst)
+             threads(%tx2, %ty2, %tz2) in (%block_x2 = %cst, %block_y2 = %cst,
+                                           %block_z2 = %cst) {
+    gpu.terminator
+  }
+
   return
 }
 


        


More information about the Mlir-commits mailing list