[Mlir-commits] [mlir] fa98bdb - [mlir][sparse][gpu] make computeType mandatory

Kun Wu llvmlistbot at llvm.org
Fri Jun 2 14:47:50 PDT 2023


Author: Kun Wu
Date: 2023-06-02T21:47:44Z
New Revision: fa98bdbd95d14959d3c6c09a4c29ba0d974883dd

URL: https://github.com/llvm/llvm-project/commit/fa98bdbd95d14959d3c6c09a4c29ba0d974883dd
DIFF: https://github.com/llvm/llvm-project/commit/fa98bdbd95d14959d3c6c09a4c29ba0d974883dd.diff

LOG: [mlir][sparse][gpu] make computeType mandatory

Differential Revision: https://reviews.llvm.org/D152018

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/ops.mlir
    mlir/test/Dialect/GPU/sparse-roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 2fac955ce6a12..17bff31941579 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1869,7 +1869,7 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
                        GPU_SparseDnVecHandle:$dnY,
-                       OptionalAttr<TypeAttr>:$computeType);
+                       TypeAttr:$computeType);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
@@ -1880,16 +1880,17 @@ def GPU_SpMVBufferSizeOp : GPU_Op<"spmv_buffer_size", [GPU_AsyncOpInterface]> {
       "Value":$env,
       "Value":$spmatA,
       "Value":$dnX,
-      "Value":$dnY)
+      "Value":$dnY,
+      "Type":$computeType)
       , [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
-                 env, modeA, spmatA, dnX, dnY, {});}]>
+                 env, modeA, spmatA, dnX, dnY, computeType);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict ( `into` $computeType^)?
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY attr-dict  `into` $computeType
   }];
 }
 
@@ -1921,7 +1922,7 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnVecHandle:$dnX,
                        GPU_SparseDnVecHandle:$dnY,
-                       OptionalAttr<TypeAttr>:$computeType,
+                       TypeAttr:$computeType,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
@@ -1932,15 +1933,16 @@ def GPU_SpMVOp : GPU_Op<"spmv", [GPU_AsyncOpInterface]> {
       "Value":$spmatA,
       "Value":$dnX,
       "Value":$dnY,
+      "Type":$computeType,
       "Value":$buffer), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
-                 spmatA, dnX, dnY, {}, buffer);}]>
+                 spmatA, dnX, dnY, computeType, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnX `,` $dnY `,` $buffer attr-dict `:` type($buffer) `into` $computeType
   }];
 }
 
@@ -1974,7 +1976,7 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
                        GPU_SparseDnMatHandle:$dnmatC,
-                       OptionalAttr<TypeAttr>:$computeType);
+                       TypeAttr:$computeType);
   let results = (outs Res<Index>:$bufferSz, 
                       Optional<GPU_AsyncToken>:$asyncToken);
 
@@ -1985,16 +1987,17 @@ def GPU_SpMMBufferSizeOp : GPU_Op<"spmm_buffer_size", [GPU_AsyncOpInterface]> {
       "Value":$env,
       "Value":$spmatA,
       "Value":$dnmatB,
-      "Value":$dnmatC), [{
+      "Value":$dnmatC,
+      "Type":$computeType), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
-                 env, modeA, modeB, spmatA, dnmatB, dnmatC, {});}]>
+                 env, modeA, modeB, spmatA, dnmatB, dnmatC, computeType);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict ( `into` $computeType^)?
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC attr-dict `into` $computeType
   }];
 }
 
@@ -2028,7 +2031,7 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseDnMatHandle:$dnmatB,
                        GPU_SparseDnMatHandle:$dnmatC,
-                       OptionalAttr<TypeAttr>:$computeType,
+                       TypeAttr:$computeType,
                        AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
@@ -2039,16 +2042,17 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
       "Value":$spmatA,
       "Value":$dnmatB,
       "Value":$dnmatC,
+      "Type":$computeType,
       "Value":$buffer), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
-                 modeB, spmatA, dnmatB, dnmatC, {}, buffer);}]>
+                 modeB, spmatA, dnmatB, dnmatC, computeType, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
+    $env `,` $spmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $dnmatC `,` $buffer attr-dict `:` type($buffer) `into` $computeType
   }];
 }
 
@@ -2082,26 +2086,27 @@ def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]>
                    GPU_SparseDnMatHandle:$dnmatA,
                    GPU_SparseDnMatHandle:$dnmatB,
                    GPU_SparseSpMatHandle:$spmatC,
-                   OptionalAttr<TypeAttr>:$computeType);
+                   TypeAttr:$computeType);
   let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-      "::mlir::Type":$bufferSz,
-      "::mlir::Type":$asyncToken,
-      "::mlir::ValueRange":$asyncDependencies,
-      "::mlir::Value":$env,
-      "::mlir::Value":$dnmatA,
-      "::mlir::Value":$dnmatB,
-      "::mlir::Value":$spmatC), [{
+      "Type":$bufferSz,
+      "Type":$asyncToken,
+      "ValueRange":$asyncDependencies,
+      "Value":$env,
+      "Value":$dnmatA,
+      "Value":$dnmatB,
+      "Value":$spmatC,
+      "Type":$computeType), [{
     auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
     auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
     return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
-                 env, modeA, modeB, dnmatA, dnmatB, spmatC, {});}]>
+                 env, modeA, modeB, dnmatA, dnmatB, spmatC, computeType);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict ( `into` $computeType^)?
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict `into` $computeType
   }];
 }
 
@@ -2135,27 +2140,28 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
                    GPU_SparseDnMatHandle:$dnmatA,
                    GPU_SparseDnMatHandle:$dnmatB,
                    GPU_SparseSpMatHandle:$spmatC,
-                   OptionalAttr<TypeAttr>:$computeType,
+                   TypeAttr:$computeType,
                    AnyMemRef:$buffer);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
-    "::mlir::Type":$asyncToken,
-    "::mlir::ValueRange":$asyncDependencies,
-    "::mlir::Value":$env,
-    "::mlir::Value":$dnmatA,
-    "::mlir::Value":$dnmatB,
-    "::mlir::Value":$spmatC,
-    "::mlir::Value":$buffer), [{
+    "Type":$asyncToken,
+    "ValueRange":$asyncDependencies,
+    "Value":$env,
+    "Value":$dnmatA,
+    "Value":$dnmatB,
+    "Value":$spmatC,
+    "Type":$computeType,
+    "Value":$buffer), [{
   auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
   auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
   return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
-                modeB, dnmatA, dnmatB, spmatC, {}, buffer);}]>
+                modeB, dnmatA, dnmatB, spmatC, computeType, buffer);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) ( `into` $computeType^)?
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer) `into` $computeType
   }];
 }
 

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 5ec455a65bd65..023a52eeec138 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -1274,25 +1274,6 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-// Returns the element type of the defining spmat op.
-// TODO: safer and more flexible to store data type in actual op instead?
-static Type getSpMatElemType(Value spMat) {
-  if (auto op = spMat.getDefiningOp<gpu::CreateCooOp>())
-    return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
-  if (auto op = spMat.getDefiningOp<gpu::CreateCsrOp>())
-    return llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
-  llvm_unreachable("cannot find spmat def");
-}
-
-// Returns the element type of the defining dnmat or dnvec op.
-static Type getDnElemType(Value dn) {
-  if (auto op = dn.getDefiningOp<gpu::CreateDnMatOp>())
-    return op.getMemref().getType().getElementType();
-  if (auto op = dn.getDefiningOp<gpu::CreateDnVecOp>())
-    return op.getMemref().getType().getElementType();
-  llvm_unreachable("cannot find dn def");
-}
-
 template <typename T>
 static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
   Type llvmInt32Type = builder.getIntegerType(32);
@@ -1300,14 +1281,11 @@ static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
                                           static_cast<int32_t>(TValue));
 }
 
-static Value
-genConstInt32FromOptionalComputeMode(OpBuilder &builder, Location loc,
-                                     std::optional<Type> computeTypeOptional,
-                                     Type defaultType) {
-  auto computeTypeInt =
-      getCuSparseDataTypeFrom(computeTypeOptional.value_or(defaultType));
-  auto computeType = genConstInt32From(builder, loc, computeTypeInt);
-  return computeType;
+static Value genConstInt32FromComputeMode(OpBuilder &builder, Location loc,
+                                          Type computeType) {
+  auto computeTypeInt = getCuSparseDataTypeFrom(computeType);
+  auto computeTypeConst = genConstInt32From(builder, loc, computeTypeInt);
+  return computeTypeConst;
 }
 
 LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
@@ -1502,9 +1480,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
-  // retrieve the compute type, notice that it may be optional
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize =
       spMVBufferSizeCallBuilder
@@ -1524,9 +1501,8 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
-  // retrieve the compute type, notice that it may be optional
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnY()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1550,9 +1526,8 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
-  // retrieve the compute type, notice that it may be optional
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
 
   auto bufferSize = spMMBufferSizeCallBuilder
                         .create(loc, rewriter,
@@ -1573,9 +1548,8 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(),
-      getSpMatElemType(op.getSpmatC()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize = SDDMMBufferSizeCallBuilder
                         .create(loc, rewriter,
@@ -1596,9 +1570,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  // retrieve the compute type, notice that it may be optional
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(), getDnElemType(op.getDnmatC()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
 
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
@@ -1628,9 +1601,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto computeType = genConstInt32FromOptionalComputeMode(
-      rewriter, loc, adaptor.getComputeType(),
-      getSpMatElemType(op.getSpmatC()));
+  auto computeType =
+      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 61ee115e879a9..a190ff6dacb92 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -462,9 +462,12 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
   Value dnY = dvecY.getResult(0);
   token = dvecY.getAsyncToken();
 
+  auto dnYType = llvm::cast<ShapedType>(y.getType()).getElementType();
+
   // Precompute buffersize for SpMV.
   auto bufferComp = rewriter.create<gpu::SpMVBufferSizeOp>(
-      loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY);
+      loc, indexTp, tokenTp, token, handle, spMatA, dnX, dnY,
+      /*computeType=*/dnYType);
   Value bufferSz = bufferComp.getResult(0);
   token = bufferComp.getAsyncToken();
   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
@@ -472,8 +475,9 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
   token = buf.getAsyncToken();
 
   // Perform the SpMV.
-  auto spmvComp = rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle,
-                                               spMatA, dnX, dnY, buffer);
+  auto spmvComp =
+      rewriter.create<gpu::SpMVOp>(loc, tokenTp, token, handle, spMatA, dnX,
+                                   dnY, /*computeType=*/dnYType, buffer);
   token = spmvComp.getAsyncToken();
 
   // Copy data back to host and free all the resoures.
@@ -565,18 +569,24 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
   Value dnC = dmatC.getResult(0);
   token = dmatC.getAsyncToken();
 
+  auto dmatCType = llvm::cast<ShapedType>(c.getType()).getElementType();
+
   // Precompute buffersize for SpMM.
   auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
-      loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC);
+      loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC,
+      /*computeType=*/dmatCType);
   Value bufferSz = bufferComp.getResult(0);
   token = bufferComp.getAsyncToken();
   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
   Value buffer = buf.getResult(0);
   token = buf.getAsyncToken();
 
+  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
+
   // Perform the SpMM.
-  auto spmmComp = rewriter.create<gpu::SpMMOp>(loc, tokenTp, token, handle,
-                                               spMatA, dnB, dnC, buffer);
+  auto spmmComp =
+      rewriter.create<gpu::SpMMOp>(loc, tokenTp, token, handle, spMatA, dnB,
+                                   dnC, /*computeType=*/dnCType, buffer);
   token = spmmComp.getAsyncToken();
 
   // Copy data back to host and free all the resoures.

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index 678842361b7a3..0c7f8dd20026d 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -23,8 +23,8 @@ module attributes {gpu.container_module} {
     %env, %token3 = gpu.create_sparse_env async [%token2]
     %spmat, %token4 = gpu.create_coo async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
     %dnvec, %token5 = gpu.create_dn_vec async [%token4] %mem2, %arg0 : memref<?xf64>
-    %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec
-    %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec  into f64
+    %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64> into f64
     %token8 = gpu.destroy_sp_mat async [%token7] %spmat
     %token9 = gpu.destroy_dn_vec async [%token8] %dnvec
     %token10 = gpu.destroy_sparse_env async [%token9] %env
@@ -53,8 +53,8 @@ module attributes {gpu.container_module} {
     %env, %token3 = gpu.create_sparse_env async [%token2]
     %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
     %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
-    %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat
-    %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.spmm_buffer_size async [%token5] %env, %spmat, %dnmat, %dnmat into f64
+    %token7 = gpu.spmm async [%token6] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64> into f64
     %token8 = gpu.destroy_sp_mat async [%token7] %spmat
     %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
     %token10 = gpu.destroy_sparse_env async [%token9] %env
@@ -83,8 +83,8 @@ module attributes {gpu.container_module} {
     %env, %token3 = gpu.create_sparse_env async [%token2]
     %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
     %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
-    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
-    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat into f64
+    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64> into f64
     %token8 = gpu.destroy_sp_mat async [%token7] %spmat
     %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
     %token10 = gpu.destroy_sparse_env async [%token9] %env

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 8900c5bfee581..972e467a6e0a2 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -335,19 +335,19 @@ module attributes {gpu.container_module} {
     // CHECK: gpu.create_dn_vec async
     %dnvec, %token6 = gpu.create_dn_vec async [%token5] %mem2, %arg0 : memref<?xf64>
     // CHECK: gpu.spmv_buffer_size async
-    %bufferSz, %token7 = gpu.spmv_buffer_size async [%token6] %env, %spmat, %dnvec, %dnvec
+    %bufferSz, %token7 = gpu.spmv_buffer_size async [%token6] %env, %spmat, %dnvec, %dnvec  into f64
     // CHECK: gpu.spmv async
-    %token8 = gpu.spmv async [%token7] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
+    %token8 = gpu.spmv async [%token7] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>  into f64
     // CHECK: gpu.create_dn_mat async
     %dnmat, %token9 = gpu.create_dn_mat async [%token8] %arg0, %arg0, %mem2 : memref<?xf64>
     // CHECK: gpu.spmm_buffer_size async
-    %bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat
+    %bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat  into f64
     // CHECK: gpu.spmm async
-    %token11 = gpu.spmm async [%token10] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>
+    %token11 = gpu.spmm async [%token10] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>  into f64
     // CHECK: gpu.sddmm_buffer_size async
-    %bufferSz3, %token12 = gpu.sddmm_buffer_size async [%token11] %env, %dnmat, %dnmat, %spmat
+    %bufferSz3, %token12 = gpu.sddmm_buffer_size async [%token11] %env, %dnmat, %dnmat, %spmat  into f64
     // CHECK: gpu.sddmm async
-    %token13 = gpu.sddmm async [%token12] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+    %token13 = gpu.sddmm async [%token12] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>  into f64
     // CHECK: gpu.destroy_dn_mat async
     %token14 = gpu.destroy_dn_mat async [%token13] %dnmat
     // CHECK: gpu.destroy_sp_mat async

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
index 6465208791dd5..26dc223175980 100644
--- a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -9,8 +9,8 @@ module attributes {gpu.container_module} {
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_coo async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_vec async [%{{.*}}] %{{.*}}, %{{.*}} : memref<?xf64>
-  // CHECK: %{{.*}}, %{{.*}} = gpu.spmv_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
-  // CHECK: %{{.*}} = gpu.spmv async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.spmv_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} into f64
+  // CHECK: %{{.*}} = gpu.spmv async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64> into f64
   // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
   // CHECK: %{{.*}} = gpu.destroy_dn_vec async [%{{.*}}] %{{.*}}
   // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
@@ -23,8 +23,8 @@ module attributes {gpu.container_module} {
     %env, %token3 = gpu.create_sparse_env async [%token2]
     %spmat, %token4 = gpu.create_coo async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
     %dnvec, %token5 = gpu.create_dn_vec async [%token4] %mem2, %arg0 : memref<?xf64>
-    %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec
-    %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.spmv_buffer_size async [%token5] %env, %spmat, %dnvec, %dnvec into f64
+    %token7 = gpu.spmv async [%token6] %env, %spmat, %dnvec, %dnvec, %mem2 : memref<?xf64> into f64
     %token8 = gpu.destroy_sp_mat async [%token7] %spmat
     %token9 = gpu.destroy_dn_vec async [%token8] %dnvec
     %token10 = gpu.destroy_sparse_env async [%token9] %env
@@ -69,8 +69,8 @@ module attributes {gpu.container_module} {
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_sparse_env async [%{{.*}}]
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
   // CHECK: %{{.*}}, %{{.*}} = gpu.create_dn_mat async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
-  // CHECK: %{{.*}}, %{{.*}} = gpu.sddmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
-  // CHECK: %{{.*}} = gpu.sddmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>
+  // CHECK: %{{.*}}, %{{.*}} = gpu.sddmm_buffer_size async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}  into f64
+  // CHECK: %{{.*}} = gpu.sddmm async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf64>  into f64
   // CHECK: %{{.*}} = gpu.destroy_sp_mat async [%{{.*}}] %{{.*}}
   // CHECK: %{{.*}} = gpu.destroy_dn_mat async [%{{.*}}] %{{.*}}
   // CHECK: %{{.*}} = gpu.destroy_sparse_env async [%{{.*}}] %{{.*}}
@@ -83,8 +83,8 @@ module attributes {gpu.container_module} {
     %env, %token3 = gpu.create_sparse_env async [%token2]
     %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
     %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
-    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
-    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat into f64
+    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64> into f64
     %token8 = gpu.destroy_sp_mat async [%token7] %spmat
     %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
     %token10 = gpu.destroy_sparse_env async [%token9] %env


        


More information about the Mlir-commits mailing list