[llvm-branch-commits] [mlir] 5535696 - [mlir] Add gpu.allocate, gpu.deallocate ops with LLVM lowering to runtime function calls.

Christian Sigg via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Nov 27 00:45:42 PST 2020


Author: Christian Sigg
Date: 2020-11-27T09:40:59+01:00
New Revision: 5535696c386ba89b66c1b5a72a2aa98783571cc9

URL: https://github.com/llvm/llvm-project/commit/5535696c386ba89b66c1b5a72a2aa98783571cc9
DIFF: https://github.com/llvm/llvm-project/commit/5535696c386ba89b66c1b5a72a2aa98783571cc9.diff

LOG: [mlir] Add gpu.allocate, gpu.deallocate ops with LLVM lowering to runtime function calls.

The ops are very similar to the std variants, but support async GPU execution.

gpu.alloc does not currently support an alignment attribute, and the new ops do not have
canonicalizers/folders like their std siblings do.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D91698

Added: 
    mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/test/Dialect/GPU/ops.mlir
    mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
    mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index 9828af7a31b0..99f388b4db51 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -19,6 +19,7 @@
 #include "mlir/IR/FunctionSupport.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 

diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index 593b735f01c9..33c00ca9b22c 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -804,4 +804,79 @@ def GPU_WaitOp : GPU_Op<"wait", [GPU_AsyncOpInterface]> {
   }];
 }
 
+def GPU_AllocOp : GPU_Op<"alloc", [
+    GPU_AsyncOpInterface,
+    AttrSizedOperandSegments,
+    MemoryEffects<[MemAlloc<DefaultResource>]>
+  ]> {
+
+  let summary = "GPU memory allocation operation.";
+  let description = [{
+    The `gpu.alloc` operation allocates a region of memory on the GPU. It is
+    similar to the `std.alloc` op, but supports asynchronous GPU execution.
+
+    The op does not execute before all async dependencies have finished
+    executing.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it also returns a !gpu.async.token.
+
+    Example:
+
+    ```mlir
+    %memref, %token = gpu.alloc async [%dep] (%width) : memref<64x?xf32, 1>
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   Variadic<Index>:$dynamicSizes, Variadic<Index>:$symbolOperands);
+  let results = (outs Res<AnyMemRef, "", [MemAlloc<DefaultResource>]>:$memref,
+                 Optional<GPU_AsyncToken>:$asyncToken);
+
+  let extraClassDeclaration = [{
+    MemRefType getType() { return memref().getType().cast<MemRefType>(); }
+  }];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) ` `
+    `(` $dynamicSizes `)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref)
+  }];
+}
+
+def GPU_DeallocOp : GPU_Op<"dealloc", [
+    GPU_AsyncOpInterface, MemoryEffects<[MemFree]>
+  ]> {
+
+  let summary = "GPU memory deallocation operation";
+
+  let description = [{
+    The `gpu.dealloc` operation frees the region of memory referenced by a
+    memref which was originally created by the `gpu.alloc` operation. It is
+    similar to the `std.dealloc` op, but supports asynchronous GPU execution.
+
+    The op does not execute before all async dependencies have finished
+    executing.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a !gpu.async.token.
+
+    Example:
+
+    ```mlir
+    %token = gpu.dealloc async [%dep] %memref : memref<8x64xf32, 1>
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   Arg<AnyMemRef, "", [MemFree]>:$memref);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $memref attr-dict `:` type($memref)
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index a046bb068d12..d625db95e976 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -142,6 +142,15 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmIntPtrType /* intptr_t rank */,
        llvmPointerType /* void *memrefDesc */,
        llvmIntPtrType /* intptr_t elementSizeBytes */}};
+  FunctionCallBuilder allocCallBuilder = {
+      "mgpuMemAlloc",
+      llvmPointerType /* void * */,
+      {llvmIntPtrType /* intptr_t sizeBytes */,
+       llvmPointerType /* void *stream */}};
+  FunctionCallBuilder deallocCallBuilder = {
+      "mgpuMemFree",
+      llvmVoidType,
+      {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -158,6 +167,34 @@ class ConvertHostRegisterOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertAllocOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
+public:
+  ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
+/// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertDeallocOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
+public:
+  ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
 /// call. Currently it supports CUDA and ROCm (HIP).
 class ConvertWaitOpToGpuRuntimeCallPattern
@@ -231,7 +268,6 @@ class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
     return success();
   }
 };
-
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
@@ -260,17 +296,35 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
       builder.getSymbolRefAttr(function), arguments);
 }
 
-// Returns whether value is of LLVM type.
-static bool isLLVMType(Value value) {
-  return value.getType().isa<LLVM::LLVMType>();
+// Returns whether all operands are of LLVM type.
+static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
+                                     ConversionPatternRewriter &rewriter) {
+  if (!llvm::all_of(operands, [](Value value) {
+        return value.getType().isa<LLVM::LLVMType>();
+      }))
+    return rewriter.notifyMatchFailure(
+        op, "Cannot convert if operands aren't of LLVM type.");
+  return success();
+}
+
+static LogicalResult
+isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
+                         gpu::AsyncOpInterface op) {
+  if (op.getAsyncDependencies().size() != 1)
+    return rewriter.notifyMatchFailure(
+        op, "Can only convert with exactly one async dependency.");
+
+  if (!op.getAsyncToken())
+    return rewriter.notifyMatchFailure(op, "Can convert only async version.");
+
+  return success();
 }
 
 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (!llvm::all_of(operands, isLLVMType))
-    return rewriter.notifyMatchFailure(
-        op, "Cannot convert if operands aren't of LLVM type.");
+  if (failed(areAllLLVMTypes(op, operands, rewriter)))
+    return failure();
 
   Location loc = op->getLoc();
 
@@ -287,6 +341,71 @@ LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  auto allocOp = cast<gpu::AllocOp>(op);
+  MemRefType memRefType = allocOp.getType();
+
+  if (failed(areAllLLVMTypes(op, operands, rewriter)) ||
+      !isSupportedMemRefType(memRefType) ||
+      failed(
+          isAsyncWithOneDependency(rewriter, cast<gpu::AsyncOpInterface>(op))))
+    return failure();
+
+  auto loc = op->getLoc();
+
+  // Get shape of the memref as values: static sizes are constant
+  // values and dynamic sizes are passed to 'alloc' as operands.
+  SmallVector<Value, 4> shape;
+  SmallVector<Value, 4> strides;
+  Value sizeBytes;
+  getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, shape, strides,
+                           sizeBytes);
+
+  // Allocate the underlying buffer and store a pointer to it in the MemRef
+  // descriptor.
+  Type elementPtrType = this->getElementPtrType(memRefType);
+  auto adaptor = gpu::AllocOpAdaptor(operands, op->getAttrDictionary());
+  auto stream = adaptor.asyncDependencies().front();
+  Value allocatedPtr =
+      allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
+  allocatedPtr =
+      rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
+
+  // No alignment.
+  Value alignedPtr = allocatedPtr;
+
+  // Create the MemRef descriptor.
+  auto memRefDescriptor = this->createMemRefDescriptor(
+      loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
+
+  rewriter.replaceOp(op, {memRefDescriptor, stream});
+
+  return success();
+}
+
+LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
+    Operation *op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, operands, rewriter)) ||
+      failed(
+          isAsyncWithOneDependency(rewriter, cast<gpu::AsyncOpInterface>(op))))
+    return failure();
+
+  Location loc = op->getLoc();
+
+  auto adaptor = gpu::DeallocOpAdaptor(operands, op->getAttrDictionary());
+  Value pointer =
+      MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
+  auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
+  Value stream = adaptor.asyncDependencies().front();
+  deallocCallBuilder.create(loc, rewriter, {casted, stream});
+
+  rewriter.replaceOp(op, {stream});
+  return success();
+}
+
 // Converts `gpu.wait` to runtime calls. The operands are all CUDA or ROCm
 // streams (i.e. void*). The converted op synchronizes the host with every
 // stream and then destroys it. That is, it assumes that the stream is not used
@@ -447,9 +566,8 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
     Operation *op, ArrayRef<Value> operands,
     ConversionPatternRewriter &rewriter) const {
-  if (!llvm::all_of(operands, isLLVMType))
-    return rewriter.notifyMatchFailure(
-        op, "Cannot convert if operands aren't of LLVM type.");
+  if (failed(areAllLLVMTypes(op, operands, rewriter)))
+    return failure();
 
   auto launchOp = cast<gpu::LaunchFuncOp>(op);
 
@@ -537,9 +655,11 @@ void mlir::populateGpuToLLVMConversionPatterns(
       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
         return LLVM::LLVMType::getInt8PtrTy(context);
       });
-  patterns.insert<ConvertHostRegisterOpToGpuRuntimeCallPattern,
-                  ConvertWaitOpToGpuRuntimeCallPattern,
-                  ConvertWaitAsyncOpToGpuRuntimeCallPattern>(converter);
+  patterns.insert<ConvertAllocOpToGpuRuntimeCallPattern,
+                  ConvertDeallocOpToGpuRuntimeCallPattern,
+                  ConvertHostRegisterOpToGpuRuntimeCallPattern,
+                  ConvertWaitAsyncOpToGpuRuntimeCallPattern,
+                  ConvertWaitOpToGpuRuntimeCallPattern>(converter);
   patterns.insert<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
       converter, gpuBinaryAnnotation);
   patterns.insert<EraseGpuModuleOpPattern>(&converter.getContext());

diff  --git a/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir
new file mode 100644
index 000000000000..06ccd1e8f4ee
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-alloc-to-gpu-runtime-calls.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+  func @main() {
+    // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate()
+    %0 = gpu.wait async
+    // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+    // CHECK: llvm.call @mgpuMemAlloc(%[[size_bytes]], %[[stream]])
+    %1, %2 = gpu.alloc async [%0] () : memref<13xf32>
+    // CHECK: %[[float_ptr:.*]] = llvm.extractvalue {{.*}}[0]
+    // CHECK: %[[void_ptr:.*]] = llvm.bitcast %[[float_ptr]]
+    // CHECK: llvm.call @mgpuMemFree(%[[void_ptr]], %[[stream]])
+    %3 = gpu.dealloc async [%2] %1 : memref<13xf32>
+    // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]])
+    // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]])
+    gpu.wait [%3]
+    return
+  }
+}

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index a3b781afdfbc..aed4368c22a7 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -144,6 +144,23 @@ module attributes {gpu.container_module} {
     } ) {gpu.kernel, sym_name = "kernel_1", type = (f32, memref<?xf32>) -> (), workgroup_attributions = 1: i64} : () -> ()
   }
 
+  func @alloc() {
+    // CHECK-LABEL: func @alloc()
+
+    // CHECK: %[[m0:.*]] = gpu.alloc () : memref<13xf32, 1>
+    %m0 = gpu.alloc () : memref<13xf32, 1>
+    // CHECK: gpu.dealloc %[[m0]] : memref<13xf32, 1>
+    gpu.dealloc %m0 : memref<13xf32, 1>
+
+    %t0 = gpu.wait async
+    // CHECK: %[[m1:.*]], %[[t1:.*]] = gpu.alloc async [{{.*}}] () : memref<13xf32, 1>
+    %m1, %t1 = gpu.alloc async [%t0] () : memref<13xf32, 1>
+    // CHECK: gpu.dealloc async [%[[t1]]] %[[m1]] : memref<13xf32, 1>
+    %t2 = gpu.dealloc async [%t1] %m1 : memref<13xf32, 1>
+
+    return
+  }
+
   func @async_token(%arg0 : !gpu.async.token) -> !gpu.async.token {
     // CHECK-LABEL: func @async_token({{.*}}: !gpu.async.token)
     // CHECK: return {{.*}} : !gpu.async.token

diff  --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
index 917d203298c5..a6729b1c0b7d 100644
--- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp
@@ -107,6 +107,16 @@ extern "C" void mgpuEventRecord(CUevent event, CUstream stream) {
   CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
 }
 
+extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) {
+  CUdeviceptr ptr;
+  CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
+  return reinterpret_cast<void *>(ptr);
+}
+
+extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
+  CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
+}
+
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the CUDA runtime. Helpful until we have

diff  --git a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
index 882a4a34d2dc..aad7ae27ff89 100644
--- a/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
+++ b/mlir/tools/mlir-rocm-runner/rocm-runtime-wrappers.cpp
@@ -108,6 +108,16 @@ extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
   HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
 }
 
+extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
+  void *ptr;
+  HIP_REPORT_IF_ERROR(hipMemAlloc(&ptr, sizeBytes));
+  return ptr;
+}
+
+extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
+  HIP_REPORT_IF_ERROR(hipMemFree(ptr));
+}
+
 /// Helper functions for writing mlir example code
 
 // Allows to register byte array with the ROCM runtime. Helpful until we have


        


More information about the llvm-branch-commits mailing list