[Mlir-commits] [mlir] 86888e4 - [mlir][sparse][gpu] generate proper memcpy in/out host and device
Aart Bik
llvmlistbot at llvm.org
Fri Apr 21 09:30:52 PDT 2023
Author: Aart Bik
Date: 2023-04-21T09:30:42-07:00
New Revision: 86888e420c41ebb07fa1a8818ea9af218b015fe3
URL: https://github.com/llvm/llvm-project/commit/86888e420c41ebb07fa1a8818ea9af218b015fe3
DIFF: https://github.com/llvm/llvm-project/commit/86888e420c41ebb07fa1a8818ea9af218b015fe3.diff
LOG: [mlir][sparse][gpu] generate proper memcpy in/out host and device
The host registration is a convenient way to get CUDA kernels
running, but it may be slow and does not work for all buffer
(like global constants). This revision uses the proper alloc
copy dealloc chains for buffers, using asynchronous chains
to increase overlap. The host registration mechanism is
kept under a flag for the output, just for experimentation
purposes while this project ramps up.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D148682
Added:
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-const.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 96346d97ebd00..c99c26b9c98cd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -76,24 +76,28 @@ static gpu::GPUFuncOp genGPUFunc(OpBuilder &builder, gpu::GPUModuleOp gpuModule,
}
/// Constructs code to launch GPU kernel.
-static void genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
- SmallVectorImpl<Value> &args,
- unsigned numThreads) {
+static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc,
+ SmallVectorImpl<Value> &args,
+ SmallVectorImpl<Value> &tokens,
+ unsigned numThreads) {
Location loc = gpuFunc->getLoc();
Value none = TypedValue<::mlir::IntegerType>{};
Value one = constantIndex(builder, loc, 1);
Value numT = constantIndex(builder, loc, numThreads);
gpu::KernelDim3 gridSize = {one, one, one};
gpu::KernelDim3 blckSize = {numT, one, one};
- builder.create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
- /*dynSharedMemSz*/ none, args);
+ return builder
+ .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize,
+ /*dynSharedMemSz*/ none, args,
+ builder.getType<gpu::AsyncTokenType>(), tokens)
+ .getAsyncToken();
}
/// Maps the provided ranked host buffer into the device address space.
/// Writes from the host are guaranteed to be visible to device kernels
/// that are launched afterwards. Writes from the device are guaranteed
/// to be visible on the host after synchronizing with the device kernel
-/// completion.
+/// completion. Needs to cast the buffer to a unranked buffer.
static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
Value mem) {
MemRefType memTp = mem.getType().cast<MemRefType>();
@@ -101,7 +105,122 @@ static Value genHostRegisterMemref(OpBuilder &builder, Location loc,
UnrankedMemRefType::get(memTp.getElementType(), /*memorySpace=*/0);
Value cast = builder.create<memref::CastOp>(loc, resTp, mem);
builder.create<gpu::HostRegisterOp>(loc, cast);
- return mem; // convenience pass-through
+ return cast;
+}
+
+/// Unmaps the provided buffer, expecting the casted buffer.
+static void genHostUnregisterMemref(OpBuilder &builder, Location loc,
+ Value cast) {
+ builder.create<gpu::HostUnregisterOp>(loc, cast);
+}
+
+/// Generates first wait in an asynchronous chain.
+static Value genFirstWait(OpBuilder &builder, Location loc) {
+ Type tokenType = builder.getType<gpu::AsyncTokenType>();
+ return builder.create<gpu::WaitOp>(loc, tokenType, ValueRange())
+ .getAsyncToken();
+}
+
+/// Generates last, blocking wait in an asynchronous chain.
+static void genBlockingWait(OpBuilder &builder, Location loc,
+ ValueRange operands) {
+ builder.create<gpu::WaitOp>(loc, Type(), operands);
+}
+
+/// Allocates memory on the device.
+/// TODO: A `host_shared` attribute could be used to indicate that
+/// the buffer is visible by both host and device, but lowering
+/// that feature does not seem to be fully supported yet.
+static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
+ Value token) {
+ auto tp = mem.getType().cast<ShapedType>();
+ auto elemTp = tp.getElementType();
+ auto shape = tp.getShape();
+ auto memTp = MemRefType::get(shape, elemTp);
+ SmallVector<Value> dynamicSizes;
+ for (unsigned r = 0, rank = tp.getRank(); r < rank; r++) {
+ if (shape[r] == ShapedType::kDynamic) {
+ Value dim = constantIndex(builder, loc, r);
+ Value dimOp = builder.create<memref::DimOp>(loc, mem, dim);
+ dynamicSizes.push_back(dimOp);
+ }
+ }
+ return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
+ token, dynamicSizes, ValueRange());
+}
+
+/// Deallocates memory from the device.
+static Value genDeallocMemRef(OpBuilder &builder, Location loc, Value mem,
+ Value token) {
+ return builder.create<gpu::DeallocOp>(loc, token.getType(), token, mem)
+ .getAsyncToken();
+}
+
+/// Copies memory between host and device (direction is implicit).
+static Value genCopyMemRef(OpBuilder &builder, Location loc, Value dst,
+ Value src, Value token) {
+ return builder.create<gpu::MemcpyOp>(loc, token.getType(), token, dst, src)
+ .getAsyncToken();
+}
+
+/// Prepares the outlined arguments, passing scalars and buffers in. Here we
+/// assume that the first buffer is the one allocated for output. We create
+/// a set of properly chained asynchronous allocation/copy pairs to increase
+/// overlap before launching the kernel.
+/// TODO: the output assumption may be a bit too brittle
+static Value genParametersIn(OpBuilder &builder, Location loc,
+ SmallVectorImpl<Value> &scalars,
+ SmallVectorImpl<Value> &buffers,
+ SmallVectorImpl<Value> &args,
+ SmallVectorImpl<Value> &tokens,
+ bool useHostRegistrationForOut) {
+ Value out;
+ // Scalars are passed by value.
+ for (Value s : scalars)
+ args.push_back(s);
+ // Buffers are need to be made visible on device.
+ for (Value b : buffers) {
+ if (useHostRegistrationForOut) {
+ out = genHostRegisterMemref(builder, loc, b);
+ args.push_back(b);
+ useHostRegistrationForOut = false;
+ continue;
+ }
+ Value firstToken = genFirstWait(builder, loc);
+ auto alloc = genAllocMemRef(builder, loc, b, firstToken);
+ Value devMem = alloc.getResult(0);
+ Value depToken = alloc.getAsyncToken(); // copy-after-alloc
+ args.push_back(devMem);
+ tokens.push_back(genCopyMemRef(builder, loc, devMem, b, depToken));
+ }
+ return out;
+}
+
+/// Finalizes the outlined arguments. The output buffer is copied depending
+/// on the kernel token and then deallocated. All other buffers are simply
+/// deallocated. Then we wait for all operations to complete.
+static void genParametersOut(OpBuilder &builder, Location loc, Value out,
+ Value kernelToken, SmallVectorImpl<Value> &scalars,
+ SmallVectorImpl<Value> &buffers,
+ SmallVectorImpl<Value> &args,
+ SmallVectorImpl<Value> &tokens) {
+ unsigned base = scalars.size();
+ for (unsigned i = base, e = args.size(); i < e; i++) {
+ Value firstToken;
+ if (i == base) {
+ // Assumed output parameter: unregister or copy-out.
+ if (out) {
+ genHostUnregisterMemref(builder, loc, out);
+ out = Value();
+ continue;
+ }
+ firstToken =
+ genCopyMemRef(builder, loc, buffers[0], args[i], kernelToken);
+ } else {
+ firstToken = genFirstWait(builder, loc);
+ }
+ tokens.push_back(genDeallocMemRef(builder, loc, args[i], firstToken));
+ }
}
/// Constructs code for new GPU kernel.
@@ -158,10 +277,8 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
/// Proof-of-concept rewriter. This rule generates a CUDA implementation
/// for each outermost forall loop generated by the sparse compiler.
-//
-// TODO: right works with parallelization-strategy=dense-outer-loop
-// but give this its own flags in the future
-//
+/// TODO: right works with parallelization-strategy=dense-outer-loop
+/// but give this its own flags in the future
struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
@@ -211,22 +328,34 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
else
return failure(); // don't know how to share
}
- // Prepare the outlined arguments, register buffers.
+ // Pass outlined non-constant values.
+ // TODO: Experiment with `useHostRegistrationForOut` to see if we want to
+ // keep the feature at all (either through a heuristic or compiler
+ // option for gpu codegen).
Location loc = forallOp->getLoc();
SmallVector<Value> args;
- for (Value s : scalars)
- args.push_back(s);
- for (Value b : buffers)
- args.push_back(genHostRegisterMemref(rewriter, loc, b));
- auto saveIp = rewriter.saveInsertionPoint();
+ SmallVector<Value> tokens;
+ Value out = genParametersIn(rewriter, loc, scalars, buffers, args, tokens,
+ /*useHostRegistrationForOut=*/false);
// Set up GPU module and construct GPU function.
+ auto saveIp = rewriter.saveInsertionPoint();
ModuleOp topModule = forallOp->getParentOfType<ModuleOp>();
auto gpuModule = genGPUModule(rewriter, topModule);
auto gpuFunc = genGPUFunc(rewriter, gpuModule, args);
genGPUCode(rewriter, gpuFunc, forallOp, constants, scalars, buffers);
- // Generate code that launches the kernel.
+ // Generate code that launches the kernel asynchronously, blocking on all
+ // opens tokens and yielding a new token for the output.
+ // TODO: Passing in tokens to launch up does not seem to be properly lowered
+ // by cubin yet, hence the current blocking wait.
rewriter.restoreInsertionPoint(saveIp);
- genLaunchGPUFunc(rewriter, gpuFunc, args, numThreads);
+ genBlockingWait(rewriter, loc, tokens);
+ tokens.clear();
+ Value kernelToken =
+ genLaunchGPUFunc(rewriter, gpuFunc, args, tokens, numThreads);
+ // Finalize the outlined arguments.
+ genParametersOut(rewriter, loc, out, kernelToken, scalars, buffers, args,
+ tokens);
+ genBlockingWait(rewriter, loc, tokens);
rewriter.eraseOp(forallOp);
return success();
}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
index ec7c30e9468a2..07d8c1ccf9d63 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_combi.mlir
@@ -7,12 +7,46 @@
//
// CHECK-LABEL: gpu.module @sparse_kernels
-// CHECK-DAG: gpu.func @kernel0
-// CHECK-DAG: gpu.func @kernel1
+// CHECK: gpu.func @kernel1
+// CHECK: gpu.func @kernel0
//
// CHECK-LABEL: func.func @matmuls
-// CHECK-DAG: gpu.launch_func @sparse_kernels::@kernel0 blocks
-// CHECK-DAG: gpu.launch_func @sparse_kernels::@kernel1 blocks
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: %[[T1:.*]] = gpu.launch_func async @sparse_kernels::@kernel1 blocks
+// CHECK: gpu.memcpy async [%[[T1]]]
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.wait
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: gpu.alloc async
+// CHECK: gpu.memcpy async
+// CHECK: %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
+// CHECK: gpu.memcpy async [%[[T0]]]
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.dealloc async
+// CHECK: gpu.wait
//
func.func @matmuls(%A: tensor<1024x8xf64>,
%B: tensor<8x1024xf64, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
index 92d59416b32b5..f770a941c6174 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul.mlir
@@ -47,12 +47,34 @@
//
//
// CHECK-LABEL: func.func @matmul
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.launch_func @sparse_kernels::@kernel0 blocks
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S0:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S1:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S2:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S3:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S4:.*]] = gpu.memcpy async
+// CHECK: gpu.wait [%[[S0]], %[[S1]], %[[S2]], %[[S3]], %[[S4]]
+// CHECK: %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
+// CHECK: %[[M0:.*]] = gpu.memcpy async [%[[T0]]]
+// CHECK: %[[M1:.*]] = gpu.dealloc async [%[[M0]]]
+// CHECK: %[[M2:.*]] = gpu.wait async
+// CHECK: %[[M3:.*]] = gpu.dealloc async [%[[M2]]]
+// CHECK: %[[M4:.*]] = gpu.wait async
+// CHECK: %[[M5:.*]] = gpu.dealloc async [%[[M4]]]
+// CHECK: %[[M6:.*]] = gpu.wait async
+// CHECK: %[[M7:.*]] = gpu.dealloc async [%[[M6]]]
+// CHECK: %[[M8:.*]] = gpu.wait async
+// CHECK: %[[M9:.*]] = gpu.dealloc async [%[[M8]]]
+// CHECK: gpu.wait [%[[M1]], %[[M3]], %[[M5]], %[[M7]], %[[M9]]
//
func.func @matmul(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64>, %C_in: tensor<?x?xf64>) -> tensor<?x?xf64> {
%C_out = linalg.matmul
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
index 05dfc5829c8c6..dd6f377f44db4 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matvec.mlir
@@ -43,12 +43,34 @@
// CHECK: }
//
// CHECK-LABEL: func.func @matvec
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.host_register
-// CHECK: gpu.launch_func @sparse_kernels::@kernel0 blocks
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S0:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S1:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S2:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S3:.*]] = gpu.memcpy async
+// CHECK: gpu.wait async
+// CHECK: gpu.alloc async
+// CHECK: %[[S4:.*]] = gpu.memcpy async
+// CHECK: gpu.wait [%[[S0]], %[[S1]], %[[S2]], %[[S3]], %[[S4]]
+// CHECK: %[[T0:.*]] = gpu.launch_func async @sparse_kernels::@kernel0 blocks
+// CHECK: %[[M0:.*]] = gpu.memcpy async [%[[T0]]]
+// CHECK: %[[M1:.*]] = gpu.dealloc async [%[[M0]]]
+// CHECK: %[[M2:.*]] = gpu.wait async
+// CHECK: %[[M3:.*]] = gpu.dealloc async [%[[M2]]]
+// CHECK: %[[M4:.*]] = gpu.wait async
+// CHECK: %[[M5:.*]] = gpu.dealloc async [%[[M4]]]
+// CHECK: %[[M6:.*]] = gpu.wait async
+// CHECK: %[[M7:.*]] = gpu.dealloc async [%[[M6]]]
+// CHECK: %[[M8:.*]] = gpu.wait async
+// CHECK: %[[M9:.*]] = gpu.dealloc async [%[[M8]]]
+// CHECK: gpu.wait [%[[M1]], %[[M3]], %[[M5]], %[[M7]], %[[M9]]
//
func.func @matvec(%A: tensor<?x?xf64, #CSR>, %x: tensor<?xf64>, %y_in: tensor<?xf64>) -> tensor<?xf64> {
%y_out = linalg.matvec
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-const.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-const.mlir
new file mode 100644
index 0000000000000..213f3b7890d6c
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matvec-const.mlir
@@ -0,0 +1,65 @@
+//
+// NOTE: this test requires gpu-sm80
+//
+// RUN: mlir-opt %s \
+// RUN: --sparse-compiler="enable-runtime-library=false parallelization-strategy=dense-outer-loop gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71" \
+// RUN: | mlir-cpu-runner \
+// RUN: --shared-libs=%mlir_cuda_runtime \
+// RUN: --shared-libs=%mlir_runner_utils \
+// RUN: --e main --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>
+
+module {
+ // Compute matrix vector y = Ax
+ func.func @matvec(%A: tensor<1024x64xf64, #CSR>, %x: tensor<64xf64>, %y_in: tensor<1024xf64>) -> tensor<1024xf64> {
+ %y_out = linalg.matvec
+ ins(%A, %x: tensor<1024x64xf64, #CSR>, tensor<64xf64>)
+ outs(%y_in: tensor<1024xf64>) -> tensor<1024xf64>
+ return %y_out : tensor<1024xf64>
+ }
+
+ memref.global "private" constant @__constant_64xf64 : memref<64xf64> = dense<1.000000e+00> {alignment = 64 : i64}
+
+ func.func @main() {
+ %f0 = arith.constant 0.0 : f64
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ // Stress test with a dense matrix DA.
+ %DA = tensor.generate {
+ ^bb0(%i: index, %j: index):
+ %k = arith.addi %i, %j : index
+ %l = arith.index_cast %k : index to i64
+ %f = arith.uitofp %l : i64 to f64
+ tensor.yield %f : f64
+ } : tensor<1024x64xf64>
+
+ // Convert to a "sparse" 1024 x 64 matrix A.
+ %A = sparse_tensor.convert %DA : tensor<1024x64xf64> to tensor<1024x64xf64, #CSR>
+
+ // Initialize dense vector to 1024 zeros.
+ %y = tensor.generate {
+ ^bb0(%i : index):
+ tensor.yield %f0 : f64
+ } : tensor<1024xf64>
+
+ // Call the kernel with an vector taken from global memory.
+ %xbuf = memref.get_global @__constant_64xf64 : memref<64xf64>
+ %x = bufferization.to_tensor %xbuf restrict : memref<64xf64>
+ %0 = call @matvec(%A, %x, %y) : (tensor<1024x64xf64, #CSR>, tensor<64xf64>, tensor<1024xf64>) -> tensor<1024xf64>
+
+ //
+ // Sanity check on results.
+ //
+ // CHECK: ( 2016, 2080, 2144, 2208, 2272, 2336, 2400, 2464, 2528, 2592, 2656, 2720, 2784, 2848, 2912, 2976, 3040, 3104, 3168, 3232, 3296, 3360, 3424, 3488, 3552, 3616, 3680, 3744, 3808, 3872, 3936, 4000, 4064, 4128, 4192, 4256, 4320, 4384, 4448, 4512, 4576, 4640, 4704, 4768, 4832, 4896, 4960, 5024, 5088, 5152, 5216, 5280, 5344, 5408, 5472, 5536, 5600, 5664, 5728, 5792, 5856, 5920, 5984, 6048 )
+ //
+ %pb0 = vector.transfer_read %0[%c0], %f0 : tensor<1024xf64>, vector<64xf64>
+ vector.print %pb0 : vector<64xf64>
+
+ // Release the resources.
+ bufferization.dealloc_tensor %A : tensor<1024x64xf64, #CSR>
+ return
+ }
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
index f2047b679eace..17f8928d2196f 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir
@@ -342,7 +342,7 @@ module attributes {gpu.container_module} {
vector.print %pb0 : vector<32xf16>
}
- // Maps the provided host buffer into the device address space.
+ // Maps the provided host buffers into the device address space.
// Writes from the host are guaranteed to be visible to device
// kernels that are launched afterwards. Writes from the device
// are guaranteed to be visible on the host after synchronizing
@@ -368,6 +368,12 @@ module attributes {gpu.container_module} {
%b : memref<8x32xf16>,
%c : memref<16x8xf16>)
+ // Unmaps the host buffers.
+ gpu.host_unregister %cast_a : memref<*xf16>
+ gpu.host_unregister %cast_m : memref<*xi16>
+ gpu.host_unregister %cast_b : memref<*xf16>
+ gpu.host_unregister %cast_c : memref<*xf16>
+
//
// Verify computed matrix C.
//
More information about the Mlir-commits
mailing list