[Mlir-commits] [mlir] 86bf710 - [mlir] [gpu] [sparse] refined SparseHandle type
Aart Bik
llvmlistbot at llvm.org
Wed May 24 10:16:16 PDT 2023
Author: Kun Wu
Date: 2023-05-24T10:16:07-07:00
New Revision: 86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4
URL: https://github.com/llvm/llvm-project/commit/86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4
DIFF: https://github.com/llvm/llvm-project/commit/86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4.diff
LOG: [mlir] [gpu] [sparse] refined SparseHandle type
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D151014
Added:
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index e56af3e050486..ddef02095c64f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -109,10 +109,30 @@ class MMAMatrixOf<list<Type> allowedTypes> :
"::llvm::cast<::mlir::gpu::MMAMatrixType>($_self).getElementType()",
"gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
-// Generic type for all sparse handles (could be refined).
-def GPU_SparseHandle : DialectType<
- GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseHandleType>()">, "sparse handle type">,
- BuildableType<"mlir::gpu::SparseHandleType::get($_builder.getContext())">;
+// Types for all sparse handles.
+def GPU_SparseEnvHandle :
+ DialectType<GPU_Dialect,
+ CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">,
+ "sparse environment handle type">,
+ BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
+
+def GPU_SparseDnVecHandle :
+ DialectType<GPU_Dialect,
+ CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">,
+ "dense vector handle type">,
+ BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
+
+def GPU_SparseDnMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">,
+ "dense matrix handle type">,
+ BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
+
+def GPU_SparseSpMatHandle :
+ DialectType<GPU_Dialect,
+ CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">,
+ "sparse matrix handle type">,
+ BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
//===----------------------------------------------------------------------===//
// GPU Interfaces.
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 64b8f8f6e8b3a..a775066362cab 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -163,14 +163,23 @@ class MMAMatrixType
// Adds a `gpu.async.token` to the front of the argument list.
void addAsyncDependency(Operation *op, Value token);
-// Represents any sparse handle.
+// Handle types for sparse.
+enum class SparseHandleKind { Env, DnVec, DnMat, SpMat };
+
+template <SparseHandleKind K>
class SparseHandleType
- : public Type::TypeBase<SparseHandleType, Type, TypeStorage> {
+ : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
public:
- // Used for generic hooks in TypeBase.
+ using Base =
+ typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
using Base::Base;
};
+using SparseEnvHandleType = SparseHandleType<SparseHandleKind::Env>;
+using SparseDnVecHandleType = SparseHandleType<SparseHandleKind::DnVec>;
+using SparseDnMatHandleType = SparseHandleType<SparseHandleKind::DnMat>;
+using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
+
} // namespace gpu
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 77e65972038da..5160b6886817e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1557,14 +1557,16 @@ def GPU_CreateSparseEnvOp : GPU_Op<"create_sparse_env", [GPU_AsyncOpInterface]>
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies);
- let results = (outs Res<GPU_SparseHandle>:$env, Optional<GPU_AsyncToken>:$asyncToken);
-
+ let results = (outs Res<GPU_SparseEnvHandle>:$env,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) attr-dict
}];
}
-def GPU_DestroySparseEnvOp : GPU_Op<"destroy_sparse_env", [GPU_AsyncOpInterface]> {
+def GPU_DestroySparseEnvOp : GPU_Op<
+ "destroy_sparse_env",
+ [GPU_AsyncOpInterface]> {
let summary = "Destroy sparse environment operation";
let description = [{
The `gpu.destroy_sparse_env` operation releases all resources of a sparse
@@ -1583,11 +1585,12 @@ def GPU_DestroySparseEnvOp : GPU_Op<"destroy_sparse_env", [GPU_AsyncOpInterface]
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Arg<GPU_SparseHandle>:$env);
+ Arg<GPU_SparseEnvHandle>:$env);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
- custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $env attr-dict
+ custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+ $env attr-dict
}];
}
@@ -1612,7 +1615,8 @@ def GPU_CreateDnVecOp : GPU_Op<"create_dn_vec", [GPU_AsyncOpInterface]> {
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
AnyMemRef:$memref, Index:$size);
- let results = (outs Res<GPU_SparseHandle>:$dvec, Optional<GPU_AsyncToken>:$asyncToken);
+ let results = (outs Res<GPU_SparseDnVecHandle>:$dvec,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1639,11 +1643,12 @@ def GPU_DestroyDnVecOp : GPU_Op<"destroy_dn_vec", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Arg<GPU_SparseHandle>:$dvec);
+ Arg<GPU_SparseDnVecHandle>:$dvec);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
- custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $dvec attr-dict
+ custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+ $dvec attr-dict
}];
}
@@ -1667,10 +1672,10 @@ def GPU_CreateDnMatOp : GPU_Op<"create_dn_mat", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Index:$rows,
- Index:$cols,
- AnyMemRef:$memref);
- let results = (outs Res<GPU_SparseHandle>:$dmat, Optional<GPU_AsyncToken>:$asyncToken);
+ Index:$rows,
+ Index:$cols,
+ AnyMemRef:$memref);
+ let results = (outs Res<GPU_SparseDnMatHandle>:$dmat, Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1697,11 +1702,12 @@ def GPU_DestroyDnMatOp : GPU_Op<"destroy_dn_mat", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Arg<GPU_SparseHandle>:$dmat);
+ Arg<GPU_SparseDnMatHandle>:$dmat);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
- custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $dmat attr-dict
+ custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+ $dmat attr-dict
}];
}
@@ -1726,13 +1732,14 @@ def GPU_CreateCooOp : GPU_Op<"create_coo", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Index:$rows,
- Index:$cols,
- Index:$nnz,
- AnyMemRef:$rowIdxs,
- AnyMemRef:$colIdxs,
- AnyMemRef:$values);
- let results = (outs Res<GPU_SparseHandle>:$spmat, Optional<GPU_AsyncToken>:$asyncToken);
+ Index:$rows,
+ Index:$cols,
+ Index:$nnz,
+ AnyMemRef:$rowIdxs,
+ AnyMemRef:$colIdxs,
+ AnyMemRef:$values);
+ let results = (outs Res<GPU_SparseSpMatHandle>:$spmat,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1769,7 +1776,8 @@ def GPU_CreateCsrOp : GPU_Op<"create_csr", [GPU_AsyncOpInterface]> {
AnyMemRef:$rowPos,
AnyMemRef:$colIdxs,
AnyMemRef:$values);
- let results = (outs Res<GPU_SparseHandle>:$spmat, Optional<GPU_AsyncToken>:$asyncToken);
+ let results = (outs Res<GPU_SparseSpMatHandle>:$spmat,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1797,7 +1805,7 @@ def GPU_DestroySpMatOp : GPU_Op<"destroy_sp_mat", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- Arg<GPU_SparseHandle>:$spmat);
+ Arg<GPU_SparseSpMatHandle>:$spmat);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
@@ -1823,13 +1831,13 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
%buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY
```
}];
-
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- GPU_SparseHandle:$env,
- GPU_SparseHandle:$spmatA,
- GPU_SparseHandle:$dnX,
- GPU_SparseHandle:$dnY);
- let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
+ GPU_SparseEnvHandle:$env,
+ GPU_SparseSpMatHandle:$spmatA,
+ GPU_SparseDnVecHandle:$dnX,
+ GPU_SparseDnVecHandle:$dnY);
+ let results = (outs Res<Index>:$bufferSz,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1855,13 +1863,12 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
%token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref<?xf64>
```
}];
-
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- GPU_SparseHandle:$env,
- GPU_SparseHandle:$spmatA,
- GPU_SparseHandle:$dnX,
- GPU_SparseHandle:$dnY,
- AnyMemRef:$buffer);
+ GPU_SparseEnvHandle:$env,
+ GPU_SparseSpMatHandle:$spmatA,
+ GPU_SparseDnVecHandle:$dnX,
+ GPU_SparseDnVecHandle:$dnY,
+ AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
@@ -1890,11 +1897,12 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- GPU_SparseHandle:$env,
- GPU_SparseHandle:$spmatA,
- GPU_SparseHandle:$dnmatB,
- GPU_SparseHandle:$dnmatC);
- let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
+ GPU_SparseEnvHandle:$env,
+ GPU_SparseSpMatHandle:$spmatA,
+ GPU_SparseDnMatHandle:$dnmatB,
+ GPU_SparseDnMatHandle:$dnmatC);
+ let results = (outs Res<Index>:$bufferSz,
+ Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1922,11 +1930,11 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
}];
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
- GPU_SparseHandle:$env,
- GPU_SparseHandle:$spmatA,
- GPU_SparseHandle:$dnmatB,
- GPU_SparseHandle:$dnmatC,
- AnyMemRef:$buffer);
+ GPU_SparseEnvHandle:$env,
+ GPU_SparseSpMatHandle:$spmatA,
+ GPU_SparseDnMatHandle:$dnmatB,
+ GPU_SparseDnMatHandle:$dnmatC,
+ AnyMemRef:$buffer);
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 600bd9152c436..1d1923d4d0c2b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1470,18 +1470,23 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
return success();
}
+template <typename T>
+static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
+ converter.addConversion([&converter](T) -> Type {
+ return converter.getPointerType(
+ IntegerType::get(&converter.getContext(), 8));
+ });
+}
+
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
StringRef gpuBinaryAnnotation,
bool kernelBarePtrCallConv) {
- converter.addConversion([&converter](gpu::AsyncTokenType type) -> Type {
- return converter.getPointerType(
- IntegerType::get(&converter.getContext(), 8));
- });
- converter.addConversion([&converter](gpu::SparseHandleType type) -> Type {
- return converter.getPointerType(
- IntegerType::get(&converter.getContext(), 8));
- });
+ addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
+ addOpaquePointerConversion<gpu::SparseDnVecHandleType>(converter);
+ addOpaquePointerConversion<gpu::SparseDnMatHandleType>(converter);
+ addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
+ addOpaquePointerConversion<gpu::SparseEnvHandleType>(converter);
patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
ConvertDeallocOpToGpuRuntimeCallPattern,
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index ce502401b86a7..add4b97c71fee 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -146,7 +146,10 @@ struct GPUInlinerInterface : public DialectInlinerInterface {
void GPUDialect::initialize() {
addTypes<AsyncTokenType>();
addTypes<MMAMatrixType>();
- addTypes<SparseHandleType>();
+ addTypes<SparseEnvHandleType>();
+ addTypes<SparseDnVecHandleType>();
+ addTypes<SparseDnMatHandleType>();
+ addTypes<SparseSpMatHandleType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
@@ -158,6 +161,19 @@ void GPUDialect::initialize() {
addInterfaces<GPUInlinerInterface>();
}
+static std::string getSparseHandleKeyword(SparseHandleKind kind) {
+ switch (kind) {
+ case SparseHandleKind::Env:
+ return "sparse.env_handle";
+ case SparseHandleKind::DnVec:
+ return "sparse.dnvec_handle";
+ case SparseHandleKind::DnMat:
+ return "sparse.dnmat_handle";
+ case SparseHandleKind::SpMat:
+ return "sparse.spmat_handle";
+ }
+};
+
Type GPUDialect::parseType(DialectAsmParser &parser) const {
// Parse the main keyword for the type.
StringRef keyword;
@@ -201,17 +217,31 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
shape, elementType, operand);
}
- if (keyword == "sparse.handle")
- return SparseHandleType::get(context);
+ if (keyword == getSparseHandleKeyword(SparseHandleKind::Env))
+ return SparseEnvHandleType::get(context);
+ if (keyword == getSparseHandleKeyword(SparseHandleKind::DnVec))
+ return SparseDnVecHandleType::get(context);
+ if (keyword == getSparseHandleKeyword(SparseHandleKind::DnMat))
+ return SparseDnMatHandleType::get(context);
+ if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat))
+ return SparseSpMatHandleType::get(context);
parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
return Type();
}
-
+// TODO: print refined type here. Notice that should be corresponding to the
+// parser
void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<AsyncTokenType>([&](Type) { os << "async.token"; })
- .Case<SparseHandleType>([&](Type) { os << "sparse.handle"; })
+ .Case<SparseEnvHandleType>(
+ [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::Env); })
+ .Case<SparseDnVecHandleType>(
+ [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::DnVec); })
+ .Case<SparseDnMatHandleType>(
+ [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::DnMat); })
+ .Case<SparseSpMatHandleType>(
+ [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); })
.Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
os << "mma_matrix<";
auto shape = fragTy.getShape();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cadbe86942274..53a11ad9ac440 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -436,22 +436,25 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
// Create sparse environment and sparse matrix/dense vector handles.
Type indexTp = rewriter.getIndexType();
- Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+ Type envHandleTp = rewriter.getType<gpu::SparseEnvHandleType>();
+ Type dnVecHandleTp = rewriter.getType<gpu::SparseDnVecHandleType>();
+ Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
auto env =
- rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+ rewriter.create<gpu::CreateSparseEnvOp>(loc, envHandleTp, tokenTp, token);
Value handle = env.getResult(0);
token = env.getAsyncToken();
- Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY,
- szX, nseA, rowA, colA, valA, isCOO, enableRT);
+ Operation *spGenA =
+ genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
+ rowA, colA, valA, isCOO, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
- auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+ auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, dnVecHandleTp, tokenTp,
token, vecX, szX);
Value dnX = dvecX.getResult(0);
token = dvecX.getAsyncToken();
- auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+ auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, dnVecHandleTp, tokenTp,
token, vecY, szY);
Value dnY = dvecY.getResult(0);
token = dvecY.getAsyncToken();
@@ -540,22 +543,25 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
// Create sparse environment and sparse matrix/dense matrix handles.
Type indexTp = rewriter.getIndexType();
- Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+ Type envHandleTp = rewriter.getType<gpu::SparseEnvHandleType>();
+ Type dnMatHandleTp = rewriter.getType<gpu::SparseDnMatHandleType>();
+ Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
Value token = genFirstWait(rewriter, loc);
auto env =
- rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+ rewriter.create<gpu::CreateSparseEnvOp>(loc, envHandleTp, tokenTp, token);
Value handle = env.getResult(0);
token = env.getAsyncToken();
- Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm,
- szk, nseA, rowA, colA, valA, isCOO, enableRT);
+ Operation *spGenA =
+ genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
+ rowA, colA, valA, isCOO, enableRT);
Value spMatA = spGenA->getResult(0);
token = spGenA->getResult(1);
- auto dmatB = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+ auto dmatB = rewriter.create<gpu::CreateDnMatOp>(loc, dnMatHandleTp, tokenTp,
token, szk, szn, matB);
Value dnB = dmatB.getResult(0);
token = dmatB.getAsyncToken();
- auto dmatC = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+ auto dmatC = rewriter.create<gpu::CreateDnMatOp>(loc, dnMatHandleTp, tokenTp,
token, szm, szn, matC);
Value dnC = dmatC.getResult(0);
token = dmatC.getAsyncToken();
More information about the Mlir-commits
mailing list