[Mlir-commits] [mlir] 3ac561d - [mlir][gpu] Add lowering to LLVM for `gpu.wait` and `gpu.wait async`.
Christian Sigg
llvmlistbot at llvm.org
Wed Oct 21 09:20:49 PDT 2020
Author: Christian Sigg
Date: 2020-10-21T18:20:42+02:00
New Revision: 3ac561d8c348a7bdc0313a268d5b3b4dcac118a2
URL: https://github.com/llvm/llvm-project/commit/3ac561d8c348a7bdc0313a268d5b3b4dcac118a2
DIFF: https://github.com/llvm/llvm-project/commit/3ac561d8c348a7bdc0313a268d5b3b4dcac118a2.diff
LOG: [mlir][gpu] Add lowering to LLVM for `gpu.wait` and `gpu.wait async`.
Reviewed By: herhut
Differential Revision: https://reviews.llvm.org/D89686
Added:
mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir
Modified:
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 36618384bb39..63a58fbc53f4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -91,6 +91,7 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
let summary = "Convert GPU dialect to LLVM dialect with GPU runtime calls";
let constructor = "mlir::createGpuToLLVMConversionPass()";
+ let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"gpuBinaryAnnotation", "gpu-binary-annotation", "std::string",
"", "Annotation attribute string for GPU binary">,
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index f7f5834e6351..9d4c0c32dc82 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -157,6 +157,34 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern
ConversionPatternRewriter &rewriter) const override;
};
+/// A rewrite pattern to convert gpu.wait operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertWaitOpToGpuRuntimeCallPattern
+ : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
+public:
+ ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
+
+private:
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+/// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertWaitAsyncOpToGpuRuntimeCallPattern
+ : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
+public:
+ ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
+
+private:
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// A rewrite patter to convert gpu.launch_func operations into a sequence of
/// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
///
@@ -257,6 +285,69 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
return success();
}
+// Converts `gpu.wait` to runtime calls. The operands are all CUDA or ROCm
+// streams (i.e. void*). The converted op synchronizes the host with every
+// stream and then destroys it. That is, it assumes that the stream is not used
+// afterwards. In case this isn't correct, we will get a runtime error.
+// Eventually, we will have a pass that guarantees this property.
+LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (cast<gpu::WaitOp>(op).asyncToken())
+ return failure(); // The gpu.wait is async.
+
+ Location loc = op->getLoc();
+
+ for (auto asyncDependency : operands)
+ streamSynchronizeCallBuilder.create(loc, rewriter, {asyncDependency});
+ for (auto asyncDependency : operands)
+ streamDestroyCallBuilder.create(loc, rewriter, {asyncDependency});
+
+ rewriter.eraseOp(op);
+ return success();
+}
+
+// Converts `gpu.wait async` to runtime calls. The result is a new stream that
+// is synchronized with all operands, which are CUDA or ROCm streams (i.e.
+// void*). We create and record an event after the definition of the stream
+// and make the new stream wait on that event before destroying it again. This
+// assumes that there is no other use between the definition and this op, and
+// the plan is to have a pass that guarantees this property.
+LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
+ Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ if (!cast<gpu::WaitOp>(op).asyncToken())
+ return failure(); // The gpu.wait is not async.
+
+ Location loc = op->getLoc();
+
+ auto insertionPoint = rewriter.saveInsertionPoint();
+ SmallVector<Value, 1> events;
+ for (auto pair : llvm::zip(op->getOperands(), operands)) {
+ auto token = std::get<0>(pair);
+ if (auto *defOp = token.getDefiningOp()) {
+ rewriter.setInsertionPointAfter(defOp);
+ } else {
+ // If we can't find the defining op, we record the event at block start,
+ // which is late and therefore misses parallelism, but still valid.
+ rewriter.setInsertionPointToStart(op->getBlock());
+ }
+ auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ auto stream = std::get<1>(pair);
+ eventRecordCallBuilder.create(loc, rewriter, {event, stream});
+ events.push_back(event);
+ }
+ rewriter.restoreInsertionPoint(insertionPoint);
+ auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
+ for (auto event : events)
+ streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
+ for (auto event : events)
+ eventDestroyCallBuilder.create(loc, rewriter, {event});
+ rewriter.replaceOp(op, {stream});
+
+ return success();
+}
+
// Creates a struct containing all kernel parameters on the stack and returns
// an array of type-erased pointers to the fields of the struct. The array can
// then be passed to the CUDA / ROCm (HIP) kernel launch calls.
@@ -411,7 +502,13 @@ mlir::createGpuToLLVMConversionPass(StringRef gpuBinaryAnnotation) {
void mlir::populateGpuToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
StringRef gpuBinaryAnnotation) {
- patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern>(converter);
+ converter.addConversion(
+ [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
+ return LLVM::LLVMType::getInt8PtrTy(context);
+ });
+ patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern,
+ ConvertWaitOpToGpuRuntimeCallPattern,
+ ConvertWaitAsyncOpToGpuRuntimeCallPattern>(converter);
patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
converter, gpuBinaryAnnotation);
patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());
diff --git a/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir
new file mode 100644
index 000000000000..b6eacfb969dd
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-wait-to-gpu-runtime-calls.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+ func @foo() {
+ // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+ // CHECK: %[[e0:.*]] = llvm.call @mgpuEventCreate
+ // CHECK: llvm.call @mgpuEventRecord(%[[e0]], %[[t0]])
+ %t0 = gpu.wait async
+ // CHECK: %[[t1:.*]] = llvm.call @mgpuStreamCreate
+ // CHECK: llvm.call @mgpuStreamWaitEvent(%[[t1]], %[[e0]])
+ // CHECK: llvm.call @mgpuEventDestroy(%[[e0]])
+ %t1 = gpu.wait async [%t0]
+ // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+ // CHECK: llvm.call @mgpuStreamSynchronize(%[[t1]])
+ // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+ // CHECK: llvm.call @mgpuStreamDestroy(%[[t1]])
+ gpu.wait [%t0, %t1]
+ return
+ }
+}
More information about the Mlir-commits
mailing list