[Mlir-commits] [mlir] 235fbe7 - [mlir] [sparse] [gpu] adding transpose support to spmm spmv

Kun Wu llvmlistbot at llvm.org
Fri May 26 10:09:51 PDT 2023


Author: Kun Wu
Date: 2023-05-26T17:07:09Z
New Revision: 235fbe792b4cb167403bba8770629a295ce81c6c

URL: https://github.com/llvm/llvm-project/commit/235fbe792b4cb167403bba8770629a295ce81c6c
DIFF: https://github.com/llvm/llvm-project/commit/235fbe792b4cb167403bba8770629a295ce81c6c.diff

LOG: [mlir] [sparse] [gpu] adding transpose support to spmm spmv

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D151259

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index fdcbf4d139bca..6b8ede2071af4 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1816,6 +1816,31 @@ def GPU_DestroySpMatOp : GPU_Op<"destroy_sp_mat", [GPU_AsyncOpInterface]> {
   }];
 }
 
+// To avoid coupling this dialect with cusparse.h specifics, we hardcoded magic 
+// literals in this enum. Note that this should be kept in sync with 
+// cusparseOperation_t in cusparse.h:
+// typedef enum {
+// CUSPARSE_OPERATION_NON_TRANSPOSE       = 0,
+// CUSPARSE_OPERATION_TRANSPOSE           = 1,
+// CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE = 2
+// } cusparseOperation_t;
+// TODO: find a proper way to keep them in sync?
+def GPU_TransposeMode : I32EnumAttr<"TransposeMode",
+    "transpose mode of sparse matrix supported by sparse tensor ops",
+    [
+      I32EnumAttrCase<"NON_TRANSPOSE", 0>, 
+      I32EnumAttrCase<"TRANSPOSE", 1>, 
+      I32EnumAttrCase<"CONJUGATE_TRANSPOSE", 2>,
+    ]> {
+      let genSpecializedAttr = 0;
+      let cppNamespace = GPU_Dialect.cppNamespace;
+}
+
+def GPU_TransposeModeAttr : EnumAttr<GPU_Dialect, GPU_TransposeMode,
+                                   "mat_transpose_mode">{
+  let defaultValue = "TransposeMode::NON_TRANSPOSE";
+}
+
 def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
   let summary = "Precompute buffersize for SpMV operation";
   let description = [{
@@ -1828,23 +1853,41 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
     it does not block until the execution has finished on the device). In
     that case, it returns a !gpu.async.token in addition to the environment.
 
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
     Example:
 
     ```mlir
-    %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY
+    %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY
     ```
   }];
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                        GPU_SparseEnvHandle:$env,
+                       GPU_TransposeModeAttr:$modeA,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
                        GPU_SparseDnVecHandle:$dnY);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
+  let builders = [OpBuilder<(ins
+      "::mlir::Type":$bufferSz,
+      "::mlir::Type":$asyncToken,
+      "::mlir::ValueRange":$asyncDependencies,
+      "::mlir::Value":$env,
+      "::mlir::Value":$spmatA,
+      "::mlir::Value":$dnX,
+      "::mlir::Value":$dnY), [{
+    auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+    return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env, 
+                 modeA, spmatA, dnX, dnY);}]>
+  ];
+
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA `,` $dnX `,` $dnY attr-dict
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict
   }];
 }
 
@@ -1860,23 +1903,41 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
     it does not block until the execution has finished on the device). In
     that case, it returns a !gpu.async.token in addition to the environment.
 
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
     Example:
 
     ```mlir
-    %token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref<?xf64>
+    %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64>
     ```
   }];
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                        GPU_SparseEnvHandle:$env,
+                       GPU_TransposeModeAttr:$modeA,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
                        GPU_SparseDnVecHandle:$dnY,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
+  let builders = [OpBuilder<(ins
+      "::mlir::Type":$asyncToken,
+      "::mlir::ValueRange":$asyncDependencies,
+      "::mlir::Value":$env,
+      "::mlir::Value":$spmatA,
+      "::mlir::Value":$dnX,
+      "::mlir::Value":$dnY,
+      "::mlir::Value":$buffer), [{
+    auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+    return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
+                 spmatA, dnX, dnY, buffer);}]>
+  ];
+
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
   }];
 }
 
@@ -1892,24 +1953,44 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
     it does not block until the execution has finished on the device). In
     that case, it returns a !gpu.async.token in addition to the environment.
 
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
     Example:
 
     ```mlir
-    %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA, %spmatB, %spmatC
+    %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC
     ```
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                        GPU_SparseEnvHandle:$env,
+                       GPU_TransposeModeAttr:$modeA,
+                       GPU_TransposeModeAttr:$modeB,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
                        GPU_SparseDnMatHandle:$dnmatC);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
+  let builders = [OpBuilder<(ins
+      "::mlir::Type":$bufferSz,
+      "::mlir::Type":$asyncToken,
+      "::mlir::ValueRange":$asyncDependencies,
+      "::mlir::Value":$env,
+      "::mlir::Value":$spmatA,
+      "::mlir::Value":$dnmatB,
+      "::mlir::Value":$dnmatC), [{
+    auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+    auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+    return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
+                 env, modeA, modeB, spmatA, dnmatB, dnmatC);}]>
+  ];
+
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA `,` $dnmatB `,` $dnmatC attr-dict
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict
   }];
 }
 
@@ -1925,24 +2006,44 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
     it does not block until the execution has finished on the device). In
     that case, it returns a !gpu.async.token in addition to the environment.
 
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
     Example:
 
     ```mlir
-    %token = gpu.spmm async [%dep] %env, %spmatA, %spmatB, %spmatC, %buffer
+    %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer
     ```
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                        GPU_SparseEnvHandle:$env,
+                       GPU_TransposeModeAttr:$modeA,
+                       GPU_TransposeModeAttr:$modeB,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
                        GPU_SparseDnMatHandle:$dnmatC,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
+  let builders = [OpBuilder<(ins
+      "::mlir::Type":$asyncToken,
+      "::mlir::ValueRange":$asyncDependencies,
+      "::mlir::Value":$env,
+      "::mlir::Value":$spmatA,
+      "::mlir::Value":$dnmatB,
+      "::mlir::Value":$dnmatC,
+      "::mlir::Value":$buffer), [{
+    auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+    auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+    return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
+                 modeB, spmatA, dnmatB, dnmatC, buffer);}]>
+  ];
+
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA `,` $dnmatB `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
   }];
 }
 

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8defd8970b900..029c1005e58ca 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -237,23 +237,26 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   FunctionCallBuilder spMVBufferSizeCallBuilder = {
       "mgpuSpMVBufferSize",
       llvmIntPtrType,
-      {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmInt32Type, llvmPointerType /* void *stream */}};
+      {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
+       llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMVCallBuilder = {
       "mgpuSpMV",
       llvmVoidType,
-      {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
+      {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
+       llvmPointerType, llvmInt32Type, llvmPointerType,
+       llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMMBufferSizeCallBuilder = {
       "mgpuSpMMBufferSize",
       llvmIntPtrType,
-      {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmInt32Type, llvmPointerType /* void *stream */}};
+      {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+       llvmPointerType, llvmPointerType, llvmInt32Type,
+       llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMMCallBuilder = {
       "mgpuSpMM",
       llvmVoidType,
-      {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
+      {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+       llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
+       llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -1196,6 +1199,13 @@ static Type getSpMatElemType(Value spMat) {
   llvm_unreachable("cannot find spmat def");
 }
 
+static Value genConstFrom(OpBuilder &builder, Location loc,
+                          gpu::TransposeMode mode) {
+  Type llvmInt32Type = builder.getIntegerType(32);
+  return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                          static_cast<int32_t>(mode));
+}
+
 LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::CreateSparseEnvOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1389,6 +1399,7 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  auto modeA = genConstFrom(rewriter, loc, op.getModeA());
   Type dType = getSpMatElemType(op.getSpmatA());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
@@ -1396,8 +1407,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto bufferSize =
       spMVBufferSizeCallBuilder
           .create(loc, rewriter,
-                  {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnX(),
-                   adaptor.getDnY(), dw, stream})
+                  {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
+                   adaptor.getDnX(), adaptor.getDnY(), dw, stream})
           .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
@@ -1411,6 +1422,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   Type dType = getSpMatElemType(op.getSpmatA());
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
@@ -1419,7 +1431,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
   if (!getTypeConverter()->useOpaquePointers())
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMVCallBuilder.create(loc, rewriter,
-                         {adaptor.getEnv(), adaptor.getSpmatA(),
+                         {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
                           adaptor.getDnX(), adaptor.getDnY(), dw, pBuf,
                           stream});
   rewriter.replaceOp(op, {stream});
@@ -1434,14 +1446,16 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   Type dType = getSpMatElemType(op.getSpmatA());
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize =
       spMMBufferSizeCallBuilder
           .create(loc, rewriter,
-                  {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(),
-                   adaptor.getDnmatC(), dw, stream})
+                  {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
+                   adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream})
           .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
@@ -1457,13 +1471,15 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
   Type dType = getSpMatElemType(op.getSpmatA());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMMCallBuilder.create(loc, rewriter,
-                         {adaptor.getEnv(), adaptor.getSpmatA(),
+                         {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
                           adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf,
                           stream});
   rewriter.replaceOp(op, {stream});

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 246e5d985d87c..61ee115e879a9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -740,6 +740,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
     if (numLoops == 2 && numTensors == 3 &&
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isReductionIterator(iteratorTypes[1]) &&
+        // TODO: add transposed {i, j}
         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
       return rewriteSpMV(rewriter, op, enableRT);
     }
@@ -749,6 +750,8 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isParallelIterator(iteratorTypes[1]) &&
         linalg::isReductionIterator(iteratorTypes[2]) &&
+        // TODO: add transposed {i, k}, {k, j}
+        // TODO: maybe add transposed {i, j} in future
         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
       return rewriteSpMM(rewriter, op, enableRT);
     }

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index fa0fbebf212c1..c05785e9b576e 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -338,39 +338,45 @@ mgpuDestroySpMat(void *m, CUstream /*stream*/) {
   CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
-    void *h, void *a, void *x, void *y, int32_t dw, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw,
+                   CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
   cudaDataType_t dtp = dataTp(dw);
   ALPHABETA(dw, alpha, beta)
   size_t bufferSize = 0;
-  CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
-      handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX, betap, vecY,
-      dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
+  CUSPARSE_REPORT_IF_ERROR(
+      cusparseSpMV_bufferSize(handle, modeA, &alpha, matA, vecX, &beta, vecY,
+                              dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, void *a, void *x,
-                                                   void *y, int32_t dw,
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a,
+                                                   void *x, void *y, int32_t dw,
                                                    void *buf,
                                                    CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
   cudaDataType_t dtp = dataTp(dw);
   ALPHABETA(dw, alpha, beta)
-  CUSPARSE_REPORT_IF_ERROR(
-      cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, vecX,
-                   betap, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf))
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, &alpha, matA, vecX,
+                                        &beta, vecY, dtp,
+                                        CUSPARSE_SPMV_ALG_DEFAULT, buf))
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(
-    void *h, void *a, void *b, void *c, int32_t dw, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+                   int32_t dw, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
@@ -378,24 +384,23 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(
   ALPHABETA(dw, alpha, beta)
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
-      handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
-      CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap, matC, dtp,
+      handle, modeA, modeB, &alpha, matA, matB, &beta, matC, dtp,
       CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, void *a, void *b,
-                                                   void *c, int32_t dw,
-                                                   void *buf,
-                                                   CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
+         void *buf, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
   cudaDataType_t dtp = dataTp(dw);
   ALPHABETA(dw, alpha, beta)
-  CUSPARSE_REPORT_IF_ERROR(
-      cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
-                   CUSPARSE_OPERATION_NON_TRANSPOSE, alphap, matA, matB, betap,
-                   matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf))
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, &alpha, matA,
+                                        matB, &beta, matC, dtp,
+                                        CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }


        


More information about the Mlir-commits mailing list