[Mlir-commits] [mlir] 84718d3 - [MLIR][GPU] Add gpu.set_default_device op
Krzysztof Drewniak
llvmlistbot at llvm.org
Thu Feb 17 13:30:20 PST 2022
Author: Krzysztof Drewniak
Date: 2022-02-17T21:30:09Z
New Revision: 84718d37db577f57514df6ac544e3db88aa75684
URL: https://github.com/llvm/llvm-project/commit/84718d37db577f57514df6ac544e3db88aa75684
DIFF: https://github.com/llvm/llvm-project/commit/84718d37db577f57514df6ac544e3db88aa75684.diff
LOG: [MLIR][GPU] Add gpu.set_default_device op
This op is added to allow MLIR code running on multi-GPU systems to
select the GPU they want to execute operations on when no GPU is
otherwise specified.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D119883
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
mlir/test/Dialect/GPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index d8236eafa9cf1..5d25892175b90 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -273,7 +273,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
/// Returns the type of this function.
/// FIXME: We should drive this via the ODS `type` param.
- FunctionType getType() {
+ FunctionType getType() {
return getTypeAttr().getValue().cast<FunctionType>();
}
@@ -1006,6 +1006,18 @@ def GPU_MemsetOp : GPU_Op<"memset",
let hasFolder = 1;
}
+def GPU_SetDefaultDeviceOp : GPU_Op<"set_default_device",
+ [MemoryEffects<[MemWrite]>]>,
+ Arguments<(ins I32:$devIndex)> {
+ let summary = "Set default GPU for operations after this by index";
+ let description = [{
+ Operation that sets the current default GPU, using a zero-based index
+ into the set of GPUs on the system. The default GPU setting may be
+ thread-local.
+ }];
+ let assemblyFormat = "attr-dict $devIndex";
+}
+
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
[MemoryEffects<[MemRead]>]>{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index a30bfaf5cce4d..1aa12500c5716 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -185,6 +185,10 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
+ FunctionCallBuilder setDefaultDeviceCallBuilder = {
+ "mgpuSetDefaultDevice",
+ llvmVoidType,
+ {llvmInt32Type /* uint32_t devIndex */}};
};
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -342,6 +346,21 @@ class ConvertMemsetOpToGpuRuntimeCallPattern
matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
+
+/// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
+/// Currently supports CUDA and ROCm (HIP)
+class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
+ : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
+public:
+ ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
+ LLVMTypeConverter &typeConverter)
+ : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
+ typeConverter) {}
+
+ LogicalResult
+ matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
} // namespace
void GpuToLLVMConversionPass::runOnOperation() {
@@ -844,6 +863,15 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
return success();
}
+LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
+ gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()});
+ rewriter.replaceOp(op, {});
+ return success();
+}
+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createGpuToLLVMConversionPass() {
return std::make_unique<GpuToLLVMConversionPass>();
@@ -861,6 +889,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
ConvertHostRegisterOpToGpuRuntimeCallPattern,
ConvertMemcpyOpToGpuRuntimeCallPattern,
ConvertMemsetOpToGpuRuntimeCallPattern,
+ ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 9fb3c100feaed..dd66056289cec 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -35,16 +35,20 @@
fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
}(expr)
-// Make the primary context of device 0 current for the duration of the instance
-// and restore the previous context on destruction.
+thread_local static int32_t defaultDevice = 0;
+
+// Make the primary context of the current default device current for the
+// duration
+// of the instance and restore the previous context on destruction.
class ScopedContext {
public:
ScopedContext() {
- // Static reference to CUDA primary context for device ordinal 0.
+ // Static reference to CUDA primary context for device ordinal
+ // defaultDevice.
static CUcontext context = [] {
CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
CUdevice device;
- CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0));
+ CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
CUcontext ctx;
// Note: this does not affect the current context.
CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
@@ -187,3 +191,8 @@ mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
mgpuMemHostRegister(ptr, sizeBytes);
}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
+ defaultDevice = device;
+ CUDA_REPORT_IF_ERROR(cudaSetDevice(device));
+}
diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
index 92358ed38d9cb..34363ccc61416 100644
--- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
@@ -30,16 +30,18 @@
fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
}(expr)
+thread_local static int32_t defaultDevice = 0;
+
// Sets the `Context` for the duration of the instance and restores the previous
// context on destruction.
class ScopedContext {
public:
ScopedContext() {
- // Static reference to HIP primary context for device ordinal 0.
+ // Static reference to HIP primary context for device ordinal defaultDevice.
static hipCtx_t context = [] {
HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
hipDevice_t device;
- HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/0));
+ HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/defaultDevice));
hipCtx_t ctx;
HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&ctx, device));
return ctx;
@@ -199,3 +201,8 @@ mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
mgpuMemGetDevicePointer(aligned, &devicePtr);
return {devicePtr, devicePtr, offset, {size}, {stride}};
}
+
+extern "C" void mgpuSetDefaultDevice(int32_t device) {
+ defaultDevice = device;
+ HIP_REPORT_IF_ERROR(hipSetDevice(device));
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index c1c5ff5570832..c317dbc930480 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -252,4 +252,11 @@ module attributes {gpu.container_module} {
gpu.device_async_wait %token {numGroups = 1 : i32}
return
}
+
+ // CHECK-LABEL: func @set_default_device
+ func @set_default_device(%arg0: i32) {
+ // CHECK: gpu.set_default_device
+ gpu.set_default_device %arg0
+ return
+ }
}
More information about the Mlir-commits
mailing list