[Mlir-commits] [mlir] 361458b - [mlir] create gpu memset op
Christian Sigg
llvmlistbot at llvm.org
Fri Sep 3 23:13:18 PDT 2021
Author: Loren Maggiore
Date: 2021-09-04T08:13:04+02:00
New Revision: 361458b1ce890bd3f3575cca82b9bbb01d270e6f
URL: https://github.com/llvm/llvm-project/commit/361458b1ce890bd3f3575cca82b9bbb01d270e6f
DIFF: https://github.com/llvm/llvm-project/commit/361458b1ce890bd3f3575cca82b9bbb01d270e6f.diff
LOG: [mlir] create gpu memset op
Create a gpu memset op and corresponding CUDA and ROCm wrappers.
Reviewed By: herhut, lorenrose1013
Differential Revision: https://reviews.llvm.org/D107548
Added:
mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
mlir/test/Dialect/GPU/canonicalize.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index facea7e0523b..9290e161058b 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -901,6 +901,42 @@ def GPU_MemcpyOp : GPU_Op<"memcpy", [GPU_AsyncOpInterface]> {
let hasFolder = 1;
}
+def GPU_MemsetOp : GPU_Op<"memset",
+ [GPU_AsyncOpInterface, AllElementTypesMatch<["dst", "value"]>]> {
+
+ let summary = "GPU memset operation";
+
+ let description = [{
+ The `gpu.memset` operation sets the content of memref to a scalar value.
+
+ The op does not execute before all async dependencies have finished
+ executing.
+
+ If the `async` keyword is present, the op is executed asynchronously (i.e.
+ it does not block until the execution has finished on the device). In
+ that case, it returns a !gpu.async.token.
+
+ Example:
+
+ ```mlir
+ %token = gpu.memset async [%dep] %dst, %value : memref<?xf32, 1>, f32
+ ```
+ }];
+
+ let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+ Arg<AnyMemRef, "", [MemWrite]>:$dst,
+ Arg<AnyType, "">:$value);
+ let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+ let assemblyFormat = [{
+ custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+ $dst`,` $value `:` type($dst)`,` type($value) attr-dict
+ }];
+ // MemsetOp is fully verified by traits.
+ let verifier = [{ return success(); }];
+ let hasFolder = 1;
+}
+
def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix",
[MemoryEffects<[MemRead]>]>{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1234a9a14fbd..e4f7858d8d1d 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -79,6 +79,18 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
: ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
protected:
+ Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
+ MemRefType type, MemRefDescriptor desc) const {
+ return type.hasStaticShape()
+ ? ConvertToLLVMPattern::createIndexConstant(
+ rewriter, loc, type.getNumElements())
+ // For identity maps (verified by caller), the number of
+ // elements is stride[0] * size[0].
+ : rewriter.create<LLVM::MulOp>(loc,
+ desc.stride(rewriter, loc, 0),
+ desc.size(rewriter, loc, 0));
+ }
+
MLIRContext *context = &this->getTypeConverter()->getContext();
Type llvmVoidType = LLVM::LLVMVoidType::get(context);
@@ -165,6 +177,12 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
+ FunctionCallBuilder memsetCallBuilder = {
+ "mgpuMemset32",
+ llvmVoidType,
+ {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
+ llvmIntPtrType /* intptr_t sizeBytes */,
+ llvmPointerType /* void *stream */}};
};
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -308,6 +326,20 @@ class ConvertMemcpyOpToGpuRuntimeCallPattern
matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
+
+/// A rewrite pattern to convert gpu.memset operations into a GPU runtime
+/// call. Currently it supports CUDA and ROCm (HIP).
+class ConvertMemsetOpToGpuRuntimeCallPattern
+ : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
+public:
+ ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+ : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
+
+private:
+ LogicalResult
+ matchAndRewrite(gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
} // namespace
void GpuToLLVMConversionPass::runOnOperation() {
@@ -757,14 +789,7 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
MemRefDescriptor srcDesc(adaptor.src());
-
- Value numElements =
- memRefType.hasStaticShape()
- ? createIndexConstant(rewriter, loc, memRefType.getNumElements())
- // For identity layouts (verified above), the number of elements is
- // stride[0] * size[0].
- : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
- srcDesc.size(rewriter, loc, 0));
+ Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
Type elementPtrType = getElementPtrType(memRefType);
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
@@ -787,6 +812,40 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
return success();
}
+LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
+ gpu::MemsetOp memsetOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
+
+ if (failed(areAllLLVMTypes(memsetOp, operands, rewriter)) ||
+ !isConvertibleAndHasIdentityMaps(memRefType) ||
+ failed(isAsyncWithOneDependency(rewriter, memsetOp)))
+ return failure();
+
+ auto loc = memsetOp.getLoc();
+ auto adaptor = gpu::MemsetOpAdaptor(operands, memsetOp->getAttrDictionary());
+
+ Type valueType = adaptor.value().getType();
+ if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
+ return rewriter.notifyMatchFailure(memsetOp,
+ "value must be a 32 bit scalar");
+ }
+
+ MemRefDescriptor dstDesc(adaptor.dst());
+ Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
+
+ auto value =
+ rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
+ auto dst = rewriter.create<LLVM::BitcastOp>(
+ loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
+
+ auto stream = adaptor.asyncDependencies().front();
+ memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
+
+ rewriter.replaceOp(memsetOp, {stream});
+ return success();
+}
+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createGpuToLLVMConversionPass() {
return std::make_unique<GpuToLLVMConversionPass>();
@@ -803,6 +862,7 @@ void mlir::populateGpuToLLVMConversionPatterns(
ConvertDeallocOpToGpuRuntimeCallPattern,
ConvertHostRegisterOpToGpuRuntimeCallPattern,
ConvertMemcpyOpToGpuRuntimeCallPattern,
+ ConvertMemsetOpToGpuRuntimeCallPattern,
ConvertWaitAsyncOpToGpuRuntimeCallPattern,
ConvertWaitOpToGpuRuntimeCallPattern,
ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 166827672901..f148d6e4ad9e 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1079,6 +1079,11 @@ LogicalResult MemcpyOp::fold(ArrayRef<Attribute> operands,
return foldMemRefCast(*this);
}
+LogicalResult MemsetOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<::mlir::OpFoldResult> &results) {
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// GPU_AllocOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 8220f0305fee..9ee8057aa463 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -141,13 +141,19 @@ extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
}
-extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
CUstream stream) {
CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
reinterpret_cast<CUdeviceptr>(src),
sizeBytes, stream));
}
+extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
+ CUstream stream) {
+ CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
+ value, count, stream));
+}
+
/// Helper functions for writing mlir example code
// Allows to register byte array with the CUDA runtime. Helpful until we have
diff --git a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
index 399a37331060..92358ed38d9c 100644
--- a/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/RocmRuntimeWrappers.cpp
@@ -133,12 +133,17 @@ extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
HIP_REPORT_IF_ERROR(hipFree(ptr));
}
-extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes,
+extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
hipStream_t stream) {
HIP_REPORT_IF_ERROR(
hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
}
+extern "C" void mgpuMemset32(void *dst, int value, size_t count,
+ hipStream_t stream) {
+ HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
+ value, count, stream));
+}
/// Helper functions for writing mlir example code
// Allows to register byte array with the ROCM runtime. Helpful until we have
diff --git a/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
new file mode 100644
index 000000000000..1b786ed0f6e5
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-memset-to-gpu-runtime-calls.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+ // CHECK: func @foo
+ func @foo(%dst : memref<7xf32, 1>, %value : f32) {
+ // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
+ %t0 = gpu.wait async
+ // CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
+ // CHECK: %[[value:.*]] = llvm.bitcast
+ // CHECK: %[[dst:.*]] = llvm.bitcast
+ // CHECK: llvm.call @mgpuMemset32(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
+ %t1 = gpu.memset async [%t0] %dst, %value : memref<7xf32, 1>, f32
+ // CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
+ // CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
+ gpu.wait [%t1]
+ return
+ }
+}
diff --git a/mlir/test/Dialect/GPU/canonicalize.mlir b/mlir/test/Dialect/GPU/canonicalize.mlir
index 448e08be951b..38abfd172eea 100644
--- a/mlir/test/Dialect/GPU/canonicalize.mlir
+++ b/mlir/test/Dialect/GPU/canonicalize.mlir
@@ -6,7 +6,16 @@ func @memcpy_after_cast(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
// CHECK: gpu.memcpy
%0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
%1 = memref.cast %arg1 : memref<10xf32> to memref<?xf32>
- gpu.memcpy %0,%1 : memref<?xf32>, memref<?xf32>
+ gpu.memcpy %0, %1 : memref<?xf32>, memref<?xf32>
+ return
+}
+
+// CHECK-LABEL: @memset_after_cast
+func @memset_after_cast(%arg0: memref<10xf32>, %arg1: f32) {
+ // CHECK-NOT: memref.cast
+ // CHECK: gpu.memset
+ %0 = memref.cast %arg0 : memref<10xf32> to memref<?xf32>
+ gpu.memset %0, %arg1 : memref<?xf32>, f32
return
}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index a1ed9fa363a3..3c7a57d099e0 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -467,6 +467,13 @@ func @memcpy_incompatible_shape(%dst : memref<7xf32>, %src : memref<9xf32>) {
// -----
+func @memset_incompatible_shape(%dst : memref<?xf32>, %value : i32) {
+ // expected-error @+1 {{'gpu.memset' op failed to verify that all of {dst, value} have same element type}}
+ gpu.memset %dst, %value : memref<?xf32>, i32
+}
+
+// -----
+
func @mmamatrix_invalid_shape(){
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
%i = constant 16 : index
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 1bed13c4b21a..2c4a13d96d6d 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -195,6 +195,17 @@ module attributes {gpu.container_module} {
return
}
+ func @memset(%dst : memref<3x7xf32>, %value : f32) {
+ // CHECK-LABEL: func @memset
+ // CHECK: gpu.memset {{.*}}, {{.*}} : memref<3x7xf32>, f32
+ gpu.memset %dst, %value : memref<3x7xf32>, f32
+ // CHECK: %[[t0:.*]] = gpu.wait async
+ %0 = gpu.wait async
+ // CHECK: {{.*}} = gpu.memset async [%[[t0]]] {{.*}}, {{.*}} : memref<3x7xf32>, f32
+ %1 = gpu.memset async [%0] %dst, %value : memref<3x7xf32>, f32
+ return
+ }
+
func @mmamatrix_valid_element_type(){
// CHECK-LABEL: func @mmamatrix_valid_element_type
%wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
More information about the Mlir-commits
mailing list