[Mlir-commits] [mlir] cc402de - [mlir][sparse][gpu] add result type to spmv and spmm gpu libgen path
Kun Wu
llvmlistbot at llvm.org
Thu Jun 1 10:17:48 PDT 2023
Author: Kun Wu
Date: 2023-06-01T17:17:40Z
New Revision: cc402de0b13b8682fec5762b2cf5064f9c8297f8
URL: https://github.com/llvm/llvm-project/commit/cc402de0b13b8682fec5762b2cf5064f9c8297f8
DIFF: https://github.com/llvm/llvm-project/commit/cc402de0b13b8682fec5762b2cf5064f9c8297f8.diff
LOG: [mlir][sparse][gpu] add result type to spmv and spmm gpu libgen path
Differential Revision: https://reviews.llvm.org/D151592
Added:
mlir/test/Dialect/GPU/sparse-roundtrip.mlir
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a401feea3d075..2fac955ce6a12 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1860,7 +1860,7 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
Example:
```mlir
- %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY
+ %buffersz, %token = gpu.spmv_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY into f32
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -1868,26 +1868,28 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
GPU_TransposeModeAttr:$modeA,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnVecHandle:$dnX,
- GPU_SparseDnVecHandle:$dnY);
+ GPU_SparseDnVecHandle:$dnY,
+ OptionalAttr<TypeAttr>:$computeType);
let results = (outs Res<Index>:$bufferSz,
Optional<GPU_AsyncToken>:$asyncToken);
let builders = [OpBuilder<(ins
- "::mlir::Type":$bufferSz,
- "::mlir::Type":$asyncToken,
- "::mlir::ValueRange":$asyncDependencies,
- "::mlir::Value":$env,
- "::mlir::Value":$spmatA,
- "::mlir::Value":$dnX,
- "::mlir::Value":$dnY), [{
+ "Type":$bufferSz,
+ "Type":$asyncToken,
+ "ValueRange":$asyncDependencies,
+ "Value":$env,
+ "Value":$spmatA,
+ "Value":$dnX,
+ "Value":$dnY)
+ , [{
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
- return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, env,
- modeA, spmatA, dnX, dnY);}]>
+ return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
+ env, modeA, spmatA, dnX, dnY, {});}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict ( `into` $computeType^)?
}];
}
@@ -1910,7 +1912,7 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
Example:
```mlir
- %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64>
+ %token = gpu.spmv async [%dep] %env, %spmatA{TRANSPOSE}, %dnX, %dnY : memref<?xf64> into bf16
```
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -1919,25 +1921,26 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnVecHandle:$dnX,
GPU_SparseDnVecHandle:$dnY,
+ OptionalAttr<TypeAttr>:$computeType,
AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let builders = [OpBuilder<(ins
- "::mlir::Type":$asyncToken,
- "::mlir::ValueRange":$asyncDependencies,
- "::mlir::Value":$env,
- "::mlir::Value":$spmatA,
- "::mlir::Value":$dnX,
- "::mlir::Value":$dnY,
- "::mlir::Value":$buffer), [{
+ "Type":$asyncToken,
+ "ValueRange":$asyncDependencies,
+ "Value":$env,
+ "Value":$spmatA,
+ "Value":$dnX,
+ "Value":$dnY,
+ "Value":$buffer), [{
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
- spmatA, dnX, dnY, buffer);}]>
+ spmatA, dnX, dnY, {}, buffer);}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer)
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
}];
}
@@ -1960,7 +1963,7 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
Example:
```mlir
- %buffersz, %token = gpu.spmm_buffersize async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC
+ %buffersz, %token = gpu.spmm_buffer_size async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC into f32
```
}];
@@ -1970,27 +1973,28 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
GPU_TransposeModeAttr:$modeB,
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnMatHandle:$dnmatB,
- GPU_SparseDnMatHandle:$dnmatC);
+ GPU_SparseDnMatHandle:$dnmatC,
+ OptionalAttr<TypeAttr>:$computeType);
let results = (outs Res<Index>:$bufferSz,
Optional<GPU_AsyncToken>:$asyncToken);
let builders = [OpBuilder<(ins
- "::mlir::Type":$bufferSz,
- "::mlir::Type":$asyncToken,
- "::mlir::ValueRange":$asyncDependencies,
- "::mlir::Value":$env,
- "::mlir::Value":$spmatA,
- "::mlir::Value":$dnmatB,
- "::mlir::Value":$dnmatC), [{
+ "Type":$bufferSz,
+ "Type":$asyncToken,
+ "ValueRange":$asyncDependencies,
+ "Value":$env,
+ "Value":$spmatA,
+ "Value":$dnmatB,
+ "Value":$dnmatC), [{
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
- env, modeA, modeB, spmatA, dnmatB, dnmatC);}]>
+ env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict ( `into` $computeType^)?
}];
}
@@ -2013,7 +2017,7 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
Example:
```mlir
- %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer
+ %token = gpu.spmm async [%dep] %env, %spmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %dnmatC, %buffer into f32
```
}];
@@ -2024,26 +2028,27 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
GPU_SparseSpMatHandle:$spmatA,
GPU_SparseDnMatHandle:$dnmatB,
GPU_SparseDnMatHandle:$dnmatC,
+ OptionalAttr<TypeAttr>:$computeType,
AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let builders = [OpBuilder<(ins
- "::mlir::Type":$asyncToken,
- "::mlir::ValueRange":$asyncDependencies,
- "::mlir::Value":$env,
- "::mlir::Value":$spmatA,
- "::mlir::Value":$dnmatB,
- "::mlir::Value":$dnmatC,
- "::mlir::Value":$buffer), [{
+ "Type":$asyncToken,
+ "ValueRange":$asyncDependencies,
+ "Value":$env,
+ "Value":$spmatA,
+ "Value":$dnmatB,
+ "Value":$dnmatC,
+ "Value":$buffer), [{
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
- modeB, spmatA, dnmatB, dnmatC, buffer);}]>
+ modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer)
+ $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
}];
}
@@ -2062,7 +2067,7 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
Example:
```mlir
- %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC
+ %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC into f32
```
The matrix arguments can also be associated with one of the following
@@ -2076,7 +2081,8 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
GPU_TransposeModeAttr:$modeB,
GPU_SparseDnMatHandle:$dnmatA,
GPU_SparseDnMatHandle:$dnmatB,
- GPU_SparseSpMatHandle:$spmatC);
+ GPU_SparseSpMatHandle:$spmatC,
+ OptionalAttr<TypeAttr>:$computeType);
let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
let builders = [OpBuilder<(ins
@@ -2090,12 +2096,12 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
- env, modeA, modeB, dnmatA, dnmatB, spmatC);}]>
+ env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict
+ $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict ( `into` $computeType^)?
}];
}
@@ -2114,7 +2120,7 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
Example:
```mlir
- %token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer
+ %token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer into f32
```
The matrix arguments can also be associated with one of the following
@@ -2129,6 +2135,7 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
GPU_SparseDnMatHandle:$dnmatA,
GPU_SparseDnMatHandle:$dnmatB,
GPU_SparseSpMatHandle:$spmatC,
+ OptionalAttr<TypeAttr>:$computeType,
AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
@@ -2143,12 +2150,12 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
- modeB, dnmatA, dnmatB, spmatC, buffer);}]>
+ modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]>
];
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
- $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer)
+ $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
}];
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 07ca1e51ed696..5ec455a65bd65 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -688,6 +688,53 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
return builder.create<LLVM::CallOp>(loc, function, arguments);
}
+// Corresponding to cusparseIndexType_t defined in cusparse.h.
+static int32_t getCuSparseIndexTypeFrom(Type type) {
+ if (type.isa<IndexType>())
+ return 3; // CUSPARSE_INDEX_64I
+ else
+ return 2; // CUSPARSE_INDEX_32I
+ // TODO: add support to CUSPARSE_INDEX_16U: 1
+}
+
+// Corresponding to cudaDataType_t defined in CUDA library_types.h.
+static int32_t getCuSparseDataTypeFrom(Type type) {
+ if (llvm::isa<ComplexType>(type)) {
+ // get the element type
+ auto elementType = type.cast<ComplexType>().getElementType();
+ if (elementType.isBF16())
+ return 15; // CUDA_C_16BF
+ if (elementType.isF16())
+ return 6; // CUDA_C_16F
+ if (elementType.isF32())
+ return 4; // CUDA_C_32F
+ if (elementType.isF64())
+ return 5; // CUDA_C_64F
+ if (elementType.isInteger(8))
+ return 7; // CUDA_C_8I
+ if (elementType.isInteger(16))
+ return 21; // CUDA_C_16I
+ if (elementType.isInteger(32))
+ return 11; // CUDA_C_32I
+ }
+ if (type.isBF16())
+ return 14; // CUDA_R_16BF
+ if (type.isF16())
+ return 2; // CUDA_R_16F
+ if (type.isF32())
+ return 0; // CUDA_R_32F
+ if (type.isF64())
+ return 1; // CUDA_R_64F
+ if (type.isInteger(8))
+ return 3; // CUDA_R_8I
+ if (type.isInteger(16))
+ return 20; // CUDA_R_16I
+ if (type.isInteger(32))
+ return 10; // CUDA_R_32I
+
+ llvm_unreachable("unsupported element type");
+}
+
// Returns whether all operands are of LLVM type.
static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
ConversionPatternRewriter &rewriter) {
@@ -1237,11 +1284,30 @@ static Type getSpMatElemType(Value spMat) {
llvm_unreachable("cannot find spmat def");
}
-static Value genConstFrom(OpBuilder &builder, Location loc,
- gpu::TransposeMode mode) {
+// Returns the element type of the defining dnmat or dnvec op.
+static Type getDnElemType(Value dn) {
+ if (auto op = dn.getDefiningOp<gpu::CreateDnMatOp>())
+ return op.getMemref().getType().getElementType();
+ if (auto op = dn.getDefiningOp<gpu::CreateDnVecOp>())
+ return op.getMemref().getType().getElementType();
+ llvm_unreachable("cannot find dn def");
+}
+
+template <typename T>
+static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
Type llvmInt32Type = builder.getIntegerType(32);
return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- static_cast<int32_t>(mode));
+ static_cast<int32_t>(TValue));
+}
+
+static Value
+genConstInt32FromOptionalComputeMode(OpBuilder &builder, Location loc,
+ std::optional<Type> computeTypeOptional,
+ Type defaultType) {
+ auto computeTypeInt =
+ getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType));
+ auto computeType = genConstInt32From(builder, loc, computeTypeInt);
+ return computeType;
}
LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -1283,13 +1349,11 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
- Type dType =
- llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ Type dType = op.getMemref().getType().getElementType();
+ auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createDnVecCallBuilder
- .create(loc, rewriter, {adaptor.getSize(), pVec, dw, stream})
+ .create(loc, rewriter, {adaptor.getSize(), pVec, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
@@ -1320,14 +1384,12 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
if (!getTypeConverter()->useOpaquePointers())
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
- Type dType =
- llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ Type dType = op.getMemref().getType().getElementType();
+ auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createDnMatCallBuilder
.create(loc, rewriter,
- {adaptor.getRows(), adaptor.getCols(), pMat, dw, stream})
+ {adaptor.getRows(), adaptor.getCols(), pMat, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
@@ -1369,15 +1431,13 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
- auto iw = rewriter.create<LLVM::ConstantOp>(
- loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
+ auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCooCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
- pRowIdxs, pColIdxs, pValues, iw, dw, stream})
+ pRowIdxs, pColIdxs, pValues, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
@@ -1408,17 +1468,14 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
Type dType =
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
- auto pw = rewriter.create<LLVM::ConstantOp>(
- loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
- auto iw = rewriter.create<LLVM::ConstantOp>(
- loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto ptp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(pType));
+ auto itp = genConstInt32From(rewriter, loc, getCuSparseIndexTypeFrom(iType));
+ auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
auto handle =
createCsrCallBuilder
.create(loc, rewriter,
{adaptor.getRows(), adaptor.getCols(), adaptor.getNnz(),
- pRowPos, pColIdxs, pValues, pw, iw, dw, stream})
+ pRowPos, pColIdxs, pValues, ptp, itp, dtp, stream})
.getResult();
rewriter.replaceOp(op, {handle, stream});
return success();
@@ -1444,16 +1501,16 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- auto modeA = genConstFrom(rewriter, loc, op.getModeA());
- Type dType = getSpMatElemType(op.getSpmatA());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
+ // retrieve the compute type, notice that it may be optional
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
auto stream = adaptor.getAsyncDependencies().front();
auto bufferSize =
spMVBufferSizeCallBuilder
.create(loc, rewriter,
{adaptor.getEnv(), modeA, adaptor.getSpmatA(),
- adaptor.getDnX(), adaptor.getDnY(), dw, stream})
+ adaptor.getDnX(), adaptor.getDnY(), computeType, stream})
.getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
@@ -1466,10 +1523,10 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- Type dType = getSpMatElemType(op.getSpmatA());
- auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+ // retrieve the compute type, notice that it may be optional
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1477,7 +1534,7 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
spMVCallBuilder.create(loc, rewriter,
{adaptor.getEnv(), modeA, adaptor.getSpmatA(),
- adaptor.getDnX(), adaptor.getDnY(), dw, pBuf,
+ adaptor.getDnX(), adaptor.getDnY(), computeType, pBuf,
stream});
rewriter.replaceOp(op, {stream});
return success();
@@ -1490,18 +1547,19 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
- auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
- Type dType = getSpMatElemType(op.getSpmatA());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
- auto bufferSize =
- spMMBufferSizeCallBuilder
- .create(loc, rewriter,
- {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
- adaptor.getDnmatB(), adaptor.getDnmatC(), dw, stream})
- .getResult();
+ // retrieve the compute type, notice that it may be optional
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+
+ auto bufferSize = spMMBufferSizeCallBuilder
+ .create(loc, rewriter,
+ {adaptor.getEnv(), modeA, modeB,
+ adaptor.getSpmatA(), adaptor.getDnmatB(),
+ adaptor.getDnmatC(), computeType, stream})
+ .getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
}
@@ -1513,18 +1571,18 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
- auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
- Type dType = getSpMatElemType(op.getSpmatC());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(),
+ getSpMatElemType(op.getSpmatC()));
auto stream = adaptor.getAsyncDependencies().front();
- auto bufferSize =
- SDDMMBufferSizeCallBuilder
- .create(loc, rewriter,
- {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
- adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream})
- .getResult();
+ auto bufferSize = SDDMMBufferSizeCallBuilder
+ .create(loc, rewriter,
+ {adaptor.getEnv(), modeA, modeB,
+ adaptor.getDnmatA(), adaptor.getDnmatB(),
+ adaptor.getSpmatC(), computeType, stream})
+ .getResult();
rewriter.replaceOp(op, {bufferSize, stream});
return success();
}
@@ -1536,11 +1594,12 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
- auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
- Type dType = getSpMatElemType(op.getSpmatA());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
+ auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+ // retrieve the compute type, notice that it may be optional
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1548,8 +1607,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
spMMCallBuilder.create(loc, rewriter,
{adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
- adaptor.getDnmatB(), adaptor.getDnmatC(), dw, pBuf,
- stream});
+ adaptor.getDnmatB(), adaptor.getDnmatC(), computeType,
+ pBuf, stream});
rewriter.replaceOp(op, {stream});
return success();
}
@@ -1569,11 +1628,11 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
failed(isAsyncWithOneDependency(rewriter, op)))
return failure();
Location loc = op.getLoc();
- Type dType = getSpMatElemType(op.getSpmatC());
- auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
- dType.getIntOrFloatBitWidth());
- auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
- auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
+ auto computeType = genConstInt32FromOptionalComputeMode(
+ rewriter, loc, adaptor.getComputeType(),
+ getSpMatElemType(op.getSpmatC()));
+ auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+ auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
auto stream = adaptor.getAsyncDependencies().front();
Value pBuf =
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1581,8 +1640,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
SDDMMCallBuilder.create(loc, rewriter,
{adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
- adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf,
- stream});
+ adaptor.getDnmatB(), adaptor.getSpmatC(),
+ computeType, pBuf, stream});
rewriter.replaceOp(op, {stream});
return success();
}
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index f928f32425ea6..c7367a8a3893c 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -17,6 +17,8 @@
#include <stdio.h>
#include "cuda.h"
+#include "cuda_bf16.h"
+#include "cuda_fp16.h"
#include "cusparse.h"
#ifdef _WIN32
@@ -228,38 +230,32 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
/// Wrapper methods for the cuSparse library.
///
-static inline cudaDataType_t dataTp(int32_t width) {
- switch (width) {
- case 32:
- return CUDA_R_32F;
- default:
- return CUDA_R_64F;
- }
-}
-
-static inline cusparseIndexType_t idxTp(int32_t width) {
- switch (width) {
- case 32:
- return CUSPARSE_INDEX_32I;
- default:
- return CUSPARSE_INDEX_64I;
- }
-}
-
// Some macro magic to get float/double alpha and beta on host.
-#define ALPHABETA(w, alpha, beta) \
+#define ALPHABETA(dtp, alpha, beta) \
+ __nv_bfloat16(alpha##bf16) = 1.0f; \
+ __nv_bfloat16(beta##bf16) = 1.0f; \
+ __half(alpha##f16) = 1.0f; \
+ __half(beta##f16) = 1.0f; \
float(alpha##f) = 1.0f; \
float(beta##f) = 1.0f; \
double(alpha##d) = 1.0; \
double(beta##d) = 1.0; \
const void *(alpha##p) = nullptr; \
const void *(beta##p) = nullptr; \
- if ((w) == 32) { \
+ if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) { \
+ (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf)); \
+ (beta##p) = reinterpret_cast<void *>(&(beta##16bf)); \
+ } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) { \
+ (alpha##p) = reinterpret_cast<void *>(&(alpha##16f)); \
+ (beta##p) = reinterpret_cast<void *>(&(beta##16f)); \
+ } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \
(alpha##p) = reinterpret_cast<void *>(&(alpha##f)); \
(beta##p) = reinterpret_cast<void *>(&(beta##f)); \
- } else { \
+ } else if (dtp == CUDA_R_64F || dtp == CUDA_C_64F) { \
(alpha##p) = reinterpret_cast<void *>(&(alpha##d)); \
(beta##p) = reinterpret_cast<void *>(&(beta##d)); \
+ } else { \
+ llvm_unreachable("Unsupported data type"); \
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
@@ -276,10 +272,10 @@ mgpuDestroySparseEnv(void *h, CUstream /*stream*/) {
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
-mgpuCreateDnVec(intptr_t size, void *values, int32_t dw, CUstream /*stream*/) {
+mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) {
cusparseDnVecDescr_t vec = nullptr;
- cudaDataType_t dtp = dataTp(dw);
- CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dtp))
+ auto dTp = static_cast<cudaDataType_t>(dtp);
+ CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp))
return reinterpret_cast<void *>(vec);
}
@@ -290,12 +286,12 @@ mgpuDestroyDnVec(void *v, CUstream /*stream*/) {
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
-mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dw,
+mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp,
CUstream /*stream*/) {
cusparseDnMatDescr_t mat = nullptr;
- cudaDataType_t dtp = dataTp(dw);
+ auto dTp = static_cast<cudaDataType_t>(dtp);
CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols,
- values, dtp, CUSPARSE_ORDER_ROW))
+ values, dTp, CUSPARSE_ORDER_ROW))
return reinterpret_cast<void *>(mat);
}
@@ -307,27 +303,26 @@ mgpuDestroyDnMat(void *m, CUstream /*stream*/) {
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs,
- void *colIdxs, void *values, int32_t iw, int32_t dw,
+ void *colIdxs, void *values, int32_t itp, int32_t dtp,
CUstream /*stream*/) {
cusparseSpMatDescr_t mat = nullptr;
- cusparseIndexType_t itp = idxTp(iw);
- cudaDataType_t dtp = dataTp(dw);
+ auto iTp = static_cast<cusparseIndexType_t>(itp);
+ auto dTp = static_cast<cudaDataType_t>(dtp);
CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs,
- colIdxs, values, itp,
- CUSPARSE_INDEX_BASE_ZERO, dtp))
+ colIdxs, values, iTp,
+ CUSPARSE_INDEX_BASE_ZERO, dTp))
return reinterpret_cast<void *>(mat);
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
- void *colIdxs, void *values, int32_t pw, int32_t iw, int32_t dw,
- CUstream /*stream*/) {
+ void *colIdxs, void *values, int32_t ptp, int32_t itp,
+ int32_t dtp, CUstream /*stream*/) {
cusparseSpMatDescr_t mat = nullptr;
- cusparseIndexType_t ptp = idxTp(pw);
- cusparseIndexType_t itp = idxTp(iw);
- cudaDataType_t dtp = dataTp(dw);
+ auto pTp = static_cast<cusparseIndexType_t>(ptp);
+ auto iTp = static_cast<cusparseIndexType_t>(itp);
CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
- colIdxs, values, ptp, itp,
+ colIdxs, values, pTp, iTp,
CUSPARSE_INDEX_BASE_ZERO, dtp))
return reinterpret_cast<void *>(mat);
}
@@ -339,102 +334,102 @@ mgpuDestroySpMat(void *m, CUstream /*stream*/) {
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
-mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t dw,
+mgpuSpMVBufferSize(void *h, int32_t ma, void *a, void *x, void *y, int32_t ctp,
CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
size_t bufferSize = 0;
CUSPARSE_REPORT_IF_ERROR(
cusparseSpMV_bufferSize(handle, modeA, alphap, matA, vecX, betap, vecY,
- dtp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
+ cTp, CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(void *h, int32_t ma, void *a,
- void *x, void *y, int32_t dw,
- void *buf,
+ void *x, void *y,
+ int32_t ctp, void *buf,
CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(handle, modeA, alphap, matA, vecX,
- betap, vecY, dtp,
+ betap, vecY, cTp,
CUSPARSE_SPMV_ALG_DEFAULT, buf))
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
mgpuSpMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
- int32_t dw, CUstream /*stream*/) {
+ int32_t ctp, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
size_t bufferSize = 0;
CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
- handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp,
+ handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
- void *buf, CUstream /*stream*/) {
+mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+ int32_t ctp, void *buf, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(handle, modeA, modeB, alphap, matA,
- matB, betap, matC, dtp,
+ matB, betap, matC, cTp,
CUSPARSE_SPMM_ALG_DEFAULT, buf))
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
- int32_t dw, CUstream /*stream*/) {
+ int32_t ctp, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ auto cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
size_t bufferSize = 0;
CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
- handle, modeA, modeB, alphap, matA, matB, betap, matC, dtp,
+ handle, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
}
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
- int32_t dw, void *buf, CUstream /*stream*/) {
+ int32_t ctp, void *buf, CUstream /*stream*/) {
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
- cudaDataType_t dtp = dataTp(dw);
- ALPHABETA(dw, alpha, beta)
+ auto cTp = static_cast<cudaDataType_t>(ctp);
+ ALPHABETA(cTp, alpha, beta)
CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(handle, modeA, modeB, alphap, matA,
- matB, betap, matC, dtp,
+ matB, betap, matC, cTp,
CUSPARSE_SDDMM_ALG_DEFAULT, buf))
}
diff --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
new file mode 100644
index 0000000000000..6465208791dd5
--- /dev/null
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -0,0 +1,97 @@
+// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s
+
+module attributes {gpu.container_module} {
+
+ // CHECK-LABEL: func @matvec
+ // CHECK: %{{.*}} = gpu.wait async
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_coo async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_vec async [%{{.*}}] %{{.*}}, %{{.*}} : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.spmv_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ // CHECK: %{{.*}} = gpu.spmv async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+ // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_dn_vec async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+ // CHECK: gpu.wait [%{{.*}}]
+ // CHECK: return
+ func.func @matvec(%arg0: index) {
+ %token0 = gpu.wait async
+ %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+ %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+ %env, %token3 = gpu.create_sparse_env async [%token2]
+ %spmat, %token4 = gpu.create_coo async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ %dnvec, %token5 = gpu.create_dn_vec async [%token4] %mem2, %arg0 : memref<?xf64>
+ %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec
+ %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
+ %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+ %token9 = gpu.destroy_dn_vec async [%token8] %dnvec
+ %token10 = gpu.destroy_sparse_env async [%token9] %env
+ gpu.wait [%token10]
+ return
+ }
+
+ // CHECK-LABEL: func @matmul
+ // CHECK: %{{.*}} = gpu.wait async
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_mat async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.spmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} into f64
+ // CHECK: %{{.*}} = gpu.spmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64> into f64
+ // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_dn_mat async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+ // CHECK: gpu.wait [%{{.*}}]
+ // CHECK: return
+ func.func @matmul(%arg0: index) {
+ %token0 = gpu.wait async
+ %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+ %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+ %env, %token3 = gpu.create_sparse_env async [%token2]
+ %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
+ %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat into f64
+ %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64> into f64
+ %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+ %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
+ %token10 = gpu.destroy_sparse_env async [%token9] %env
+ gpu.wait [%token10]
+ return
+ }
+
+ // CHECK-LABEL: func @sddmm
+ // CHECK: %{{.*}} = gpu.wait async
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_mat async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+ // CHECK: %{{.*}}, %{{.*}} = gpu.sddmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
+ // CHECK: %{{.*}} = gpu.sddmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+ // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_dn_mat async [%{{.*}}] %{{.*}}
+ // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
+ // CHECK: gpu.wait [%{{.*}}]
+ // CHECK: return
+ func.func @sddmm(%arg0: index) {
+ %token0 = gpu.wait async
+ %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+ %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+ %env, %token3 = gpu.create_sparse_env async [%token2]
+ %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+ %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
+ %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
+ %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+ %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+ %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
+ %token10 = gpu.destroy_sparse_env async [%token9] %env
+ gpu.wait [%token10]
+ return
+ }
+
+}
+
+
More information about the Mlir-commits
mailing list