[Mlir-commits] [mlir] [mlir][sparse] support BSR for cuSPARSE (libgen path only) (PR #69646)

Aart Bik llvmlistbot at llvm.org
Thu Oct 19 15:04:56 PDT 2023


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/69646

>From ea861d7daaad51013b69560cf8e8eb31ef9b06ca Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 19 Oct 2023 14:32:27 -0700
Subject: [PATCH 1/2] [mlir][sparse] support BSR for cuSPARSE (libgen path
 only)

---
 .../Transforms/SparseGPUCodegen.cpp           |  69 +++++--
 .../GPU/CUDA/sparse-sampled-matmul-lib.mlir   |  13 +-
 .../GPU/CUDA/sparse-sddmm-lib.mlir            | 189 ++++++++++++++++++
 3 files changed, 246 insertions(+), 25 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index a6e963181816f7b..e411032e1c1c37c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -39,7 +39,7 @@ enum class CuSparseFormat {
   kCOO,
   kCSR,
   kCSC,
-  kBSR, // TODO: coming soon!
+  kBSR,
 };
 
 //===----------------------------------------------------------------------===//
@@ -428,6 +428,19 @@ static bool isAdmissibleCSC(SparseTensorType &aTp) {
          aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
 }
 
+/// Test for BSR matrix with suitable metadata.
+static bool isAdmissibleBSR(SparseTensorType &aTp) {
+  if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
+      aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
+      aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
+    // CuSparse only supports "square" blocks currently.
+    SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
+    assert (dims.size() == 2);
+    return dims[0] = dims[1];
+  }
+  return false;
+}
+
 /// Returns a suitable sparse format for the operation and given operand
 /// types with cuSparse, or kNone if none is available.
 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -448,6 +461,8 @@ static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
     return CuSparseFormat::kCSR;
   if (isAdmissibleCSC(aTp))
     return CuSparseFormat::kCSC;
+  if (isAdmissibleBSR(aTp))
+    return CuSparseFormat::kBSR;
   return CuSparseFormat::kNone;
 }
 
@@ -475,9 +490,10 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
 }
 
 /// Generates the sparse matrix handle.
-static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
-                           Type tokenTp, Value token, Value sz1, Value sz2,
-                           Value nseA, Value rowA, Value colA, Value valA,
+static Operation *genSpMat(OpBuilder &builder, Location loc,
+                           SparseTensorType &aTp, Type handleTp, Type tokenTp,
+                           Value token, Value sz1, Value sz2, Value nseA,
+                           Value rowA, Value colA, Value valA,
                            CuSparseFormat format, bool enableRT) {
   if (format == CuSparseFormat::kCOO) {
     // Library uses SoA COO, direct IR uses AoS COO.
@@ -498,9 +514,24 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
   if (format == CuSparseFormat::kCSR)
     return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
                                             sz2, nseA, rowA, colA, valA);
-  assert(format == CuSparseFormat::kCSC);
-  return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
-                                          sz2, nseA, rowA, colA, valA);
+  if (format == CuSparseFormat::kCSC)
+    return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
+                                            sz2, nseA, rowA, colA, valA);
+  // BSR requires a bit more work since we need to pass in the block size
+  // and all others sizes in terms of blocks (#block-rows, #block-cols,
+  // #nonzero-blocks).
+  assert(format == CuSparseFormat::kBSR);
+  SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
+  assert(dims.size() == 2 && dims[0] == dims[1]);
+  uint64_t b = dims[0];
+  Value bSz = constantIndex(builder, loc, b);
+  Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
+  Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
+  Value bNum = builder.create<arith::DivUIOp>(
+      loc, nseA, constantIndex(builder, loc, b * b));
+  return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
+                                          bCols, bNum, bSz, bSz, rowA, colA,
+                                          valA);
 }
 
 /// Match and rewrite SpMV kernel.
@@ -566,8 +597,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
-      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
-               rowA, colA, valA, format, enableRT);
+      genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
+               nseA, rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
@@ -691,8 +722,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
-      genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
-               rowA, colA, valA, format, enableRT);
+      genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
+               nseA, rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -806,13 +837,13 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
   Value token = genFirstWait(rewriter, loc);
   Operation *spGenA =
-      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
-               rowA, colA, valA, format, enableRT);
+      genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
+               nseA, rowA, colA, valA, format, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   Operation *spGenB =
-      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
-               rowB, colB, valB, format, enableRT);
+      genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
+               nseB, rowB, colB, valB, format, enableRT);
   Value spMatB = spGenB->getResult(0);
   token = spGenB->getResult(1);
 
@@ -830,8 +861,8 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value valC = e3.getResult(0); // no free needed
   token = e3.getAsyncToken();
   Operation *spGenC =
-      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
-               rowC, colC, valC, format, enableRT);
+      genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
+               zero, rowC, colC, valC, format, enableRT);
   Value spMatC = spGenC->getResult(0);
   token = spGenC->getResult(1);
 
@@ -1137,8 +1168,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value dnB = dmatB.getResult(0);
   token = dmatB.getAsyncToken();
   Operation *spGenC =
-      genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
-               rowC, colC, valC, format, enableRT);
+      genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
+               nseC, rowC, colC, valC, format, enableRT);
   Value spMatC = spGenC->getResult(0);
   token = spGenC->getResult(1);
   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
index 61de57564beda2e..4c1466f4202e378 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir
@@ -3,7 +3,8 @@
 //
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
-// DEFINE: %{run} = TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
+// DEFINE: %{run} = \
+// DEFINE:   TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
 // DEFINE:   mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
@@ -12,16 +13,16 @@
 //
 // with RT lib:
 //
-//  RUN:  %{compile} enable-runtime-library=true" | %{run}
-//  RUN:  %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
-//  Tracker #64316
-//  RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
+// RUN:  %{compile} enable-runtime-library=true" | %{run}
+// RUN:  %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
+// TODO: Tracker #64316
+// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
 //
 // without RT lib:
 //
 // RUN:  %{compile} enable-runtime-library=false" | %{run}
 // RUN:  %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run}
-//  Tracker #64316
+// TODO:  Tracker #64316
 // RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run}
 //
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir
new file mode 100644
index 000000000000000..ae84bcc4013e40e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sddmm-lib.mlir
@@ -0,0 +1,189 @@
+//
+// NOTE: this test requires gpu-sm80
+//
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE: %{run} = \
+// DEFINE:   TENSOR0="%mlir_src_dir/test/Integration/data/block.mtx" \
+// DEFINE:   mlir-cpu-runner \
+// DEFINE:   --shared-libs=%mlir_cuda_runtime \
+// DEFINE:   --shared-libs=%mlir_c_runner_utils \
+// DEFINE:   --e entry --entry-point-result=void \
+// DEFINE: | FileCheck %s
+//
+// with RT lib:
+//
+// RUN:  %{compile} enable-runtime-library=true" | %{run}
+//
+// without RT lib:
+//
+// TODO: make this work
+// R_UN:  %{compile} enable-runtime-library=false" | %{run}
+//
+
+!Filename = !llvm.ptr<i8>
+
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+#BSR = #sparse_tensor.encoding<{
+  map = (i, j) -> (
+    i floordiv 2 : dense,
+    j floordiv 2 : compressed,
+    i mod 2 : dense,
+    j mod 2 : dense)
+}>
+
+#trait_SDDMM = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,k)>,  // A
+    affine_map<(i,j,k) -> (k,j)>,  // B
+    affine_map<(i,j,k) -> (i,j)>   // S (in/out)
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)"
+}
+
+//
+// Integration test that lowers a kernel annotated as sparse to
+// actual sparse code, initializes sparse storage schemes, and
+// runs the resulting code with the JIT compiler.
+//
+module {
+  llvm.func @mgpuCreateSparseEnv()
+  llvm.func @mgpuDestroySparseEnv()
+
+  //
+  // A kernel that computes a CSR sampled dense matrix matrix multiplication
+  // using a "spy" function and in-place update of the sampling sparse matrix.
+  //
+  func.func @SDDMM(%args: tensor<?x?xf32, #CSR>,
+                   %arga: tensor<?x?xf32>,
+                   %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #CSR> {
+    %result = linalg.generic #trait_SDDMM
+      ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%args: tensor<?x?xf32, #CSR>) {
+        ^bb(%a: f32, %b: f32, %s: f32):
+           %f0 = arith.constant 0.0 : f32
+           %u = sparse_tensor.unary %s : f32 to f32
+             present={
+                ^bb0(%p: f32):
+                  %mul = arith.mulf %a, %b : f32
+                  sparse_tensor.yield %mul : f32
+             }
+             absent={}
+           %r = sparse_tensor.reduce %s, %u, %f0 : f32 {
+              ^bb0(%p: f32, %q: f32):
+                %add = arith.addf %p, %q : f32
+                sparse_tensor.yield %add : f32
+            }
+           linalg.yield %r : f32
+      } -> tensor<?x?xf32, #CSR>
+    return %result : tensor<?x?xf32, #CSR>
+  }
+
+  //
+  // A kernel that computes a BSR sampled dense matrix matrix multiplication
+  // using a "spy" function and in-place update of the sampling sparse matrix.
+  //
+  func.func @SDDMM_block(%args: tensor<?x?xf32, #BSR>,
+                         %arga: tensor<?x?xf32>,
+                         %argb: tensor<?x?xf32>) -> tensor<?x?xf32, #BSR> {
+    %result = linalg.generic #trait_SDDMM
+      ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%args: tensor<?x?xf32, #BSR>) {
+        ^bb(%a: f32, %b: f32, %s: f32):
+           %f0 = arith.constant 0.0 : f32
+           %u = sparse_tensor.unary %s : f32 to f32
+             present={
+                ^bb0(%p: f32):
+                  %mul = arith.mulf %a, %b : f32
+                  sparse_tensor.yield %mul : f32
+             }
+             absent={}
+           %r = sparse_tensor.reduce %s, %u, %f0 : f32 {
+              ^bb0(%p: f32, %q: f32):
+                %add = arith.addf %p, %q : f32
+                sparse_tensor.yield %add : f32
+            }
+           linalg.yield %r : f32
+      } -> tensor<?x?xf32, #BSR>
+    return %result : tensor<?x?xf32, #BSR>
+  }
+
+  func.func private @getTensorFilename(index) -> (!Filename)
+
+  //
+  // Main driver.
+  //
+  func.func @entry() {
+    llvm.call @mgpuCreateSparseEnv() : () -> ()
+    %d0 = arith.constant 0.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    %c6 = arith.constant 6 : index
+
+    // Initialize dense matrices.
+    %a = tensor.generate %c4, %c4 {
+    ^bb0(%i: index, %j: index):
+      %p = arith.addi %i, %c1 : index
+      %q = arith.index_cast %p : index to i32
+      %d = arith.sitofp %q : i32 to f32
+      tensor.yield %d : f32
+    } : tensor<?x?xf32>
+    %b = tensor.generate %c4, %c6 {
+    ^bb0(%i: index, %j: index):
+      %p = arith.addi %j, %c1 : index
+      %q = arith.index_cast %p : index to i32
+      %d = arith.sitofp %q : i32 to f32
+      tensor.yield %d : f32
+    } : tensor<?x?xf32>
+
+    // Read the sparse matrix from file, construct sparse storage.
+    //
+    //      +-----+-----+-----+
+    //      | 1 2 | . . | 4 . |
+    //      | . 3 | . . | . 5 |
+    //      +-----+-----+-----+
+    //      | . . | 6 7 | . . |
+    //      | . . | 8 . | . . |
+    //      +-----+-----+-----+
+    //
+    %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
+    %m_csr = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #CSR>
+    %m_bsr = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #BSR>
+
+    // Call the kernel.
+    %0 = call @SDDMM(%m_csr, %a, %b)
+       : (tensor<?x?xf32, #CSR>,
+          tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #CSR>
+    %1 = call @SDDMM_block(%m_bsr, %a, %b)
+       : (tensor<?x?xf32, #BSR>,
+          tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #BSR>
+
+    //
+    // Print the result for verification. Note that the "spy" determines what
+    // dot products are sampled, but the original contents are added back to
+    // the result (which is why the block sparse version has actual results
+    // the the original zero positions).
+    //
+    // CHECK:      ( 5, 10, 24, 19, 53, 42, 55, 56 )
+    // CHECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 )
+    //
+    %v0 = sparse_tensor.values %0 : tensor<?x?xf32, #CSR> to memref<?xf32>
+    %vv0 = vector.transfer_read %v0[%c0], %d0 : memref<?xf32>, vector<8xf32>
+    vector.print %vv0 : vector<8xf32>
+    %v1 = sparse_tensor.values %1 : tensor<?x?xf32, #BSR> to memref<?xf32>
+    %vv1 = vector.transfer_read %v1[%c0], %d0 : memref<?xf32>, vector<12xf32>
+    vector.print %vv1 : vector<12xf32>
+
+    // Release the resources.
+    bufferization.dealloc_tensor %0 : tensor<?x?xf32, #CSR>
+    bufferization.dealloc_tensor %1 : tensor<?x?xf32, #BSR>
+
+    llvm.call @mgpuDestroySparseEnv() : () -> ()
+    return
+  }
+}

>From 954aeef23bd3b1f0e319bcd14b49a2a70b925097 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 19 Oct 2023 15:04:31 -0700
Subject: [PATCH 2/2] reject block size of 0, 1

---
 mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index e411032e1c1c37c..3561e187fb041d6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -436,7 +436,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
     // CuSparse only supports "square" blocks currently.
     SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
     assert (dims.size() == 2);
-    return dims[0] = dims[1];
+    return dims[0] = dims[1] && dims[0] > 1;
   }
   return false;
 }



More information about the Mlir-commits mailing list