[Mlir-commits] [mlir] 9167dd4 - [mlir][sparse][gpu] recognizing sddmm pattern in GPU libgen path

Kun Wu llvmlistbot at llvm.org
Thu Jun 15 16:48:20 PDT 2023


Author: Kun Wu
Date: 2023-06-15T23:48:11Z
New Revision: 9167dd46ba52d0e6626e4cf1b931910ee15010b6

URL: https://github.com/llvm/llvm-project/commit/9167dd46ba52d0e6626e4cf1b931910ee15010b6
DIFF: https://github.com/llvm/llvm-project/commit/9167dd46ba52d0e6626e4cf1b931910ee15010b6.diff

LOG: [mlir][sparse][gpu] recognizing sddmm pattern in GPU libgen path

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D151582

Added: 
    mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index fbe948da7445e..86c429adf1d04 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -329,6 +329,58 @@ static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
   return false;
 }
 
+// Helper to detect c = c \spy (a * b)
+static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
+  auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
+  auto def = yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>();
+  if (!def)
+    return false;
+  Value s_out = op.getBlock()->getArguments()[2];
+  for (auto *use : s_out.getUsers()) {
+    if (!(isa<sparse_tensor::ReduceOp>(use) ||
+          isa<sparse_tensor::UnaryOp>(use)))
+      return false;
+    // The sparse matrix should be specified as the pattern in the two
+    // operators.
+    if (s_out != use->getOperand(0))
+      return false;
+
+    // the above logic makes sure the pattern involves reduction and unary,
+    // i.e.,
+    // %1 = sparse_tensor.unary
+    // %2 = sparse_tensor.reduce
+    // we need to make sure %1 produces A*B and %2 uses summation as the
+    // reduction operator.
+    if (isa<sparse_tensor::ReduceOp>(use)) {
+      auto reduceSpOp = cast<sparse_tensor::ReduceOp>(use);
+      auto yieldSpOp = cast<sparse_tensor::YieldOp>(
+          reduceSpOp.getRegion().front().getTerminator());
+      auto *reduce = yieldSpOp.getOperand(0).getDefiningOp();
+      if (!isa_and_nonnull<arith::AddFOp>(reduce) &&
+          !isa_and_nonnull<arith::AddIOp>(reduce))
+        return false;
+    }
+    if (isa<sparse_tensor::UnaryOp>(use)) {
+      auto unarySpOp = cast<sparse_tensor::UnaryOp>(use);
+      auto yieldSpOp = cast<sparse_tensor::YieldOp>(
+          unarySpOp.getRegion(0).front().getTerminator());
+      auto *unary = yieldSpOp.getOperand(0).getDefiningOp();
+      if (!isa_and_nonnull<arith::MulIOp>(unary) &&
+          !isa_and_nonnull<arith::MulFOp>(unary))
+        return false;
+
+      // we also need to make sure the unary operation is used by the reduction
+      // operation.
+      for (auto *useUnary : unarySpOp->getUsers()) {
+        if (!isa<sparse_tensor::ReduceOp>(useUnary)) {
+          return false;
+        }
+      }
+    }
+  }
+  return true;
+}
+
 /// Test for sorted COO with suitable data and coordinates types.
 static bool isAdmissibleCOO(SparseTensorType &aTp) {
   return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
@@ -627,6 +679,130 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
   return success();
 }
 
+// TODO: identify alpha and beta and pass them to the CUDA calls
+/// Match and rewrite SDDMM kernel.
+static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
+                                  linalg::GenericOp op, bool enableRT) {
+
+  // For now, this pass reuses C and copies the result non-zero elements to
+  // overwrite C's.
+  // As an ad hoc solution, this pass also assumes the linalg takes a,b,c as
+  // input argument, and c as the output. It recognizes this pattern and rewrite
+  // it.
+
+  Location loc = op.getLoc();
+  Value a = op.getOperand(0);
+  Value b = op.getOperand(1);
+  Value c = op.getOperand(2);
+
+  SmallVector<Value> tokens;
+
+  // Only admissible sparse matrix format and dense matrices.
+  bool isCOO = false;
+  SparseTensorType aTp = getSparseTensorType(a);
+  SparseTensorType bTp = getSparseTensorType(b);
+  SparseTensorType cTp = getSparseTensorType(c);
+
+  if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
+    return failure();
+
+  // cusparse currently does not support COO in its SDDMM kernel.
+  if (isCOO) {
+    return failure();
+  }
+
+  // The SDDMM does the in-place operation.
+  // Start sparse kernel and copy data from host to device.
+  //   a : bufA           -> matA
+  //   b : bufB           -> matA
+  //   c : memR/memC/memV -> rowC,colC,valC
+  Value nseC = rewriter.create<NumberOfEntriesOp>(loc, c);
+  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
+  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
+  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
+  Value bufA = genTensorToMemref(rewriter, loc, a);
+  Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
+  Value bufB = genTensorToMemref(rewriter, loc, b);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
+  Value memV = genToValues(rewriter, loc, c);
+  Value rowC = genAllocCopy(rewriter, loc, memR, tokens);
+  Value colC = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
+  Value valC = genAllocCopy(rewriter, loc, memV, tokens);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Create sparse environment and sparse matrix/dense matrix handles.
+  Type indexTp = rewriter.getIndexType();
+  Type envHandleTp = rewriter.getType<gpu::SparseEnvHandleType>();
+  Type dnMatHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
+  Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
+  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
+  Value token = genFirstWait(rewriter, loc);
+  auto env =
+      rewriter.create<gpu::CreateSparseEnvOp>(loc, envHandleTp, tokenTp, token);
+  Value handle = env.getResult(0);
+  token = env.getAsyncToken();
+
+  auto dmatA = rewriter.create<gpu::CreateDnTensorOp>(
+      loc, dnMatHandleTp, tokenTp, token, handle, matA,
+      SmallVector<Value>{szm, szk});
+  Value dnA = dmatA.getResult(0);
+  token = dmatA.getAsyncToken();
+  auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
+      loc, dnMatHandleTp, tokenTp, token, handle, matB,
+      SmallVector<Value>{szk, szn});
+  Value dnB = dmatB.getResult(0);
+  token = dmatB.getAsyncToken();
+
+  Operation *spGenC =
+      genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
+               rowC, colC, valC, isCOO, enableRT);
+  Value spMatC = spGenC->getResult(0);
+  token = spGenC->getResult(1);
+
+  auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
+  // Precompute buffersize for SDDMM.
+  auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
+      loc, indexTp, tokenTp, token, handle, dnA, dnB, spMatC, dnCType);
+  Value bufferSz = bufferComp.getResult(0);
+  token = bufferComp.getAsyncToken();
+  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
+  Value buffer = buf.getResult(0);
+  token = buf.getAsyncToken();
+
+  // Perform the SDDMM.
+  auto sddmmComp = rewriter.create<gpu::SDDMMOp>(
+      loc, tokenTp, token, handle, dnA, dnB, spMatC, dnCType, buffer);
+  token = sddmmComp.getAsyncToken();
+
+  // Copy data back to host and free all the resoures.
+  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnA)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySparseEnvOp>(loc, tokenTp, token, handle)
+              .getAsyncToken();
+  token = genDeallocMemRef(rewriter, loc, buffer, token);
+  token = genDeallocMemRef(rewriter, loc, matA, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, rowC, token);
+  if (colC)
+    token = genDeallocMemRef(rewriter, loc, colC, token);
+  token = genCopyMemRef(rewriter, loc, memV, valC, token);
+  token = genDeallocMemRef(rewriter, loc, valC, token);
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriting rules for direct code generation.
 //===----------------------------------------------------------------------===//
@@ -778,6 +954,18 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
       return rewriteSpMM(rewriter, op, enableRT);
     }
 
+    // Recognize a SDDMM kernel.
+    if (numLoops == 3 && numTensors == 3 &&
+        linalg::isParallelIterator(iteratorTypes[0]) &&
+        linalg::isParallelIterator(iteratorTypes[1]) &&
+        linalg::isReductionIterator(iteratorTypes[2]) &&
+        // TODO: add transposed {i, k}, {k, j}
+        // TODO: maybe add transposed {i, j} in future
+        maps == infer({{i, k}, {k, j}, {i, j}}) &&
+        matchSumReductionOfMulUnary(op)) {
+      return rewriteSDDMM(rewriter, op, enableRT);
+    }
+
     return failure();
   }
 

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 47fee38e84d2c..abb4d05760fba 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -426,6 +426,7 @@ mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
                                         CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }
 
+// TODO: add support to passing alpha and beta as arguments
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
                     int32_t ctp, CUstream /*stream*/) {

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
new file mode 100644
index 0000000000000..0ac9b4d76b38c
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --sparsification="enable-gpu-libgen" | FileCheck %s
+
+#trait_sampled_dense_dense = {
+  indexing_maps = [
+  affine_map<(i,j,k) -> (i,k)>,  // A
+  affine_map<(i,j,k) -> (k,j)>,  // B
+  affine_map<(i,j,k) -> (i,j)>   // S (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
+}
+
+#trait_vec_op = {
+  indexing_maps = [
+  affine_map<(i,j) -> (i,j)>,  // a (in)
+  affine_map<(i,j) -> (i,j)>,  // b (in)
+  affine_map<(i,j) -> (i,j)>   // x (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
+#CSR = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+module {
+
+// CHECK-LABEL:   func.func @sparse_sampled_dd(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME:                                 %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> {
+// CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
+// CHECK:           %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK:           %[[VAL_7:.*]] = gpu.wait async
+// CHECK:           %[[VAL_8:.*]], %[[VAL_9:.*]] = gpu.alloc async {{\[}}%[[VAL_7]]] () : memref<8x8xf64>
+// CHECK:           %[[VAL_10:.*]] = gpu.memcpy async {{\[}}%[[VAL_9]]] %[[VAL_8]], %[[VAL_6]] : memref<8x8xf64>, memref<8x8xf64>
+// CHECK:           %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK:           %[[VAL_12:.*]] = gpu.wait async
+// CHECK:           %[[VAL_13:.*]], %[[VAL_14:.*]] = gpu.alloc async {{\[}}%[[VAL_12]]] () : memref<8x8xf64>
+// CHECK:           %[[VAL_15:.*]] = gpu.memcpy async {{\[}}%[[VAL_14]]] %[[VAL_13]], %[[VAL_11]] : memref<8x8xf64>, memref<8x8xf64>
+// CHECK:           %[[VAL_16:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xindex>
+// CHECK:           %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>> to memref<?xf64>
+// CHECK:           %[[VAL_19:.*]] = gpu.wait async
+// CHECK:           %[[VAL_20:.*]] = memref.dim %[[VAL_16]], %[[VAL_4]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]], %[[VAL_22:.*]] = gpu.alloc async {{\[}}%[[VAL_19]]] (%[[VAL_20]]) : memref<?xindex>
+// CHECK:           %[[VAL_23:.*]] = gpu.memcpy async {{\[}}%[[VAL_22]]] %[[VAL_21]], %[[VAL_16]] : memref<?xindex>, memref<?xindex>
+// CHECK:           %[[VAL_24:.*]] = gpu.wait async
+// CHECK:           %[[VAL_25:.*]] = memref.dim %[[VAL_17]], %[[VAL_4]] : memref<?xindex>
+// CHECK:           %[[VAL_26:.*]], %[[VAL_27:.*]] = gpu.alloc async {{\[}}%[[VAL_24]]] (%[[VAL_25]]) : memref<?xindex>
+// CHECK:           %[[VAL_28:.*]] = gpu.memcpy async {{\[}}%[[VAL_27]]] %[[VAL_26]], %[[VAL_17]] : memref<?xindex>, memref<?xindex>
+// CHECK:           %[[VAL_29:.*]] = gpu.wait async
+// CHECK:           %[[VAL_30:.*]] = memref.dim %[[VAL_18]], %[[VAL_4]] : memref<?xf64>
+// CHECK:           %[[VAL_31:.*]], %[[VAL_32:.*]] = gpu.alloc async {{\[}}%[[VAL_29]]] (%[[VAL_30]]) : memref<?xf64>
+// CHECK:           %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_18]] : memref<?xf64>, memref<?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_10]], %[[VAL_15]], %[[VAL_23]], %[[VAL_28]], %[[VAL_33]]]
+// CHECK:           %[[VAL_34:.*]] = gpu.wait async
+// CHECK:           %[[VAL_35:.*]], %[[VAL_36:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_34]]]
+// CHECK:           %[[VAL_37:.*]], %[[VAL_38:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_36]]] %[[VAL_35]], %[[VAL_8]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64>
+// CHECK:           %[[VAL_39:.*]], %[[VAL_40:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_38]]] %[[VAL_35]], %[[VAL_13]], %[[VAL_3]], %[[VAL_3]] : index, index into memref<8x8xf64>
+// CHECK:           %[[VAL_41:.*]], %[[VAL_42:.*]] = gpu.create_csr async {{\[}}%[[VAL_40]]] %[[VAL_3]], %[[VAL_3]], %[[VAL_5]], %[[VAL_21]], %[[VAL_26]], %[[VAL_31]] : memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:           %[[VAL_43:.*]], %[[VAL_44:.*]] = gpu.sddmm_buffer_size async {{\[}}%[[VAL_42]]] %[[VAL_35]], %[[VAL_37]], %[[VAL_39]], %[[VAL_41]] into f64
+// CHECK:           %[[VAL_45:.*]], %[[VAL_46:.*]] = gpu.alloc async {{\[}}%[[VAL_44]]] (%[[VAL_43]]) : memref<?xi8>
+// CHECK:           %[[VAL_47:.*]] = gpu.sddmm async {{\[}}%[[VAL_46]]] %[[VAL_35]], %[[VAL_37]], %[[VAL_39]], %[[VAL_41]], %[[VAL_45]] : memref<?xi8> into f64
+// CHECK:           %[[VAL_48:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_47]]] %[[VAL_37]]
+// CHECK:           %[[VAL_49:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_48]]] %[[VAL_39]]
+// CHECK:           %[[VAL_50:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_49]]] %[[VAL_41]]
+// CHECK:           %[[VAL_51:.*]] = gpu.destroy_sparse_env async {{\[}}%[[VAL_50]]] %[[VAL_35]]
+// CHECK:           %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_45]] : memref<?xi8>
+// CHECK:           %[[VAL_53:.*]] = gpu.dealloc async {{\[}}%[[VAL_52]]] %[[VAL_8]] : memref<8x8xf64>
+// CHECK:           %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_13]] : memref<8x8xf64>
+// CHECK:           %[[VAL_55:.*]] = gpu.dealloc async {{\[}}%[[VAL_54]]] %[[VAL_21]] : memref<?xindex>
+// CHECK:           %[[VAL_56:.*]] = gpu.dealloc async {{\[}}%[[VAL_55]]] %[[VAL_26]] : memref<?xindex>
+// CHECK:           %[[VAL_57:.*]] = gpu.memcpy async {{\[}}%[[VAL_56]]] %[[VAL_18]], %[[VAL_31]] : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_58:.*]] = gpu.dealloc async {{\[}}%[[VAL_57]]] %[[VAL_31]] : memref<?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_58]]]
+// CHECK:           %[[VAL_59:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
+// CHECK:           return %[[VAL_59]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
+// CHECK:         }
+//
+// A kernel that computes a direct sampled matrix matrix multiplication
+// (with sparse result).
+// Compute SDDMM C = C\spy AB
+// 
+func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>,
+                               %argA: tensor<8x8xf64>,
+                               %argB: tensor<8x8xf64>) -> tensor<8x8xf64, #CSR> {
+    %result = linalg.generic #trait_sampled_dense_dense
+      ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>)
+      outs(%argS: tensor<8x8xf64, #CSR>) {
+        ^bb(%a: f64, %b: f64, %s: f64):
+           %f0 = arith.constant 0.0 : f64
+           %u = sparse_tensor.unary %s : f64 to f64
+             present={
+                ^bb0(%p: f64):
+                  %mul = arith.mulf %a, %b : f64
+                  sparse_tensor.yield %mul : f64
+             }
+             absent={}
+           %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
+              ^bb0(%p: f64, %q: f64):
+                %add = arith.addf %p, %q : f64
+                sparse_tensor.yield %add : f64
+            }
+           linalg.yield %r : f64
+    } -> tensor<8x8xf64, #CSR>
+    return %result : tensor<8x8xf64, #CSR>
+  }
+
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
new file mode 100644
index 0000000000000..7eabf73739ebe
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
@@ -0,0 +1,119 @@
+//
+// NOTE: this test requires gpu-sm80
+//
+// RUN: mlir-opt %s \
+// RUN:   --sparse-compiler="enable-runtime-library=true enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71"  \
+// RUN: | TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
+// RUN:   mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --e entry --entry-point-result=void \
+// RUN: | FileCheck %s
+//
+
+!Filename = !llvm.ptr<i8>
+
+#CSR = #sparse_tensor.encoding<{
+  lvlTypes = ["dense", "compressed"]
+}>
+
+#trait_sampled_dense_dense = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,k)>,  // A
+    affine_map<(i,j,k) -> (k,j)>,  // B
+    affine_map<(i,j,k) -> (i,j)>   // S (out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
+}
+
+//
+// Integration test that lowers a kernel annotated as sparse to
+// actual sparse code, initializes a matching sparse storage scheme
+// from file, and runs the resulting code with the JIT compiler.
+//
+module {
+  //
+  // A kernel that computes a sampled matrix matrix multiplication.
+  //
+  func.func @sampled_dense_dense(%args: tensor<?x?xf32, #CSR>,
+                                 %arga: tensor<?x?xf32>,
+                                 %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #CSR> {
+    %result = linalg.generic #trait_sampled_dense_dense
+      ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%args: tensor<?x?xf32, #CSR>) {
+        ^bb(%a: f32, %b: f32, %s: f32):
+           %f0 = arith.constant 0.0 : f32
+           %u = sparse_tensor.unary %s : f32 to f32
+             present={
+                ^bb0(%p: f32):
+                  %mul = arith.mulf %a, %b : f32
+                  sparse_tensor.yield %mul : f32
+             }
+             absent={}
+           %r = sparse_tensor.reduce %s, %u, %f0 : f32 {
+              ^bb0(%p: f32, %q: f32):
+                %add = arith.addf %p, %q : f32
+                sparse_tensor.yield %add : f32
+            }
+           linalg.yield %r : f32
+      } -> tensor<?x?xf32, #CSR>
+    return %result : tensor<?x?xf32, #CSR>
+  }
+
+  func.func private @getTensorFilename(index) -> (!Filename)
+
+  //
+  // Main driver that reads matrix from file and calls the sparse kernel.
+  //
+  func.func @entry() {
+    %d0 = arith.constant 0.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c5 = arith.constant 5 : index
+    %c10 = arith.constant 10 : index
+
+    // Initialize dense matrices.
+    %x = tensor.generate %c5, %c5 {
+    ^bb0(%i : index, %j : index):
+      tensor.yield %d0 : f32
+    } : tensor<?x?xf32>
+
+    %a = tensor.generate %c5, %c10 {
+    ^bb0(%i: index, %j: index):
+      %p = arith.addi %i, %c1 : index
+      %q = arith.index_cast %p : index to i32
+      %d = arith.sitofp %q : i32 to f32
+      tensor.yield %d : f32
+    } : tensor<?x?xf32>
+
+    %b = tensor.generate %c10, %c5 {
+    ^bb0(%i: index, %j: index):
+      %p = arith.addi %j, %c1 : index
+      %q = arith.index_cast %p : index to i32
+      %d = arith.sitofp %q : i32 to f32
+      tensor.yield %d : f32
+    } : tensor<?x?xf32>
+
+    // Read the sparse matrix from file, construct sparse storage.
+    %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
+    %s = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #CSR>
+
+    // Call the kernel.
+    %0 = call @sampled_dense_dense(%s, %a, %b)
+       : (tensor<?x?xf32, #CSR>,
+          tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #CSR>
+
+    // Print the result for verification.
+    //
+    // CHECK: ( 11, 41.4, 42, 102.5, 93, 44.1, 164, 105.2, 255 )
+    %vm = sparse_tensor.values %0 : tensor<?x?xf32, #CSR> to memref<?xf32>
+    %vv = vector.transfer_read %vm[%c0], %d0 : memref<?xf32>, vector<9xf32>
+    vector.print %vv : vector<9xf32>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %0 : tensor<?x?xf32, #CSR>
+
+    return
+  }
+}


        


More information about the Mlir-commits mailing list