[Mlir-commits] [mlir] 86bf710 - [mlir] [gpu] [sparse] refined SparseHandle type

Aart Bik llvmlistbot at llvm.org
Wed May 24 10:16:16 PDT 2023


Author: Kun Wu
Date: 2023-05-24T10:16:07-07:00
New Revision: 86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4

URL: https://github.com/llvm/llvm-project/commit/86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4
DIFF: https://github.com/llvm/llvm-project/commit/86bf710cf750b387b5e5f4a0cf8fd3d8d7ec9dd4.diff

LOG: [mlir] [gpu] [sparse] refined SparseHandle type

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151014

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index e56af3e050486..ddef02095c64f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -109,10 +109,30 @@ class MMAMatrixOf<list<Type> allowedTypes> :
   "::llvm::cast<::mlir::gpu::MMAMatrixType>($_self).getElementType()",
   "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
 
-// Generic type for all sparse handles (could be refined).
-def GPU_SparseHandle : DialectType<
-  GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::SparseHandleType>()">, "sparse handle type">,
-             BuildableType<"mlir::gpu::SparseHandleType::get($_builder.getContext())">;
+// Types for all sparse handles.
+def GPU_SparseEnvHandle : 
+  DialectType<GPU_Dialect, 
+    CPred<"$_self.isa<::mlir::gpu::SparseEnvHandleType>()">, 
+    "sparse environment handle type">, 
+  BuildableType<"mlir::gpu::SparseEnvHandleType::get($_builder.getContext())">;
+
+def GPU_SparseDnVecHandle : 
+  DialectType<GPU_Dialect, 
+    CPred<"$_self.isa<::mlir::gpu::SparseDnVecHandleType>()">, 
+    "dense vector handle type">,
+  BuildableType<"mlir::gpu::SparseDnVecHandleType::get($_builder.getContext())">;
+
+def GPU_SparseDnMatHandle : 
+  DialectType<GPU_Dialect, 
+    CPred<"$_self.isa<::mlir::gpu::SparseDnMatHandleType>()">, 
+    "dense matrix handle type">,
+  BuildableType<"mlir::gpu::SparseDnMatHandleType::get($_builder.getContext())">;
+
+def GPU_SparseSpMatHandle : 
+  DialectType<GPU_Dialect, 
+    CPred<"$_self.isa<::mlir::gpu::SparseSpMatHandleType>()">, 
+    "sparse matrix handle type">,
+  BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
 
 //===----------------------------------------------------------------------===//
 // GPU Interfaces.

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 64b8f8f6e8b3a..a775066362cab 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -163,14 +163,23 @@ class MMAMatrixType
 // Adds a `gpu.async.token` to the front of the argument list.
 void addAsyncDependency(Operation *op, Value token);
 
-// Represents any sparse handle.
+// Handle types for sparse.
+enum class SparseHandleKind { Env, DnVec, DnMat, SpMat };
+
+template <SparseHandleKind K>
 class SparseHandleType
-    : public Type::TypeBase<SparseHandleType, Type, TypeStorage> {
+    : public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
 public:
-  // Used for generic hooks in TypeBase.
+  using Base =
+      typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
   using Base::Base;
 };
 
+using SparseEnvHandleType = SparseHandleType<SparseHandleKind::Env>;
+using SparseDnVecHandleType = SparseHandleType<SparseHandleKind::DnVec>;
+using SparseDnMatHandleType = SparseHandleType<SparseHandleKind::DnMat>;
+using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
+
 } // namespace gpu
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 77e65972038da..5160b6886817e 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1557,14 +1557,16 @@ def GPU_CreateSparseEnvOp : GPU_Op<"create_sparse_env", [GPU_AsyncOpInterface]>
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies);
-  let results = (outs Res<GPU_SparseHandle>:$env, Optional<GPU_AsyncToken>:$asyncToken);
-
+  let results = (outs Res<GPU_SparseEnvHandle>:$env, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) attr-dict
   }];
 }
 
-def GPU_DestroySparseEnvOp : GPU_Op<"destroy_sparse_env", [GPU_AsyncOpInterface]> {
+def GPU_DestroySparseEnvOp : GPU_Op<
+    "destroy_sparse_env", 
+    [GPU_AsyncOpInterface]> {
   let summary = "Destroy sparse environment operation";
   let description = [{
     The `gpu.destroy_sparse_env` operation releases all resources of a sparse
@@ -1583,11 +1585,12 @@ def GPU_DestroySparseEnvOp : GPU_Op<"destroy_sparse_env", [GPU_AsyncOpInterface]
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Arg<GPU_SparseHandle>:$env);
+                       Arg<GPU_SparseEnvHandle>:$env);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
-    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $env attr-dict
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) 
+    $env attr-dict
   }];
 }
 
@@ -1612,7 +1615,8 @@ def GPU_CreateDnVecOp : GPU_Op<"create_dn_vec", [GPU_AsyncOpInterface]> {
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                    AnyMemRef:$memref, Index:$size);
-  let results = (outs Res<GPU_SparseHandle>:$dvec, Optional<GPU_AsyncToken>:$asyncToken);
+  let results = (outs Res<GPU_SparseDnVecHandle>:$dvec, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1639,11 +1643,12 @@ def GPU_DestroyDnVecOp : GPU_Op<"destroy_dn_vec", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Arg<GPU_SparseHandle>:$dvec);
+                       Arg<GPU_SparseDnVecHandle>:$dvec);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
-    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $dvec attr-dict
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) 
+    $dvec attr-dict
   }];
 }
 
@@ -1667,10 +1672,10 @@ def GPU_CreateDnMatOp : GPU_Op<"create_dn_mat", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Index:$rows,
-                   Index:$cols,
-                   AnyMemRef:$memref);
-  let results = (outs Res<GPU_SparseHandle>:$dmat, Optional<GPU_AsyncToken>:$asyncToken);
+                       Index:$rows,
+                       Index:$cols,
+                       AnyMemRef:$memref);
+  let results = (outs Res<GPU_SparseDnMatHandle>:$dmat, Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1697,11 +1702,12 @@ def GPU_DestroyDnMatOp : GPU_Op<"destroy_dn_mat", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Arg<GPU_SparseHandle>:$dmat);
+                       Arg<GPU_SparseDnMatHandle>:$dmat);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
-    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) $dmat attr-dict
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies) 
+    $dmat attr-dict
   }];
 }
 
@@ -1726,13 +1732,14 @@ def GPU_CreateCooOp : GPU_Op<"create_coo", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Index:$rows,
-                   Index:$cols,
-                   Index:$nnz,
-                   AnyMemRef:$rowIdxs,
-                   AnyMemRef:$colIdxs,
-                   AnyMemRef:$values);
-  let results = (outs Res<GPU_SparseHandle>:$spmat, Optional<GPU_AsyncToken>:$asyncToken);
+                       Index:$rows,
+                       Index:$cols,
+                       Index:$nnz,
+                       AnyMemRef:$rowIdxs,
+                       AnyMemRef:$colIdxs,
+                       AnyMemRef:$values);
+  let results = (outs Res<GPU_SparseSpMatHandle>:$spmat, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1769,7 +1776,8 @@ def GPU_CreateCsrOp : GPU_Op<"create_csr", [GPU_AsyncOpInterface]> {
                    AnyMemRef:$rowPos,
                    AnyMemRef:$colIdxs,
                    AnyMemRef:$values);
-  let results = (outs Res<GPU_SparseHandle>:$spmat, Optional<GPU_AsyncToken>:$asyncToken);
+  let results = (outs Res<GPU_SparseSpMatHandle>:$spmat, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1797,7 +1805,7 @@ def GPU_DestroySpMatOp : GPU_Op<"destroy_sp_mat", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   Arg<GPU_SparseHandle>:$spmat);
+                       Arg<GPU_SparseSpMatHandle>:$spmat);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
@@ -1823,13 +1831,13 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
     %buffersz, %token = gpu.spmv_buffersize async [%dep] %env, %spmatA, %dnX, %dnY
     ```
   }];
-
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   GPU_SparseHandle:$env,
-                   GPU_SparseHandle:$spmatA,
-                   GPU_SparseHandle:$dnX,
-                   GPU_SparseHandle:$dnY);
-  let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
+                       GPU_SparseEnvHandle:$env,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseDnVecHandle:$dnX,
+                       GPU_SparseDnVecHandle:$dnY);
+  let results = (outs Res<Index>:$bufferSz, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1855,13 +1863,12 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
     %token = gpu.spmv async [%dep] %env, %spmatA, %dnX, %dnY : memref<?xf64>
     ```
   }];
-
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   GPU_SparseHandle:$env,
-                   GPU_SparseHandle:$spmatA,
-                   GPU_SparseHandle:$dnX,
-                   GPU_SparseHandle:$dnY,
-                   AnyMemRef:$buffer);
+                       GPU_SparseEnvHandle:$env,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseDnVecHandle:$dnX,
+                       GPU_SparseDnVecHandle:$dnY,
+                       AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
@@ -1890,11 +1897,12 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   GPU_SparseHandle:$env,
-                   GPU_SparseHandle:$spmatA,
-                   GPU_SparseHandle:$dnmatB,
-                   GPU_SparseHandle:$dnmatC);
-  let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
+                       GPU_SparseEnvHandle:$env,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseDnMatHandle:$dnmatB,
+                       GPU_SparseDnMatHandle:$dnmatC);
+  let results = (outs Res<Index>:$bufferSz, 
+                      Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
@@ -1922,11 +1930,11 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                   GPU_SparseHandle:$env,
-                   GPU_SparseHandle:$spmatA,
-                   GPU_SparseHandle:$dnmatB,
-                   GPU_SparseHandle:$dnmatC,
-                   AnyMemRef:$buffer);
+                       GPU_SparseEnvHandle:$env,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseDnMatHandle:$dnmatB,
+                       GPU_SparseDnMatHandle:$dnmatC,
+                       AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let assemblyFormat = [{

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 600bd9152c436..1d1923d4d0c2b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1470,18 +1470,23 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+template <typename T>
+static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
+  converter.addConversion([&converter](T) -> Type {
+    return converter.getPointerType(
+        IntegerType::get(&converter.getContext(), 8));
+  });
+}
+
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
                                                StringRef gpuBinaryAnnotation,
                                                bool kernelBarePtrCallConv) {
-  converter.addConversion([&converter](gpu::AsyncTokenType type) -> Type {
-    return converter.getPointerType(
-        IntegerType::get(&converter.getContext(), 8));
-  });
-  converter.addConversion([&converter](gpu::SparseHandleType type) -> Type {
-    return converter.getPointerType(
-        IntegerType::get(&converter.getContext(), 8));
-  });
+  addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
+  addOpaquePointerConversion<gpu::SparseDnVecHandleType>(converter);
+  addOpaquePointerConversion<gpu::SparseDnMatHandleType>(converter);
+  addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
+  addOpaquePointerConversion<gpu::SparseEnvHandleType>(converter);
 
   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
                ConvertDeallocOpToGpuRuntimeCallPattern,

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index ce502401b86a7..add4b97c71fee 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -146,7 +146,10 @@ struct GPUInlinerInterface : public DialectInlinerInterface {
 void GPUDialect::initialize() {
   addTypes<AsyncTokenType>();
   addTypes<MMAMatrixType>();
-  addTypes<SparseHandleType>();
+  addTypes<SparseEnvHandleType>();
+  addTypes<SparseDnVecHandleType>();
+  addTypes<SparseDnMatHandleType>();
+  addTypes<SparseSpMatHandleType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
@@ -158,6 +161,19 @@ void GPUDialect::initialize() {
   addInterfaces<GPUInlinerInterface>();
 }
 
+static std::string getSparseHandleKeyword(SparseHandleKind kind) {
+  switch (kind) {
+  case SparseHandleKind::Env:
+    return "sparse.env_handle";
+  case SparseHandleKind::DnVec:
+    return "sparse.dnvec_handle";
+  case SparseHandleKind::DnMat:
+    return "sparse.dnmat_handle";
+  case SparseHandleKind::SpMat:
+    return "sparse.spmat_handle";
+  }
+};
+
 Type GPUDialect::parseType(DialectAsmParser &parser) const {
   // Parse the main keyword for the type.
   StringRef keyword;
@@ -201,17 +217,31 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
                                      shape, elementType, operand);
   }
 
-  if (keyword == "sparse.handle")
-    return SparseHandleType::get(context);
+  if (keyword == getSparseHandleKeyword(SparseHandleKind::Env))
+    return SparseEnvHandleType::get(context);
+  if (keyword == getSparseHandleKeyword(SparseHandleKind::DnVec))
+    return SparseDnVecHandleType::get(context);
+  if (keyword == getSparseHandleKeyword(SparseHandleKind::DnMat))
+    return SparseDnMatHandleType::get(context);
+  if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat))
+    return SparseSpMatHandleType::get(context);
 
   parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
   return Type();
 }
-
+// TODO: print refined type here. Notice that should be corresponding to the
+// parser
 void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<AsyncTokenType>([&](Type) { os << "async.token"; })
-      .Case<SparseHandleType>([&](Type) { os << "sparse.handle"; })
+      .Case<SparseEnvHandleType>(
+          [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::Env); })
+      .Case<SparseDnVecHandleType>(
+          [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::DnVec); })
+      .Case<SparseDnMatHandleType>(
+          [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::DnMat); })
+      .Case<SparseSpMatHandleType>(
+          [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); })
       .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
         os << "mma_matrix<";
         auto shape = fragTy.getShape();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cadbe86942274..53a11ad9ac440 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -436,22 +436,25 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
 
   // Create sparse environment and sparse matrix/dense vector handles.
   Type indexTp = rewriter.getIndexType();
-  Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+  Type envHandleTp = rewriter.getType<gpu::SparseEnvHandleType>();
+  Type dnVecHandleTp = rewriter.getType<gpu::SparseDnVecHandleType>();
+  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
   Value token = genFirstWait(rewriter, loc);
   auto env =
-      rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+      rewriter.create<gpu::CreateSparseEnvOp>(loc, envHandleTp, tokenTp, token);
   Value handle = env.getResult(0);
   token = env.getAsyncToken();
-  Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY,
-                               szX, nseA, rowA, colA, valA, isCOO, enableRT);
+  Operation *spGenA =
+      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
+               rowA, colA, valA, isCOO, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
-  auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+  auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, dnVecHandleTp, tokenTp,
                                                    token, vecX, szX);
   Value dnX = dvecX.getResult(0);
   token = dvecX.getAsyncToken();
-  auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
+  auto dvecY = rewriter.create<gpu::CreateDnVecOp>(loc, dnVecHandleTp, tokenTp,
                                                    token, vecY, szY);
   Value dnY = dvecY.getResult(0);
   token = dvecY.getAsyncToken();
@@ -540,22 +543,25 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
 
   // Create sparse environment and sparse matrix/dense matrix handles.
   Type indexTp = rewriter.getIndexType();
-  Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+  Type envHandleTp = rewriter.getType<gpu::SparseEnvHandleType>();
+  Type dnMatHandleTp = rewriter.getType<gpu::SparseDnMatHandleType>();
+  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
   Value token = genFirstWait(rewriter, loc);
   auto env =
-      rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+      rewriter.create<gpu::CreateSparseEnvOp>(loc, envHandleTp, tokenTp, token);
   Value handle = env.getResult(0);
   token = env.getAsyncToken();
-  Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm,
-                               szk, nseA, rowA, colA, valA, isCOO, enableRT);
+  Operation *spGenA =
+      genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
+               rowA, colA, valA, isCOO, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
-  auto dmatB = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+  auto dmatB = rewriter.create<gpu::CreateDnMatOp>(loc, dnMatHandleTp, tokenTp,
                                                    token, szk, szn, matB);
   Value dnB = dmatB.getResult(0);
   token = dmatB.getAsyncToken();
-  auto dmatC = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+  auto dmatC = rewriter.create<gpu::CreateDnMatOp>(loc, dnMatHandleTp, tokenTp,
                                                    token, szm, szn, matC);
   Value dnC = dmatC.getResult(0);
   token = dmatC.getAsyncToken();


        


More information about the Mlir-commits mailing list