[Mlir-commits] [mlir] a8e1f80 - [mlir][sparse][gpu] derive type of cuSparse op

Aart Bik llvmlistbot at llvm.org
Fri May 19 17:08:02 PDT 2023


Author: Aart Bik
Date: 2023-05-19T17:07:52-07:00
New Revision: a8e1f80f8b68736af15708f04681e7830058e1a9

URL: https://github.com/llvm/llvm-project/commit/a8e1f80f8b68736af15708f04681e7830058e1a9
DIFF: https://github.com/llvm/llvm-project/commit/a8e1f80f8b68736af15708f04681e7830058e1a9.diff

LOG: [mlir][sparse][gpu] derive type of cuSparse op

This no longer assumes just F64 output.

Note, however, that it will be cleaner to carry the data type in the corresponding operation (rather than tracking operands). That will also allow for mixed type cases, where operands and result type are different

This will be done in a follow revision where the result type is carried by the SpMV/SpMM op itself (and friends).

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D151005

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8c79ee3745d9b..600bd9152c436 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -238,22 +238,22 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       "mgpuSpMVBufferSize",
       llvmIntPtrType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmPointerType /* void *stream */}};
+       llvmInt32Type, llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMVCallBuilder = {
       "mgpuSpMV",
       llvmVoidType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmPointerType, llvmPointerType /* void *stream */}};
+       llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMMBufferSizeCallBuilder = {
       "mgpuSpMMBufferSize",
       llvmIntPtrType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmPointerType /* void *stream */}};
+       llvmInt32Type, llvmPointerType /* void *stream */}};
   FunctionCallBuilder spMMCallBuilder = {
       "mgpuSpMM",
       llvmVoidType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmPointerType, llvmPointerType /* void *stream */}};
+       llvmInt32Type, llvmPointerType, llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -1186,6 +1186,16 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+// Returns the element type of the defining spmat op.
+// TODO: safer and more flexible to store data type in actual op instead?
+static Type getSpMatElemType(Value spMat) {
+  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
+    return op.getValues().getType().cast<MemRefType>().getElementType();
+  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
+    return op.getValues().getType().cast<MemRefType>().getElementType();
+  llvm_unreachable("cannot find spmat def");
+}
+
 LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::CreateSparseEnvOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1379,12 +1389,16 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  Type dType = getSpMatElemType(op.getSpmatA());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
-  auto bufferSize = spMVBufferSizeCallBuilder
-                        .create(loc, rewriter,
-                                {adaptor.getEnv(), adaptor.getSpmatA(),
-                                 adaptor.getDnX(), adaptor.getDnY(), stream})
-                        .getResult();
+  auto bufferSize =
+      spMVBufferSizeCallBuilder
+          .create(loc, rewriter,
+                  {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnX(),
+                   adaptor.getDnY(), dw, stream})
+          .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
 }
@@ -1396,6 +1410,9 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  Type dType = getSpMatElemType(op.getSpmatA());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1403,7 +1420,8 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMVCallBuilder.create(loc, rewriter,
                          {adaptor.getEnv(), adaptor.getSpmatA(),
-                          adaptor.getDnX(), adaptor.getDnY(), pBuf, stream});
+                          adaptor.getDnX(), adaptor.getDnY(), dw, pBuf,
+                          stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }
@@ -1415,12 +1433,15 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  Type dType = getSpMatElemType(op.getSpmatA());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize =
       spMMBufferSizeCallBuilder
           .create(loc, rewriter,
                   {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(),
-                   adaptor.getDnmatC(), stream})
+                   adaptor.getDnmatC(), dw, stream})
           .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
@@ -1433,6 +1454,9 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  Type dType = getSpMatElemType(op.getSpmatA());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1440,7 +1464,7 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMMCallBuilder.create(loc, rewriter,
                          {adaptor.getEnv(), adaptor.getSpmatA(),
-                          adaptor.getDnmatB(), adaptor.getDnmatC(), pBuf,
+                          adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf,
                           stream});
   rewriter.replaceOp(op, {stream});
   return success();

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index e5d4cdd738847..d5eb9ca1731b9 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -322,60 +322,68 @@ mgpuDestroySpMat(void *m, CUstream /*stream*/) {
   CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
-mgpuSpMVBufferSize(void *h, void *a, void *x, void *y, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
+    void *h, void *a, void *x, void *y, int32_t dw, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
+  cudaDataType_t dtp = dataTp(dw);
   double alpha = 1.0;
   double beta = 1.0;
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
       handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX, &beta, vecY,
-      CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
+      dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuSpMV(void *h, void *a, void *x, void *y, void *buf, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, void *a, void *x,
+                                                   void *y, int32_t dw,
+                                                   void *buf,
+                                                   CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
+  cudaDataType_t dtp = dataTp(dw);
   double alpha = 1.0;
   double beta = 1.0;
   CUSPARSE_REPORT_IF_ERROR(
       cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, vecX,
-                   &beta, vecY, CUDA_R_64F, CUSPARSE_SPMV_ALG_DEFAULT, buf))
+                   &beta, vecY, dtp, CUSPARSE_SPMV_ALG_DEFAULT, buf))
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
-mgpuSpMMBufferSize(void *h, void *a, void *b, void *c, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMMBufferSize(
+    void *h, void *a, void *b, void *c, int32_t dw, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
+  cudaDataType_t dtp = dataTp(dw);
   double alpha = 1.0;
   double beta = 1.0;
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
       handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
-      CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC,
-      CUDA_R_64F, CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
+      CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta, matC, dtp,
+      CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuSpMM(void *h, void *a, void *b, void *c, void *buf, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(void *h, void *a, void *b,
+                                                   void *c, int32_t dw,
+                                                   void *buf,
+                                                   CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
+  cudaDataType_t dtp = dataTp(dw);
   double alpha = 1.0;
   double beta = 1.0;
   CUSPARSE_REPORT_IF_ERROR(
       cusparseSpMM(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
                    CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA, matB, &beta,
-                   matC, CUDA_R_64F, CUSPARSE_SPMM_ALG_DEFAULT, buf))
+                   matC, dtp, CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }


        


More information about the Mlir-commits mailing list