[Mlir-commits] [mlir] ac30f48 - [mlir][sparse][gpu]fix various cusparseLt bugs
Kun Wu
llvmlistbot at llvm.org
Mon Jun 12 16:48:58 PDT 2023
Author: Kun Wu
Date: 2023-06-12T23:48:49Z
New Revision: ac30f48e377b80a83953e346054ffe3288276713
URL: https://github.com/llvm/llvm-project/commit/ac30f48e377b80a83953e346054ffe3288276713
DIFF: https://github.com/llvm/llvm-project/commit/ac30f48e377b80a83953e346054ffe3288276713.diff
LOG: [mlir][sparse][gpu]fix various cusparseLt bugs
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D152489
Added:
Modified:
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 689a705350c75..8e922f54b5e27 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -263,36 +263,30 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
{llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
llvmPointerType, llvmInt32Type, llvmPointerType,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder spMMBufferSizeCallBuilder = {
+ FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
"mgpuSpMMBufferSize",
llvmIntPtrType,
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder spMMCallBuilder = {
+ FunctionCallBuilder createSpMMCallBuilder = {
"mgpuSpMM",
llvmVoidType,
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder SDDMMBufferSizeCallBuilder = {
+ FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
"mgpuSDDMMBufferSize",
llvmIntPtrType,
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder SDDMMCallBuilder = {
+ FunctionCallBuilder createSDDMMCallBuilder = {
"mgpuSDDMM",
llvmVoidType,
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
llvmPointerType /* void *stream */}};
- FunctionCallBuilder AssertSparseLTEnvHandleSizeCallBuilder = {
- "mgpuAssertSparseLTEnvHandleSize", llvmVoidType, {}};
- FunctionCallBuilder AssertSparseLTSpMatHandleSizeCallBuilder = {
- "mgpuAssertSparseLTSpMatHandleSize", llvmVoidType, {}};
- FunctionCallBuilder AssertSparseLTDnMatHandleSizeCallBuilder = {
- "mgpuAssertSparseLtDnMatHandleSize", llvmVoidType, {}};
FunctionCallBuilder createSparseLtEnvCallBuilder = {
"mgpuCreateSparseLtEnv",
llvmVoidType,
@@ -319,16 +313,17 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
llvmVoidType,
{llvmPointerType, llvmPointerType, llvmIntPtrType, llvmIntPtrType,
llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
- FunctionCallBuilder cuSparseLtSpmmBufferSizeBuilder = {
+ FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
"mgpuCuSparseLtSpMMBufferSize",
llvmVoidType,
- {llvmPointerType, llvmPointerType, llvmPointerType,
+ {llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
+ llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type,
llvmPointerType /*void *stream*/}};
- FunctionCallBuilder cuSparseLtSpmmBuilder = {
+ FunctionCallBuilder createCuSparseLtSpMMBuilder = {
"mgpuCuSparseLtSpMM",
llvmVoidType,
{llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
- llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
+ llvmPointerType, llvmPointerType, llvmPointerType,
llvmPointerType /*void *stream*/}};
};
@@ -1397,13 +1392,6 @@ static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
static_cast<int32_t>(TValue));
}
-static Value genConstInt32FromComputeMode(OpBuilder &builder, Location loc,
- Type computeType) {
- auto computeTypeInt = getCuSparseDataTypeFrom(computeType);
- auto computeTypeConst = genConstInt32From(builder, loc, computeTypeInt);
- return computeTypeConst;
-}
-
LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateSparseEnvOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1416,8 +1404,7 @@ LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
// 2:4 sparsity
Value handle;
if (isSpMMCusparseLtOp(op.getEnv())) {
- // Assert the size is 11024 bytes
- AssertSparseLTEnvHandleSizeCallBuilder.create(loc, rewriter, {});
+ // CUDA runner asserts the size is 11024 bytes.
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(11024));
handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1484,7 +1471,6 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
if (dims.size() == 2) {
if (isSpMMCusparseLtOp(op.getDnTensor())) {
auto envHandle = adaptor.getEnv();
- AssertSparseLTDnMatHandleSizeCallBuilder.create(loc, rewriter, {});
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(11032));
handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1660,10 +1646,10 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
Type dType =
llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
- auto dtp = genConstInt32From(rewriter, loc, getCuSparseLtDataTypeFrom(dType));
+ auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto envHandle = adaptor.getEnv();
- AssertSparseLTSpMatHandleSizeCallBuilder.create(loc, rewriter, {});
+ // CUDA runner asserts the size is 44104 bytes.
auto handleSz = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(44104));
Value handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1707,8 +1693,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
auto bufferSize =
spMVBufferSizeCallBuilder
@@ -1728,8 +1714,8 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1753,10 +1739,10 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
Value bufferSize;
if (is2To4Sparsity(op.getSpmatA())) {
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(3));
bufferSize = rewriter.create<LLVM::AllocaOp>(loc, llvmInt64PointerType,
@@ -1764,13 +1750,17 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
bufferSize =
rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, bufferSize);
- cuSparseLtSpmmBufferSizeBuilder
+ createCuSparseLtSpMMBufferSizeBuilder
.create(loc, rewriter,
- {bufferSize, adaptor.getEnv(), adaptor.getSpmatA(), stream})
+ {bufferSize, adaptor.getEnv(), modeA, modeB,
+ adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
+ computeType, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
} else {
- bufferSize = spMMBufferSizeCallBuilder
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
+ bufferSize = createSpMMBufferSizeCallBuilder
.create(loc, rewriter,
{adaptor.getEnv(), modeA, modeB,
adaptor.getSpmatA(), adaptor.getDnmatB(),
@@ -1790,10 +1780,10 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
- auto bufferSize = SDDMMBufferSizeCallBuilder
+ auto bufferSize = createSDDMMBufferSizeCallBuilder
.create(loc, rewriter,
{adaptor.getEnv(), modeA, modeB,
adaptor.getDnmatA(), adaptor.getDnmatB(),
@@ -1812,8 +1802,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
Location loc = op.getLoc();
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto stream = adaptor.getAsyncDependencies().front();
@@ -1826,20 +1816,19 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
pBufs.push_back(pBuf);
}
- cuSparseLtSpmmBuilder.create(loc, rewriter,
- {adaptor.getEnv(), adaptor.getSpmatA(),
- adaptor.getDnmatB(), adaptor.getDnmatC(),
- computeType, pBufs[0], pBufs[1], pBufs[2],
- stream});
+ createCuSparseLtSpMMBuilder.create(
+ loc, rewriter,
+ {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(),
+ adaptor.getDnmatC(), pBufs[0], pBufs[1], pBufs[2], stream});
} else {
Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
.allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
- spMMCallBuilder.create(loc, rewriter,
- {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
- adaptor.getDnmatB(), adaptor.getDnmatC(),
- computeType, pBuf, stream});
+ createSpMMCallBuilder.create(
+ loc, rewriter,
+ {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
+ adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, pBuf, stream});
}
rewriter.replaceOp(op, {stream});
return success();
@@ -1860,8 +1849,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- auto computeType =
- genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+ auto computeType = genConstInt32From(
+ rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
@@ -1869,10 +1858,10 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
- SDDMMCallBuilder.create(loc, rewriter,
- {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
- adaptor.getDnmatB(), adaptor.getSpmatC(),
- computeType, pBuf, stream});
+ createSDDMMCallBuilder.create(
+ loc, rewriter,
+ {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
+ adaptor.getSpmatC(), computeType, pBuf, stream});
rewriter.replaceOp(op, {stream});
return success();
}
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index c811b72a21403..47fee38e84d2c 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -468,13 +468,13 @@ mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
struct cusparseLtSpMatHandleAndData {
cusparseLtMatDescriptor_t mat;
- void *values{nullptr};
- // TODO: the following is associated with the SpMM operator rather than the
- // sparse matrix. Create workspace buffers and pass them to the SpMM
+ // TODO: the following three are associated with the SpMM operator rather than
+ // the sparse matrix. Create workspace buffers and pass them to the SpMM
// execution.
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulPlan_t plan;
cusparseLtMatmulDescriptor_t matmul;
+ void *values{nullptr};
};
struct cusparseLtDnMatHandleAndData {
@@ -482,24 +482,15 @@ struct cusparseLtDnMatHandleAndData {
void *values{nullptr};
};
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuAssertSparseLTEnvHandleSize() {
- assert(sizeof(cusparseLtHandle_t) == 11024);
-}
-
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuAssertSparseLtSpMatHandleSize() {
- return assert(sizeof(cusparseLtSpMatHandleAndData) == 44104);
-}
-
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSparseLtDnMatHandleSize() {
- return assert(sizeof(cusparseLtDnMatHandleAndData) == 11032);
-}
+static_assert(sizeof(cusparseLtHandle_t) == 11024);
+static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104);
+static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032);
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuCreateSparseLtEnv(void *h, CUstream /*stream*/) {
// note that cuSparseLt still uses cusparseStatus_t
CUSPARSE_REPORT_IF_ERROR(
cusparseLtInit(reinterpret_cast<cusparseLtHandle_t *>(h)))
- return;
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
@@ -510,15 +501,16 @@ mgpuDestroySparseLtEnv(void *h, CUstream /*stream*/) {
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuCreateCuSparseLtDnMat(void *dh, void *h, intptr_t rows, intptr_t cols,
- void *values, int32_t dw, CUstream /*stream*/) {
- cusparseLtMatDescriptor_t mat;
+ void *values, int32_t dtp, CUstream /*stream*/) {
auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
+ // CusparseLt expects the descriptors to be zero-initialized.
+ memset(dh, 0, sizeof(cusparseLtDnMatHandleAndData));
auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
- cudaDataType_t dtp = dataTp(dw);
+ auto dTp = static_cast<cudaDataType_t>(dtp);
// assuming row-major when deciding lda
CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit(
- handle, &(dh->mat), rows, cols, /*lda=*/cols,
- /*alignment=*/16, dtp, CUSPARSE_ORDER_ROW))
+ handle, &(dnmat_handle->mat), rows, cols, /*lda=*/cols,
+ /*alignment=*/16, dTp, CUSPARSE_ORDER_ROW))
dnmat_handle->values = values;
}
@@ -526,94 +518,99 @@ mgpuCreateCuSparseLtDnMat(void *dh, void *h, intptr_t rows, intptr_t cols,
// cusparseLt
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuDestroyCuSparseLtSpMat(void *m, CUstream /*stream*/) {
- auto matAndData = reinterpret_cast<cusparseLtSpMatHandleAndData>(m);
+ auto matAndData = reinterpret_cast<cusparseLtSpMatHandleAndData *>(m);
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matAndData->mat)))
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuDestroyCuSparseLtDnMat(void *m, CUstream /*stream*/) {
- auto matAndData = reinterpret_cast<cusparseLtDnMatHandleAndData>(m);
- CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat)))
+ auto matAndData = reinterpret_cast<cusparseLtDnMatHandleAndData *>(m);
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matAndData->mat)))
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuCusparseLtCreate2To4SpMat(void *sh, void *h, intptr_t rows, intptr_t cols,
- void *values, int32_t dw, CUstream /*stream*/) {
+ void *values, int32_t dtp, CUstream /*stream*/) {
auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
+ // CusparseLt expects the descriptors to be zero-initialized.
+ memset(spmat_handle, 0, sizeof(cusparseLtSpMatHandleAndData));
spmat_handle->values = values;
auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
- cudaDataType_t dtp = dataTp_cusparseLt(dw);
+ auto dTp = static_cast<cudaDataType_t>(dtp);
// assuming row-major when deciding lda
CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit(
- handle, &(sh->mat), rows, cols, /*ld=*/cols, /*alignment=*/16, dtp,
- CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
+ handle, &(spmat_handle->mat), rows, cols, /*ld=*/cols, /*alignment=*/16,
+ dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
}
// Several things are being done in this stage, algorithm selection, planning,
// and returning workspace and compressed matrices data buffer sizes.
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuCuSparseLtSpMMBufferSize(void *workspace_size, void *compressed_size,
- void *compressed_buffer_size, void *h, void *a,
+mgpuCuSparseLtSpMMBufferSize(void *bs, void *h, int32_t ma, int32_t mb, void *a,
+ void *b, void *c, int32_t ctp,
CUstream /*stream*/) {
// TODO: support more advanced settings, e.g., the input right operand is a
// sparse matrix assuming matA is the sparse matrix
auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
+ auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
+ auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
+ auto workspace_size = reinterpret_cast<size_t *>(bs);
+ auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]);
+ auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]);
+ auto cTp = static_cast<cusparseComputeType>(ctp);
- CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit(
- handle, &(matWithData.alg_sel), &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
+ cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+ cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit(
+ handle, &(matA->matmul), modeA, modeB, &(matA->mat), &(matB->mat),
+ &(matC->mat), &(matC->mat), cTp))
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit(
+ handle, &(matA->alg_sel), &(matA->matmul), CUSPARSELT_MATMUL_ALG_DEFAULT))
int alg = 0;
- CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
- handle, &(matWithData.alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute(
+ handle, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
sizeof(alg)))
- // TODO: add transpose support
- CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
- handle, &(matA.matmul), c, CUSPARSE_OPERATION_NON_TRANSPOSE, &(matA->mat),
- &matB, &matC, &matC, compute_type))
- CHECK_CUSPARSE(cusparseLtMatmulPlanInit(handle, &(matWithData.plan), &matmul,
- &(matWithData.alg_sel)))
-
- CHECK_CUSPARSE(
- cusparseLtMatmulGetWorkspace(handle, &(matA.plan), workspace_size))
- CHECK_CUSPARSE(cusparseLtSpMMACompressedSize(
- handle, &(matA.plan), compressed_size, compressed_buffer_size))
+
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
+ handle, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
+
+ CUSPARSE_REPORT_IF_ERROR(
+ cusparseLtMatmulGetWorkspace(handle, &(matA->plan), workspace_size))
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
+ handle, &(matA->plan), compressed_size, compressed_buffer_size))
// avoid zero-alloc
*workspace_size = (*workspace_size == 0 ? 1 : *workspace_size);
*compressed_size = (*compressed_size == 0 ? 1 : *compressed_size);
*compressed_buffer_size =
(*compressed_buffer_size == 0 ? 1 : *compressed_buffer_size);
- return;
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuCuSparseLtSpMM(void *alg_sel, void *plan, void *matmul, void *h, void *a,
- void *b, void *c, int32_t dw, void *buf, void *dA_compressed,
- void *dA_compressedBuffer, CUstream stream) {
+mgpuCuSparseLtSpMM(void *h, void *a, void *b, void *c, void *d_workspace,
+ void *dA_compressed, void *dA_compressedBuffer,
+ CUstream stream) {
auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
- cusparseLtMatmulAlgSelection_t alg_sel;
- cusparseLtMatmulPlan_t plan;
- cusparseLtMatmulDescriptor_t matmul;
-
- ALPHABETA(dw, alpha, beta)
-
- CHECK_CUSPARSE(cusparseLtSpMMACompress(handle, &(matA->plan), &(matA->values),
- dA_compressed, dA_compressedBuffer,
- stream))
+ ALPHABETA(CUDA_R_32F, alpha, beta)
+ CUSPARSE_REPORT_IF_ERROR(
+ cusparseLtSpMMACompress(handle, &(matA->plan), (matA->values),
+ dA_compressed, dA_compressedBuffer, stream))
// TODO: add support to multi-stream execution
// Perform the matrix multiplication. D = A*B+C using C==D for now
- CHECK_CUSPARSE(
- cusparseLtMatmul(handle, reinterpret_cast<cusparseLtMatmulPlan_t *>(plan),
- &alpha, dA_compressed, dB, &beta, matC->values,
- /*dD*/ matC->values, d_workspace, &stream, 1))
+ CUSPARSE_REPORT_IF_ERROR(
+ cusparseLtMatmul(handle, &(matA->plan), alphap, dA_compressed,
+ matB->values, betap, matC->values,
+ /*dD*/ matC->values, d_workspace, nullptr, 0))
- CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat)))
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat)))
// destroy the plan associated with the sparse matrix
- CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(mat->plan)))
+ CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan)))
}
#endif // MLIR_ENABLE_CUDA_CUSPARSELT
More information about the Mlir-commits
mailing list