[Mlir-commits] [mlir] cc402de - [mlir][sparse][gpu] add result type to spmv and spmm gpu libgen path

Kun Wu llvmlistbot at llvm.org
Thu Jun 1 10:17:48 PDT 2023


Author: Kun Wu
Date: 2023-06-01T17:17:40Z
New Revision: cc402de0b13b8682fec5762b2cf5064f9c8297f8

URL: https://github.com/llvm/llvm-project/commit/cc402de0b13b8682fec5762b2cf5064f9c8297f8
DIFF: https://github.com/llvm/llvm-project/commit/cc402de0b13b8682fec5762b2cf5064f9c8297f8.diff

LOG: [mlir][sparse][gpu] add result type to spmv and spmm gpu libgen path

Differential Revision: https://reviews.llvm.org/D151592

Added: 
    mlir/test/Dialect/GPU/sparse-roundtrip.mlir

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a401feea3d075..2fac955ce6a12 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1860,7 +1860,7 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY
+    %buffersz, %token = gpu.spmv_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY into f32
     ```
   }];
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -1868,26 +1868,28 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
                        GPU_TransposeModeAttr:$modeA,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
-                       GPU_SparseDnVecHandle:$dnY);
+                       GPU_SparseDnVecHandle:$dnY,
+                       OptionalAttr<TypeAttr>:$computeType);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-      "::mlir::Type":$bufferSz,
-      "::mlir::Type":$asyncToken,
-      "::mlir::ValueRange":$asyncDependencies,
-      "::mlir::Value":$env,
-      "::mlir::Value":$spmatA,
-      "::mlir::Value":$dnX,
-      "::mlir::Value":$dnY), [{
+      "Type":$bufferSz,
+      "Type":$asyncToken,
+      "ValueRange":$asyncDependencies,
+      "Value":$env,
+      "Value":$spmatA,
+      "Value":$dnX,
+      "Value":$dnY)
+      , [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
-    return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env, 
-                 modeA, spmatA, dnX, dnY);}]>
+    return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
+                 env, modeA, spmatA, dnX, dnY, {});}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict ( `into` $computeType^)?
   }];
 }
 
@@ -1910,7 +1912,7 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64>
+    %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64> into bf16
     ```
   }];
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -1919,25 +1921,26 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
                        GPU_SparseDnVecHandle:$dnY,
+                       OptionalAttr<TypeAttr>:$computeType,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-      "::mlir::Type":$asyncToken,
-      "::mlir::ValueRange":$asyncDependencies,
-      "::mlir::Value":$env,
-      "::mlir::Value":$spmatA,
-      "::mlir::Value":$dnX,
-      "::mlir::Value":$dnY,
-      "::mlir::Value":$buffer), [{
+      "Type":$asyncToken,
+      "ValueRange":$asyncDependencies,
+      "Value":$env,
+      "Value":$spmatA,
+      "Value":$dnX,
+      "Value":$dnY,
+      "Value":$buffer), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
-                 spmatA, dnX, dnY, buffer);}]>
+                 spmatA, dnX, dnY, {}, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
   }];
 }
 
@@ -1960,7 +1963,7 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC
+    %buffersz, %token = gpu.spmm_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC into f32
     ```
   }];
 
@@ -1970,27 +1973,28 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
                        GPU_TransposeModeAttr:$modeB,
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
-                       GPU_SparseDnMatHandle:$dnmatC);
+                       GPU_SparseDnMatHandle:$dnmatC,
+                       OptionalAttr<TypeAttr>:$computeType);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-      "::mlir::Type":$bufferSz,
-      "::mlir::Type":$asyncToken,
-      "::mlir::ValueRange":$asyncDependencies,
-      "::mlir::Value":$env,
-      "::mlir::Value":$spmatA,
-      "::mlir::Value":$dnmatB,
-      "::mlir::Value":$dnmatC), [{
+      "Type":$bufferSz,
+      "Type":$asyncToken,
+      "ValueRange":$asyncDependencies,
+      "Value":$env,
+      "Value":$spmatA,
+      "Value":$dnmatB,
+      "Value":$dnmatC), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
-                 env, modeA, modeB, spmatA, dnmatB, dnmatC);}]>
+                 env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict ( `into` $computeType^)?
   }];
 }
 
@@ -2013,7 +2017,7 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer
+    %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer into f32
     ```
   }];
 
@@ -2024,26 +2028,27 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
                        GPU_SparseDnMatHandle:$dnmatC,
+                       OptionalAttr<TypeAttr>:$computeType,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-      "::mlir::Type":$asyncToken,
-      "::mlir::ValueRange":$asyncDependencies,
-      "::mlir::Value":$env,
-      "::mlir::Value":$spmatA,
-      "::mlir::Value":$dnmatB,
-      "::mlir::Value":$dnmatC,
-      "::mlir::Value":$buffer), [{
+      "Type":$asyncToken,
+      "ValueRange":$asyncDependencies,
+      "Value":$env,
+      "Value":$spmatA,
+      "Value":$dnmatB,
+      "Value":$dnmatC,
+      "Value":$buffer), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
-                 modeB, spmatA, dnmatB, dnmatC, buffer);}]>
+                 modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
   }];
 }
 
@@ -2062,7 +2067,7 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
     Example:
 
     ```mlir
-    %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC
+    %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC into f32
     ```
 
     The matrix arguments can also be associated with one of the following 
@@ -2076,7 +2081,8 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
                    GPU_TransposeModeAttr:$modeB,
                    GPU_SparseDnMatHandle:$dnmatA,
                    GPU_SparseDnMatHandle:$dnmatB,
-                   GPU_SparseSpMatHandle:$spmatC);
+                   GPU_SparseSpMatHandle:$spmatC,
+                   OptionalAttr<TypeAttr>:$computeType);
   let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
@@ -2090,12 +2096,12 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
-                 env, modeA, modeB, dnmatA, dnmatB, spmatC);}]>
+                 env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict ( `into` $computeType^)?
   }];
 }
 
@@ -2114,7 +2120,7 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    %token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer
+    %token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer into f32
     ```
 
     The matrix arguments can also be associated with one of the following 
@@ -2129,6 +2135,7 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
                    GPU_SparseDnMatHandle:$dnmatA,
                    GPU_SparseDnMatHandle:$dnmatB,
                    GPU_SparseSpMatHandle:$spmatC,
+                   OptionalAttr<TypeAttr>:$computeType,
                    AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
@@ -2143,12 +2150,12 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
   auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
   auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
   return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
-                modeB, dnmatA, dnmatB, spmatC, buffer);}]>
+                modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer)
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
   }];
 }
 

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 07ca1e51ed696..5ec455a65bd65 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -688,6 +688,53 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
   return builder.create<LLVM::CallOp>(loc, function, arguments);
 }
 
+// Corresponding to cusparseIndexType_t defined in cusparse.h.
+static int32_t getCuSparseIndexTypeFrom(Type type) {
+  if (type.isa<IndexType>())
+    return 3; // CUSPARSE_INDEX_64I
+  else
+    return 2; // CUSPARSE_INDEX_32I
+  // TODO: add support to CUSPARSE_INDEX_16U: 1
+}
+
+// Corresponding to cudaDataType_t defined in CUDA library_types.h.
+static int32_t getCuSparseDataTypeFrom(Type type) {
+  if (llvm::isa<ComplexType>(type)) {
+    // get the element type
+    auto elementType = type.cast<ComplexType>().getElementType();
+    if (elementType.isBF16())
+      return 15; // CUDA_C_16BF
+    if (elementType.isF16())
+      return 6; // CUDA_C_16F
+    if (elementType.isF32())
+      return 4; // CUDA_C_32F
+    if (elementType.isF64())
+      return 5; // CUDA_C_64F
+    if (elementType.isInteger(8))
+      return 7; // CUDA_C_8I
+    if (elementType.isInteger(16))
+      return 21; // CUDA_C_16I
+    if (elementType.isInteger(32))
+      return 11; // CUDA_C_32I
+  }
+  if (type.isBF16())
+    return 14; // CUDA_R_16BF
+  if (type.isF16())
+    return 2; // CUDA_R_16F
+  if (type.isF32())
+    return 0; // CUDA_R_32F
+  if (type.isF64())
+    return 1; // CUDA_R_64F
+  if (type.isInteger(8))
+    return 3; // CUDA_R_8I
+  if (type.isInteger(16))
+    return 20; // CUDA_R_16I
+  if (type.isInteger(32))
+    return 10; // CUDA_R_32I
+
+  llvm_unreachable("unsupported element type");
+}
+
 // Returns whether all operands are of LLVM type.
 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
                                      ConversionPatternRewriter &rewriter) {
@@ -1237,11 +1284,30 @@ static Type getSpMatElemType(Value spMat) {
   llvm_unreachable("cannot find spmat def");
 }
 
-static Value genConstFrom(OpBuilder &builder, Location loc,
-                          gpu::TransposeMode mode) {
+// Returns the element type of the defining dnmat or dnvec op.
+static Type getDnElemType(Value dn) {
+  if (auto op = dn.getDefiningOp<gpu::CreateDnMatOp>())
+    return op.getMemref().getType().getElementType();
+  if (auto op = dn.getDefiningOp<gpu::CreateDnVecOp>())
+    return op.getMemref().getType().getElementType();
+  llvm_unreachable("cannot find dn def");
+}
+
+template <typename T>
+static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
   Type llvmInt32Type = builder.getIntegerType(32);
   return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                          static_cast<int32_t>(mode));
+                                          static_cast<int32_t>(TValue));
+}
+
+static Value
+genConstInt32FromOptionalComputeMode(OpBuilder &builder, Location loc,
+                                     std::optional<Type> computeTypeOptional,
+                                     Type defaultType) {
+  auto computeTypeInt =
+      getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType));
+  auto computeType = genConstInt32From(builder, loc, computeTypeInt);
+  return computeType;
 }
 
 LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -1283,13 +1349,11 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
-  Type dType =
-      llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  Type dType = op.getMemref().getType().getElementType();
+  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
   auto handle =
       createDnVecCallBuilder
-          .create(loc, rewriter, {adaptor.getSize(), pVec, dw, stream})
+          .create(loc, rewriter, {adaptor.getSize(), pVec, dtp, stream})
           .getResult();
   rewriter.replaceOp(op, {handle, stream});
   return success();
@@ -1320,14 +1384,12 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
-  Type dType =
-      llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  Type dType = op.getMemref().getType().getElementType();
+  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
   auto handle =
       createDnMatCallBuilder
           .create(loc, rewriter,
-                  {adaptor.getRows(), adaptor.getCols(), pMat, dw, stream})
+                  {adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
           .getResult();
   rewriter.replaceOp(op, {handle, stream});
   return success();
@@ -1369,15 +1431,13 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
   Type dType =
       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
-  auto iw = rewriter.create<LLVM::ConstantOp>(
-      loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
+  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
   auto handle =
       createCooCallBuilder
           .create(loc, rewriter,
                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
-                   pRowIdxs, pColIdxs, pValues, iw, dw, stream})
+                   pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
           .getResult();
   rewriter.replaceOp(op, {handle, stream});
   return success();
@@ -1408,17 +1468,14 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
       llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
   Type dType =
       llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
-  auto pw = rewriter.create<LLVM::ConstantOp>(
-      loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
-  auto iw = rewriter.create<LLVM::ConstantOp>(
-      loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
+  auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
+  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
   auto handle =
       createCsrCallBuilder
           .create(loc, rewriter,
                   {adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
-                   pRowPos, pColIdxs, pValues, pw, iw, dw, stream})
+                   pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
           .getResult();
   rewriter.replaceOp(op, {handle, stream});
   return success();
@@ -1444,16 +1501,16 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto modeA = genConstFrom(rewriter, loc, op.getModeA());
-  Type dType = getSpMatElemType(op.getSpmatA());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
+  // retrieve the compute type, notice that it may be optional
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize =
       spMVBufferSizeCallBuilder
           .create(loc, rewriter,
                   {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
-                   adaptor.getDnX(), adaptor.getDnY(), dw, stream})
+                   adaptor.getDnX(), adaptor.getDnY(), computeType, stream})
           .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
@@ -1466,10 +1523,10 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  Type dType = getSpMatElemType(op.getSpmatA());
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  // retrieve the compute type, notice that it may be optional
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1477,7 +1534,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMVCallBuilder.create(loc, rewriter,
                          {adaptor.getEnv(), modeA, adaptor.getSpmatA(),
-                          adaptor.getDnX(), adaptor.getDnY(), dw, pBuf,
+                          adaptor.getDnX(), adaptor.getDnY(), computeType, pBuf,
                           stream});
   rewriter.replaceOp(op, {stream});
   return success();
@@ -1490,18 +1547,19 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
-  Type dType = getSpMatElemType(op.getSpmatA());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
-  auto bufferSize =
-      spMMBufferSizeCallBuilder
-          .create(loc, rewriter,
-                  {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
-                   adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream})
-          .getResult();
+  // retrieve the compute type, notice that it may be optional
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+
+  auto bufferSize = spMMBufferSizeCallBuilder
+                        .create(loc, rewriter,
+                                {adaptor.getEnv(), modeA, modeB,
+                                 adaptor.getSpmatA(), adaptor.getDnmatB(),
+                                 adaptor.getDnmatC(), computeType, stream})
+                        .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
 }
@@ -1513,18 +1571,18 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
-  Type dType = getSpMatElemType(op.getSpmatC());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(),
+      getSpMatElemType(op.getSpmatC()));
   auto stream = adaptor.getAsyncDependencies().front();
-  auto bufferSize =
-      SDDMMBufferSizeCallBuilder
-          .create(loc, rewriter,
-                  {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
-                   adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream})
-          .getResult();
+  auto bufferSize = SDDMMBufferSizeCallBuilder
+                        .create(loc, rewriter,
+                                {adaptor.getEnv(), modeA, modeB,
+                                 adaptor.getDnmatA(), adaptor.getDnmatB(),
+                                 adaptor.getSpmatC(), computeType, stream})
+                        .getResult();
   rewriter.replaceOp(op, {bufferSize, stream});
   return success();
 }
@@ -1536,11 +1594,12 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
-  Type dType = getSpMatElemType(op.getSpmatA());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+  // retrieve the compute type, notice that it may be optional
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1548,8 +1607,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   spMMCallBuilder.create(loc, rewriter,
                          {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
-                          adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf,
-                          stream});
+                          adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
+                          pBuf, stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }
@@ -1569,11 +1628,11 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  Type dType = getSpMatElemType(op.getSpmatC());
-  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
-                                              dType.getIntOrFloatBitWidth());
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
+  auto computeType = genConstInt32FromOptionalComputeMode(
+      rewriter, loc, adaptor.getComputeType(),
+      getSpMatElemType(op.getSpmatC()));
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1581,8 +1640,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
   SDDMMCallBuilder.create(loc, rewriter,
                           {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
-                           adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf,
-                           stream});
+                           adaptor.getDnmatB(), adaptor.getSpmatC(),
+                           computeType, pBuf, stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index f928f32425ea6..c7367a8a3893c 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -17,6 +17,8 @@
 #include <stdio.h>
 
 #include "cuda.h"
+#include "cuda_bf16.h"
+#include "cuda_fp16.h"
 #include "cusparse.h"
 
 #ifdef _WIN32
@@ -228,38 +230,32 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
 /// Wrapper methods for the cuSparse library.
 ///
 
-static inline cudaDataType_t dataTp(int32_t width) {
-  switch (width) {
-  case 32:
-    return CUDA_R_32F;
-  default:
-    return CUDA_R_64F;
-  }
-}
-
-static inline cusparseIndexType_t idxTp(int32_t width) {
-  switch (width) {
-  case 32:
-    return CUSPARSE_INDEX_32I;
-  default:
-    return CUSPARSE_INDEX_64I;
-  }
-}
-
 // Some macro magic to get float/double alpha and beta on host.
-#define ALPHABETA(w, alpha, beta)                                              \
+#define ALPHABETA(dtp, alpha, beta)                                            \
+  __nv_bfloat16(alpha##bf16) = 1.0f;                                           \
+  __nv_bfloat16(beta##bf16) = 1.0f;                                            \
+  __half(alpha##f16) = 1.0f;                                                   \
+  __half(beta##f16) = 1.0f;                                                    \
   float(alpha##f) = 1.0f;                                                      \
   float(beta##f) = 1.0f;                                                       \
   double(alpha##d) = 1.0;                                                      \
   double(beta##d) = 1.0;                                                       \
   const void *(alpha##p) = nullptr;                                            \
   const void *(beta##p) = nullptr;                                             \
-  if ((w) == 32) {                                                             \
+  if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) {                              \
+    (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf));                     \
+    (beta##p) = reinterpret_cast<void *>(&(beta##16bf));                       \
+  } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) {                         \
+    (alpha##p) = reinterpret_cast<void *>(&(alpha##16f));                      \
+    (beta##p) = reinterpret_cast<void *>(&(beta##16f));                        \
+  } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) {                         \
     (alpha##p) = reinterpret_cast<void *>(&(alpha##f));                        \
     (beta##p) = reinterpret_cast<void *>(&(beta##f));                          \
-  } else {                                                                     \
+  } else if (dtp == CUDA_R_64F || dtp == CUDA_C_64F) {                         \
     (alpha##p) = reinterpret_cast<void *>(&(alpha##d));                        \
     (beta##p) = reinterpret_cast<void *>(&(beta##d));                          \
+  } else {                                                                     \
+    llvm_unreachable("Unsupported data type");                                 \
   }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
@@ -276,10 +272,10 @@ mgpuDestroySparseEnv(void *h, CUstream /*stream*/) {
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
-mgpuCreateDnVec(intptr_t size, void *values, int32_t dw, CUstream /*stream*/) {
+mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) {
   cusparseDnVecDescr_t vec = nullptr;
-  cudaDataType_t dtp = dataTp(dw);
-  CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dtp))
+  auto dTp = static_cast<cudaDataType_t>(dtp);
+  CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp))
   return reinterpret_cast<void *>(vec);
 }
 
@@ -290,12 +286,12 @@ mgpuDestroyDnVec(void *v, CUstream /*stream*/) {
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
-mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dw,
+mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp,
                 CUstream /*stream*/) {
   cusparseDnMatDescr_t mat = nullptr;
-  cudaDataType_t dtp = dataTp(dw);
+  auto dTp = static_cast<cudaDataType_t>(dtp);
   CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols,
-                                               values, dtp, CUSPARSE_ORDER_ROW))
+                                               values, dTp, CUSPARSE_ORDER_ROW))
   return reinterpret_cast<void *>(mat);
 }
 
@@ -307,27 +303,26 @@ mgpuDestroyDnMat(void *m, CUstream /*stream*/) {
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
 mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs,
-              void *colIdxs, void *values, int32_t iw, int32_t dw,
+              void *colIdxs, void *values, int32_t itp, int32_t dtp,
               CUstream /*stream*/) {
   cusparseSpMatDescr_t mat = nullptr;
-  cusparseIndexType_t itp = idxTp(iw);
-  cudaDataType_t dtp = dataTp(dw);
+  auto iTp = static_cast<cusparseIndexType_t>(itp);
+  auto dTp = static_cast<cudaDataType_t>(dtp);
   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs,
-                                             colIdxs, values, itp,
-                                             CUSPARSE_INDEX_BASE_ZERO, dtp))
+                                             colIdxs, values, iTp,
+                                             CUSPARSE_INDEX_BASE_ZERO, dTp))
   return reinterpret_cast<void *>(mat);
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
 mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
-              void *colIdxs, void *values, int32_t pw, int32_t iw, int32_t dw,
-              CUstream /*stream*/) {
+              void *colIdxs, void *values, int32_t ptp, int32_t itp,
+              int32_t dtp, CUstream /*stream*/) {
   cusparseSpMatDescr_t mat = nullptr;
-  cusparseIndexType_t ptp = idxTp(pw);
-  cusparseIndexType_t itp = idxTp(iw);
-  cudaDataType_t dtp = dataTp(dw);
+  auto pTp = static_cast<cusparseIndexType_t>(ptp);
+  auto iTp = static_cast<cusparseIndexType_t>(itp);
   CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
-                                             colIdxs, values, ptp, itp,
+                                             colIdxs, values, pTp, iTp,
                                              CUSPARSE_INDEX_BASE_ZERO, dtp))
   return reinterpret_cast<void *>(mat);
 }
@@ -339,102 +334,102 @@ mgpuDestroySpMat(void *m, CUstream /*stream*/) {
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
-mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw,
+mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t ctp,
                    CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(
       cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY,
-                              dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
+                              cTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a,
-                                                   void *x, void *y, int32_t dw,
-                                                   void *buf,
+                                                   void *x, void *y,
+                                                   int32_t ctp, void *buf,
                                                    CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
   cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX,
-                                        betap, vecY, dtp,
+                                        betap, vecY, cTp,
                                         CUSPARSE_SPMV_ALG_DEFAULT, buf))
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
-                   int32_t dw, CUstream /*stream*/) {
+                   int32_t ctp, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
-      handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp,
+      handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
       CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
-         void *buf, CUstream /*stream*/) {
+mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+         int32_t ctp, void *buf, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA,
-                                        matB, betap, matC, dtp,
+                                        matB, betap, matC, cTp,
                                         CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
-                    int32_t dw, CUstream /*stream*/) {
+                    int32_t ctp, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   size_t bufferSize = 0;
   CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
-      handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp,
+      handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
       CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
   return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
-          int32_t dw, void *buf, CUstream /*stream*/) {
+          int32_t ctp, void *buf, CUstream /*stream*/) {
   cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
   cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
   cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
-  cudaDataType_t dtp = dataTp(dw);
-  ALPHABETA(dw, alpha, beta)
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
   CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(handle, modeA, modeB, alphap, matA,
-                                         matB, betap, matC, dtp,
+                                         matB, betap, matC, cTp,
                                          CUSPARSE_SDDMM_ALG_DEFAULT, buf))
 }

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
new file mode 100644
index 0000000000000..6465208791dd5
--- /dev/null
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+  // CHECK-LABEL: func @matvec
+  // CHECK: %{{.*}} = gpu.wait async
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_coo async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_vec async [%{{.*}}] %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.spmv_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+  // CHECK: %{{.*}} = gpu.spmv async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_dn_vec async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+  // CHECK: gpu.wait [%{{.*}}]
+  // CHECK: return
+  func.func @matvec(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %env, %token3 = gpu.create_sparse_env async [%token2]
+    %spmat, %token4 = gpu.create_coo async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %dnvec, %token5 = gpu.create_dn_vec async [%token4] %mem2, %arg0 : memref<?xf64>
+    %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec
+    %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
+    %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+    %token9 = gpu.destroy_dn_vec async [%token8] %dnvec
+    %token10 = gpu.destroy_sparse_env async [%token9] %env
+    gpu.wait [%token10]
+    return
+  }
+
+  // CHECK-LABEL: func @matmul
+  // CHECK: %{{.*}} = gpu.wait async
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_mat async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.spmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} into f64
+  // CHECK: %{{.*}} = gpu.spmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64> into f64
+  // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_dn_mat async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+  // CHECK: gpu.wait [%{{.*}}]
+  // CHECK: return
+  func.func @matmul(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %env, %token3 = gpu.create_sparse_env async [%token2]
+    %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat into f64
+    %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64> into f64
+    %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+    %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
+    %token10 = gpu.destroy_sparse_env async [%token9] %env
+    gpu.wait [%token10]
+    return
+  }
+
+  // CHECK-LABEL: func @sddmm
+  // CHECK: %{{.*}} = gpu.wait async
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_mat async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.sddmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+  // CHECK: %{{.*}} = gpu.sddmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_dn_mat async [%{{.*}}] %{{.*}}
+  // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+  // CHECK: gpu.wait [%{{.*}}]
+  // CHECK: return
+  func.func @sddmm(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %env, %token3 = gpu.create_sparse_env async [%token2]
+    %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
+    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+    %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+    %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
+    %token10 = gpu.destroy_sparse_env async [%token9] %env
+    gpu.wait [%token10]
+    return
+  }
+
+}
+
+


        


More information about the Mlir-commits mailing list