[Mlir-commits] [mlir] 235fbe7 - [mlir] [sparse] [gpu] adding transpose support to spmm spmv
Kun Wu
llvmlistbot at llvm.org
Fri May 26 10:09:51 PDT 2023
Author: Kun Wu
Date: 2023-05-26T17:07:09Z
New Revision: 235fbe792b4cb167403bba8770629a295ce81c6c
URL: https://github.com/llvm/llvm-project/commit/235fbe792b4cb167403bba8770629a295ce81c6c
DIFF: https://github.com/llvm/llvm-project/commit/235fbe792b4cb167403bba8770629a295ce81c6c.diff
LOG: [mlir] [sparse] [gpu] adding transpose support to spmm spmv
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D151259
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index fdcbf4d139bca..6b8ede2071af4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1816,6 +1816,31 @@ def GPU_DestroySpMatOp : GPU_Op<"destroy_sp_mat", [GPU_AsyncOpInterface]> {
}];
}
+// To avoid coupling this dialect with cusparse.h specifics, we hardcoded magic
+// literals in this enum. Note that this should be kept in sync with
+// cusparseOperation_t in cusparse.h:
+// typedef enum {
+// CUSPARSE_OPERATION_NON_TRANSPOSE = 0,
+// CUSPARSE_OPERATION_TRANSPOSE = 1,
+// CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE = 2
+// } cusparseOperation_t;
+// TODO: find a proper way to keep them in sync?
+def GPU_TransposeMode : I32EnumAttr<"TransposeMode",
+ "transpose mode of sparse matrix supported by sparse tensor ops",
+ [
+ I32EnumAttrCase<"NON_TRANSPOSE", 0>,
+ I32EnumAttrCase<"TRANSPOSE", 1>,
+ I32EnumAttrCase<"CONJUGATE_TRANSPOSE", 2>,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = GPU_Dialect.cppNamespace;
+}
+
+def GPU_TransposeModeAttr : EnumAttr<GPU_Dialect, GPU_TransposeMode,
+ "mat_transpose_mode">{
+ let defaultValue = "TransposeMode::NON_TRANSPOSE";
+}
+
def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
let summary = "Precompute buffersize for SpMV operation";
let description = [{
@@ -1828,23 +1853,41 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
it does not block until the execution has finished on the device). In
that case, it returns a !gpu.async.token in addition to the environment.
+ The matrix arguments can also be associated with one of the following
+ operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+ is NON_TRANSPOSE.
+
Example:
```mlir
- %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY
+ %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
GPU_SparseEnvHandle:$env,
+ GPU_TransposeModeAttr:$modeA,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnVecHandle:$dnX,
GPU_SparseDnVecHandle:$dnY);
let results = (outs Res<Index>:$bufferSz,
Optional<GPU_AsyncToken>:$asyncToken);
+ let builders = [OpBuilder<(ins
+ "::mlir::Type":$bufferSz,
+ "::mlir::Type":$asyncToken,
+ "::mlir::ValueRange":$asyncDependencies,
+ "::mlir::Value":$env,
+ "::mlir::Value":$spmatA,
+ "::mlir::Value":$dnX,
+ "::mlir::Value":$dnY), [{
+ auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+ return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env,
+ modeA, spmatA, dnX, dnY);}]>
+ ];
+
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA `,` $dnX `,` $dnY attr-dict
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict
}];
}
@@ -1860,23 +1903,41 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
it does not block until the execution has finished on the device). In
that case, it returns a !gpu.async.token in addition to the environment.
+ The matrix arguments can also be associated with one of the following
+ operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+ is NON_TRANSPOSE.
+
Example:
```mlir
- %token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref<?xf64>
+ %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64>
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
GPU_SparseEnvHandle:$env,
+ GPU_TransposeModeAttr:$modeA,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnVecHandle:$dnX,
GPU_SparseDnVecHandle:$dnY,
AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+ let builders = [OpBuilder<(ins
+ "::mlir::Type":$asyncToken,
+ "::mlir::ValueRange":$asyncDependencies,
+ "::mlir::Value":$env,
+ "::mlir::Value":$spmatA,
+ "::mlir::Value":$dnX,
+ "::mlir::Value":$dnY,
+ "::mlir::Value":$buffer), [{
+ auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+ return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
+ spmatA, dnX, dnY, buffer);}]>
+ ];
+
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
}];
}
@@ -1892,24 +1953,44 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
it does not block until the execution has finished on the device). In
that case, it returns a !gpu.async.token in addition to the environment.
+ The matrix arguments can also be associated with one of the following
+ operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+ is NON_TRANSPOSE.
+
Example:
```mlir
- %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA, %spmatB, %spmatC
+ %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
GPU_SparseEnvHandle:$env,
+ GPU_TransposeModeAttr:$modeA,
+ GPU_TransposeModeAttr:$modeB,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnMatHandle:$dnmatB,
GPU_SparseDnMatHandle:$dnmatC);
let results = (outs Res<Index>:$bufferSz,
Optional<GPU_AsyncToken>:$asyncToken);
+ let builders = [OpBuilder<(ins
+ "::mlir::Type":$bufferSz,
+ "::mlir::Type":$asyncToken,
+ "::mlir::ValueRange":$asyncDependencies,
+ "::mlir::Value":$env,
+ "::mlir::Value":$spmatA,
+ "::mlir::Value":$dnmatB,
+ "::mlir::Value":$dnmatC), [{
+ auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+ auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+ return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
+ env, modeA, modeB, spmatA, dnmatB, dnmatC);}]>
+ ];
+
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA `,` $dnmatB `,` $dnmatC attr-dict
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict
}];
}
@@ -1925,24 +2006,44 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
it does not block until the execution has finished on the device). In
that case, it returns a !gpu.async.token in addition to the environment.
+ The matrix arguments can also be associated with one of the following
+ operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+ is NON_TRANSPOSE.
+
Example:
```mlir
- %token = gpu.spmm async [%dep] %env, %spmatA, %spmatB, %spmatC, %buffer
+ %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
GPU_SparseEnvHandle:$env,
+ GPU_TransposeModeAttr:$modeA,
+ GPU_TransposeModeAttr:$modeB,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnMatHandle:$dnmatB,
GPU_SparseDnMatHandle:$dnmatC,
AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+ let builders = [OpBuilder<(ins
+ "::mlir::Type":$asyncToken,
+ "::mlir::ValueRange":$asyncDependencies,
+ "::mlir::Value":$env,
+ "::mlir::Value":$spmatA,
+ "::mlir::Value":$dnmatB,
+ "::mlir::Value":$dnmatC,
+ "::mlir::Value":$buffer), [{
+ auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+ auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+ return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
+ modeB, spmatA, dnmatB, dnmatC, buffer);}]>
+ ];
+
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA `,` $dnmatB `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
}];
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8defd8970b900..029c1005e58ca 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -237,23 +237,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
FunctionCallBuilder spMVBufferSizeCallBuilder = {
"mgpuSpMVBufferSize",
llvmIntPtrType,
- {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
- llvmInt32Type, llvmPointerType /* void *stream */}};
+ {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
+ llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
FunctionCallBuilder spMVCallBuilder = {
"mgpuSpMV",
llvmVoidType,
- {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
- llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
+ {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
+ llvmPointerType, llvmInt32Type, llvmPointerType,
+ llvmPointerType /* void *stream */}};
FunctionCallBuilder spMMBufferSizeCallBuilder = {
"mgpuSpMMBufferSize",
llvmIntPtrType,
- {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
- llvmInt32Type, llvmPointerType /* void *stream */}};
+ {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+ llvmPointerType, llvmPointerType, llvmInt32Type,
+ llvmPointerType /* void *stream */}};
FunctionCallBuilder spMMCallBuilder = {
"mgpuSpMM",
llvmVoidType,
- {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
- llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
+ {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+ llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
+ llvmPointerType /* void *stream */}};
};
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -1196,6 +1199,13 @@ static Type getSpMatElemType(Value spMat) {
llvm_unreachable("cannot find spmat def");
}
+static Value genConstFrom(OpBuilder &builder, Location loc,
+ gpu::TransposeMode mode) {
+ Type llvmInt32Type = builder.getIntegerType(32);
+ return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+ static_cast<int32_t>(mode));
+}
+
LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::CreateSparseEnvOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1389,6 +1399,7 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
+ auto modeA = genConstFrom(rewriter, loc, op.getModeA());
Type dType = getSpMatElemType(op.getSpmatA());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
@@ -1396,8 +1407,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
auto bufferSize =
spMVBufferSizeCallBuilder
.create(loc, rewriter,
- {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnX(),
- adaptor.getDnY(), dw, stream})
+ {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
+ adaptor.getDnX(), adaptor.getDnY(), dw, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
@@ -1411,6 +1422,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();
Location loc = op.getLoc();
Type dType = getSpMatElemType(op.getSpmatA());
+ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto stream = adaptor.getAsyncDependencies().front();
@@ -1419,7 +1431,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
if (!getTypeConverter()->useOpaquePointers())
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
spMVCallBuilder.create(loc, rewriter,
- {adaptor.getEnv(), adaptor.getSpmatA(),
+ {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
adaptor.getDnX(), adaptor.getDnY(), dw, pBuf,
stream});
rewriter.replaceOp(op, {stream});
@@ -1434,14 +1446,16 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
return failure();
Location loc = op.getLoc();
Type dType = getSpMatElemType(op.getSpmatA());
+ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
auto stream = adaptor.getAsyncDependencies().front();
auto bufferSize =
spMMBufferSizeCallBuilder
.create(loc, rewriter,
- {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(),
- adaptor.getDnmatC(), dw, stream})
+ {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
+ adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
@@ -1457,13 +1471,15 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
Type dType = getSpMatElemType(op.getSpmatA());
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
dType.getIntOrFloatBitWidth());
+ auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
spMMCallBuilder.create(loc, rewriter,
- {adaptor.getEnv(), adaptor.getSpmatA(),
+ {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf,
stream});
rewriter.replaceOp(op, {stream});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 246e5d985d87c..61ee115e879a9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -740,6 +740,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
if (numLoops == 2 && numTensors == 3 &&
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isReductionIterator(iteratorTypes[1]) &&
+ // TODO: add transposed {i, j}
maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMV(rewriter, op, enableRT);
}
@@ -749,6 +750,8 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
linalg::isParallelIterator(iteratorTypes[0]) &&
linalg::isParallelIterator(iteratorTypes[1]) &&
linalg::isReductionIterator(iteratorTypes[2]) &&
+ // TODO: add transposed {i, k}, {k, j}
+ // TODO: maybe add transposed {i, j} in future
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
return rewriteSpMM(rewriter, op, enableRT);
}
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index fa0fbebf212c1..c05785e9b576e 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -338,39 +338,45 @@ mgpuDestroySpMat(void *m, CUstream /*stream*/) {
CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
}
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
- void *h, void *a, void *x, void *y, int32_t dw, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw,
+ CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+ cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
cudaDataType_t dtp = dataTp(dw);
ALPHABETA(dw, alpha, beta)
size_t bufferSize = 0;
- CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
- handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, betap, vecY,
- dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
+ CUSPARSE_REPORT_IF_ERROR(
+ cusparseSpMV_bufferSize(handle, modeA, &alpha, matA, vecX, &beta, vecY,
+ dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
}
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, void *a, void *x,
- void *y, int32_t dw,
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a,
+ void *x, void *y, int32_t dw,
void *buf,
CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+ cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
cudaDataType_t dtp = dataTp(dw);
ALPHABETA(dw, alpha, beta)
- CUSPARSE_REPORT_IF_ERROR(
- cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX,
- betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf))
+ CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, &alpha, matA, vecX,
+ &beta, vecY, dtp,
+ CUSPARSE_SPMV_ALG_DEFAULT, buf))
}
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(
- void *h, void *a, void *b, void *c, int32_t dw, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+ int32_t dw, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+ cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+ cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
@@ -378,24 +384,23 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(
ALPHABETA(dw, alpha, beta)
size_t bufferSize = 0;
CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
- handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
- CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp,
+ handle, modeA, modeB, &alpha, matA, matB, &beta, matC, dtp,
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
}
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, void *a, void *b,
- void *c, int32_t dw,
- void *buf,
- CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
+ void *buf, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+ cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+ cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
cudaDataType_t dtp = dataTp(dw);
ALPHABETA(dw, alpha, beta)
- CUSPARSE_REPORT_IF_ERROR(
- cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
- CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap,
- matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf))
+ CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, &alpha, matA,
+ matB, &beta, matC, dtp,
+ CUSPARSE_SPMM_ALG_DEFAULT, buf))
}
More information about the Mlir-commits
mailing list