[Mlir-commits] [mlir] [mlir][sparse] recognize NVidia 2:4 type for matmul (PR #76758)

Aart Bik llvmlistbot at llvm.org
Tue Jan 2 14:24:29 PST 2024


https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/76758

>From 93e9e93a47334d2a42ab2c03e8da24adf14bf8b3 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 2 Jan 2024 14:19:09 -0800
Subject: [PATCH 1/2] [mlir][sparse] recognize NVidia 2:4 type for matmul

This removes the temporary DENSE24 attribute and replaces
it with proper recognition of dense to 24 conversion. The
compressionh will be performed on the device prior to
performing the matrix mult. Note that we no longer need
to start with the linalg version, we can lift this to
the proper named linalg op. Also renames some files
into more consistent names.
---
 .../Transforms/SparseGPUCodegen.cpp           |  27 +++-
 .../ExecutionEngine/CudaRuntimeWrappers.cpp   |   2 +-
 ...ul_lib_2to4.mlir => gpu_matmul24_lib.mlir} |  28 ++--
 ...inalg.mlir => sparse-matmul-2-4-hand.mlir} | 139 +++++++++--------
 .../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir   | 141 ++++++++----------
 .../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir |  41 +++--
 6 files changed, 203 insertions(+), 175 deletions(-)
 rename mlir/test/Dialect/SparseTensor/GPU/{gpu_matmul_lib_2to4.mlir => gpu_matmul24_lib.mlir} (88%)
 rename mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/{sparse-matmul-2-4-lib-from-linalg.mlir => sparse-matmul-2-4-hand.mlir} (66%)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 8af3b694c4d975..3b19f0979d414c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -448,6 +448,22 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
   return false;
 }
 
+/// Test for 2:4 matrix with suitable metadata.
+static bool isAdmissible24(SparseTensorType &aTp) {
+  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+}
+
+/// Test for conversion into 2:4 matrix.
+static bool isConversionInto24(Value v) {
+  if (auto cnv = v.getDefiningOp<ConvertOp>()) {
+    Value a = cnv.getResult();
+    Value d = cnv.getSource();
+    SparseTensorType aTp = getSparseTensorType(a);
+    return isDenseTensor(d) && isAdmissible24(aTp);
+  }
+  return false;
+}
+
 /// Returns a suitable sparse format for the operation and given operand
 /// types with cuSparse, or kNone if none is available.
 static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -925,6 +941,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
   Value C = op.getOperand(2); // we have C = AB
   SmallVector<Value> tokens;
 
+  // The cuSparselt API currently only allows pruning and compression
+  // to occur on the device. So we recognize the pattern
+  //    A' = convert A  ; dense to 2:4
+  //    C  = A'B        ; 2:4 matrix mult
+  // and then perform compression and matrix multiplication on device.
+  auto cnv = A.getDefiningOp<ConvertOp>();
+  assert(cnv);
+  A = cnv.getSource();
+
   // All input should be dense tensors.
   if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
     return failure();
@@ -1260,7 +1285,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
       if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
         return rewriteSpGEMM(rewriter, op, enableRT);
-      if (op->getAttr("DENSE24"))
+      if (isConversionInto24(op.getOperand(0)))
         return rewrite2To4SpMM(rewriter, op);
       return rewriteSpMM(rewriter, op, enableRT);
     }
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index c45320a674568a..b9a3429e37b885 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -970,7 +970,7 @@ mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
   // Note that this adds a synchronization on the stream.
   // TODO: Do we want that?
   if (prune_flag == 2) {
-    int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream);
+    int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
     CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
         &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
     int valid = 0;
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
similarity index 88%
rename from mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir
rename to mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index f584977e96415b..6fe7ec906f30e9 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -1,5 +1,13 @@
 // RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s
 
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i            : dense,
+    j floordiv 4 : dense,
+    j mod 4      : block2_4
+  )
+}>
+
 // CHECK-LABEL:   func.func @matmul(
 // CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?x?xf16>,
 // CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?x?xf16>,
@@ -51,18 +59,14 @@
 // CHECK:           %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
 // CHECK:           return %[[VAL_55]] : tensor<?x?xf16>
 // CHECK:         }
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
 module {
-  func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
-    %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
-    ^bb0(%in: f16, %in_0: f16, %out: f16):
-      %1 = arith.mulf %in, %in_0 : f16
-      %2 = arith.addf %out, %1 : f16
-      linalg.yield %2 : f16
-    } -> tensor<?x?xf16>
-    return %0 : tensor<?x?xf16>
+  func.func @matmul(%Ad: tensor<?x?xf16>,
+                    %B: tensor<?x?xf16>,
+		    %Cin: tensor<?x?xf16>) -> tensor<?x?xf16> {
+    %A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24>
+    %C = linalg.matmul
+      ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>)
+      outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16>
+    return %C : tensor<?x?xf16>
   }
 }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir
similarity index 66%
rename from mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
rename to mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir
index d7e9cedec4ccd7..117832df95b464 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir
@@ -1,40 +1,58 @@
 // NOTE: this test requires gpu-sm80 and cusparselt
 //
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE:   --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
+// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
+// DEFINE: %s
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
 // DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
 //
-// with RT lib:
-//
-// RUN: %{compile} enable-runtime-library=true"  | %{run}
-//
-// without RT lib:
-//
-// RUN: %{compile} enable-runtime-library=false" | %{run}
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// RUN: %{compile} | %{run}
 
 module {
   llvm.func @mgpuCreateSparseLtEnv()
   llvm.func @mgpuDestroySparseLtEnv()
 
-  //
-  // TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
-  //
-  func.func @matmul_2to4(%arg0: tensor<16x32xf16>, %arg1: tensor<32x16xf16>, %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
-    %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32xf16>, tensor<32x16xf16>) outs(%arg2 : tensor<16x16xf16>) {
-    ^bb0(%in: f16, %in_0: f16, %out: f16):
-      %1 = arith.mulf %in, %in_0 : f16
-      %2 = arith.addf %out, %1 : f16
-      linalg.yield %2 : f16
-    } -> tensor<16x16xf16>
-    return %0 : tensor<16x16xf16>
+  // cuSparselt version for matmul coded by hand.
+  func.func @matmul24(%a : memref<16x32xf16>,
+                      %b : memref<32x16xf16>,
+                      %c : memref<16x16xf16>) {
+    %c0  = arith.constant 0.0 : f16
+    %c1  = arith.constant 1   : index
+    %c2  = arith.constant 2   : index
+    %c8  = arith.constant 8   : index
+    %c16 = arith.constant 16  : index
+    %c32 = arith.constant 32  : index
+    %c1048576 = arith.constant 1048576 : index
+    %token0 = gpu.wait async
+    %d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
+    %d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
+    %d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
+    %token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
+    %token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
+    %token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
+    %spmat, %token8 = gpu.create_2to4_spmat async  [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
+    %dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
+    %dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
+    %bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
+    %mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
+    %mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
+    %mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
+    %token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
+    %token16 = gpu.destroy_sp_mat async [%token15] %spmat
+    %token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
+    %token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
+    %token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
+    %token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
+    %token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
+    %token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
+    %token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
+    %token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
+    %token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
+    gpu.wait [%token25]
+    return
   }
 
   //
@@ -54,50 +72,49 @@ module {
     %c64 = arith.constant 64  : index
 
     // Matrices A, B, C (16x32, 32x16, 16x16).
+    %a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
+    %b = memref.alloc() : memref<32x16xf16> // regular dense   column-major
+    %c = memref.alloc() : memref<16x16xf16> // accumulator     row-major
 
     //
     // Setup matrix A.
     //
-    %DA = tensor.generate {
-    ^bb0(%i: index, %j: index):
-      // (i+ j/2 + 1) if j %2 == 0 else 0
-      %cf0 = arith.constant 0.0 : f16
-      %cf1 = arith.constant 1.0 : f16
-      %j_2 = arith.floordivsi %j, %c2 : index
-      %quotient = arith.remsi %j, %c2 : index
-      %sum = arith.addi %i, %j_2 : index
-      %sum_i = arith.index_cast %sum : index to i64
-      %sum_f = arith.uitofp %sum_i : i64 to f16
-      %sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
-      %is_zero = arith.cmpi "eq", %quotient, %c0 : index
-      %s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
-      tensor.yield %s : f16
-    } : tensor<16x32xf16>
+    scf.for %ai = %c0 to %c16 step %c1 {
+      scf.for %aj = %c0 to %c16 step %c1 {
+        %cf0  = arith.constant 0.0: f16
+        %a0 = arith.addi %ai, %aj : index
+        %a1 = arith.addi %a0, %c1 : index
+        %a2 = arith.index_cast %a1 : index to i32
+        %a3 = arith.sitofp %a2 : i32 to f16
+        %ajj = arith.muli %aj, %c2 : index
+        %ajj2 = arith.addi %ajj, %c1 : index
+        memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
+        memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
+      }
+    }
 
     //
     // Setup matrix B.
     //
-    %DB = tensor.generate {
-    ^bb0(%i: index, %j: index):
-      // if j_i >=8, j_i - 8 else 0
-      %is_ge8 = arith.cmpi "sge", %j, %c8 : index
-      %j_minus8 = arith.subi %j, %c8 : index
-      %j2 = arith.select %is_ge8, %j_minus8, %j : index
-      %r_i = arith.subi %j2, %i : index
-      %r_i64 = arith.index_cast %r_i : index to i64
-      %r_f = arith.sitofp %r_i64 : i64 to f16
-      tensor.yield %r_f : f16
-    } : tensor<32x16xf16>
+    scf.for %bi = %c0 to %c8 step %c1 {
+      scf.for %bj = %c0 to %c32 step %c1 {
+        %b0 = arith.subi %bi, %bj : index
+        %b1 = arith.index_cast %b0 : index to i32
+        %b2 = arith.sitofp %b1 : i32 to f16
+        %bii = arith.addi %bi, %c8 : index
+        memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
+        memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
+      }
+    }
 
     //
     // Reset matrix C.
     //
-    %DC = tensor.generate {
-    ^bb0(%i: index, %j: index):
-      %cf0 = arith.constant 0.0 : f16
-      tensor.yield %cf0 : f16
-    } : tensor<16x16xf16>
-
+    scf.for %ci = %c0 to %c16 step %c1 {
+      scf.for %cj = %c0 to %c16 step %c1 {
+        memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
+      }
+    }
 
     //
     // Sanity check on 16x32 full 2:4 input matrix A.
@@ -121,7 +138,7 @@ module {
     // CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
     //
     scf.for %pai = %c0 to %c16 step %c1 {
-      %pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
+      %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
       vector.print %pa0 : vector<32xf16>
     }
 
@@ -163,14 +180,12 @@ module {
     //
     //
     scf.for %pbi = %c0 to %c32 step %c1 {
-      %pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
+      %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
       vector.print %pb0 : vector<16xf16>
     }
 
     // Call the kernel.
-    %t1  = arith.constant 1  : index
-    %t32 = arith.constant 32 : index
-    %c_out = call @matmul_2to4 (%DA, %DB, %DC): (tensor<16x32xf16>, tensor<32x16xf16>, tensor<16x16xf16>) -> tensor<16x16xf16>
+    call @matmul24(%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()
 
     //
     // Verify computed matrix C.
@@ -193,7 +208,7 @@ module {
     // CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688  )
     //
     scf.for %pci = %c0 to %c16 step %c1 {
-      %pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
+      %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
       vector.print %pc0 : vector<16xf16>
     }
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index daf29d5290bab0..17b50b46d073ae 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -1,57 +1,41 @@
 // NOTE: this test requires gpu-sm80 and cusparselt
 //
-// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
-// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
-// DEFINE: %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
 // DEFINE: %{run} = mlir-cpu-runner \
 // DEFINE:   --shared-libs=%mlir_cuda_runtime \
 // DEFINE:   --shared-libs=%mlir_c_runner_utils \
 // DEFINE:   --e main --entry-point-result=void \
 // DEFINE: | FileCheck %s
 //
-// RUN: %{compile} | %{run}
+// with RT lib:
+//
+// RUN: %{compile} enable-runtime-library=true"  | %{run}
+//
+// without RT lib:
+//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
+
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i            : dense,
+    j floordiv 4 : dense,
+    j mod 4      : block2_4
+  )
+}>
 
 module {
   llvm.func @mgpuCreateSparseLtEnv()
   llvm.func @mgpuDestroySparseLtEnv()
 
-  func.func @sampled_matmul(%a : memref<16x32xf16>,
-                            %b : memref<32x16xf16>,
-                            %c : memref<16x16xf16>) {
-    %c0  = arith.constant 0.0 : f16
-    %c1  = arith.constant 1   : index
-    %c2  = arith.constant 2   : index
-    %c8  = arith.constant 8   : index
-    %c16 = arith.constant 16  : index
-    %c32 = arith.constant 32  : index
-    %c1048576 = arith.constant 1048576 : index
-    %token0 = gpu.wait async
-    %d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
-    %d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
-    %d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
-    %token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
-    %token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
-    %token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
-    %spmat, %token8 = gpu.create_2to4_spmat async  [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
-    %dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
-    %dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
-    %bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
-    %mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
-    %mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
-    %mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
-    %token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
-    %token16 = gpu.destroy_sp_mat async [%token15] %spmat
-    %token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
-    %token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
-    %token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
-    %token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
-    %token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
-    %token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
-    %token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
-    %token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
-    %token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
-    gpu.wait [%token25]
-    return
+  func.func @matmul24(%Ad: tensor<16x32xf16>,
+                      %B: tensor<32x16xf16>,
+                      %Cin: tensor<16x16xf16>) -> tensor<16x16xf16> {
+    %A = sparse_tensor.convert %Ad : tensor<16x32xf16> to tensor<16x32xf16, #NV_24>
+    %C = linalg.matmul
+      ins(%A, %B: tensor<16x32xf16, #NV_24>, tensor<32x16xf16>)
+      outs(%Cin: tensor<16x16xf16>) -> tensor<16x16xf16>
+    return %C : tensor<16x16xf16>
   }
 
   //
@@ -71,49 +55,50 @@ module {
     %c64 = arith.constant 64  : index
 
     // Matrices A, B, C (16x32, 32x16, 16x16).
-    %a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
-    %b = memref.alloc() : memref<32x16xf16> // regular dense   column-major
-    %c = memref.alloc() : memref<16x16xf16> // accumulator     row-major
 
     //
     // Setup matrix A.
     //
-    scf.for %ai = %c0 to %c16 step %c1 {
-      scf.for %aj = %c0 to %c16 step %c1 {
-        %cf0  = arith.constant 0.0: f16
-        %a0 = arith.addi %ai, %aj : index
-        %a1 = arith.addi %a0, %c1 : index
-        %a2 = arith.index_cast %a1 : index to i32
-        %a3 = arith.sitofp %a2 : i32 to f16
-        %ajj = arith.muli %aj, %c2 : index
-        %ajj2 = arith.addi %ajj, %c1 : index
-        memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
-        memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
-      }
-    }
+    %DA = tensor.generate {
+    ^bb0(%i: index, %j: index):
+      // (i+ j/2 + 1) if j %2 == 0 else 0
+      %cf0 = arith.constant 0.0 : f16
+      %cf1 = arith.constant 1.0 : f16
+      %j_2 = arith.floordivsi %j, %c2 : index
+      %quotient = arith.remsi %j, %c2 : index
+      %sum = arith.addi %i, %j_2 : index
+      %sum_i = arith.index_cast %sum : index to i64
+      %sum_f = arith.uitofp %sum_i : i64 to f16
+      %sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
+      %is_zero = arith.cmpi "eq", %quotient, %c0 : index
+      %s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
+      tensor.yield %s : f16
+    } : tensor<16x32xf16>
 
     //
     // Setup matrix B.
     //
-    scf.for %bi = %c0 to %c8 step %c1 {
-      scf.for %bj = %c0 to %c32 step %c1 {
-        %b0 = arith.subi %bi, %bj : index
-        %b1 = arith.index_cast %b0 : index to i32
-        %b2 = arith.sitofp %b1 : i32 to f16
-        %bii = arith.addi %bi, %c8 : index
-        memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
-        memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
-      }
-    }
+    %DB = tensor.generate {
+    ^bb0(%i: index, %j: index):
+      // if j_i >=8, j_i - 8 else 0
+      %is_ge8 = arith.cmpi "sge", %j, %c8 : index
+      %j_minus8 = arith.subi %j, %c8 : index
+      %j2 = arith.select %is_ge8, %j_minus8, %j : index
+      %r_i = arith.subi %j2, %i : index
+      %r_i64 = arith.index_cast %r_i : index to i64
+      %r_f = arith.sitofp %r_i64 : i64 to f16
+      tensor.yield %r_f : f16
+    } : tensor<32x16xf16>
 
     //
     // Reset matrix C.
     //
-    scf.for %ci = %c0 to %c16 step %c1 {
-      scf.for %cj = %c0 to %c16 step %c1 {
-        memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
-      }
-    }
+    %DC = tensor.generate {
+    ^bb0(%i: index, %j: index):
+      %cf0 = arith.constant 0.0 : f16
+      tensor.yield %cf0 : f16
+    } : tensor<16x16xf16>
+
 
     //
     // Sanity check on 16x32 full 2:4 input matrix A.
@@ -137,7 +122,7 @@ module {
     // CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
     //
     scf.for %pai = %c0 to %c16 step %c1 {
-      %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
+      %pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
       vector.print %pa0 : vector<32xf16>
     }
 
@@ -179,12 +164,16 @@ module {
     //
     //
     scf.for %pbi = %c0 to %c32 step %c1 {
-      %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
+      %pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
       vector.print %pb0 : vector<16xf16>
     }
 
     // Call the kernel.
-    call @sampled_matmul (%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()
+    %t1  = arith.constant 1  : index
+    %t32 = arith.constant 32 : index
+    %c_out = call @matmul24(%DA, %DB, %DC): (tensor<16x32xf16>,
+                                             tensor<32x16xf16>,
+                                             tensor<16x16xf16>) -> tensor<16x16xf16>
 
     //
     // Verify computed matrix C.
@@ -207,7 +196,7 @@ module {
     // CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688  )
     //
     scf.for %pci = %c0 to %c16 step %c1 {
-      %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
+      %pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
       vector.print %pc0 : vector<16xf16>
     }
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index e307286002e394..eb99a027a88600 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -16,34 +16,27 @@
 //
 // RUN: %{compile} enable-runtime-library=false" | %{run}
 
-#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#NV_24 = #sparse_tensor.encoding<{
+  map = ( i, j ) ->
+  ( i            : dense,
+    j floordiv 4 : dense,
+    j mod 4      : block2_4
+  )
+}>
 
 module {
 
   llvm.func @mgpuCreateSparseLtEnv()
   llvm.func @mgpuDestroySparseLtEnv()
 
-  //
-  // TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
-  //
-  func.func @matmul(%arg0: tensor<16x16xf16>,
-                    %arg1: tensor<16x16xf16>,
-		    %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
-    %0 = linalg.generic {
-       DENSE24,
-       indexing_maps = [#map0, #map1, #map2],
-       iterator_types = ["parallel", "parallel", "reduction"]
-    }
-     ins(%arg0, %arg1 : tensor<16x16xf16>, tensor<16x16xf16>)
-     outs(%arg2 : tensor<16x16xf16>) {
-         ^bb0(%in: f16, %in_0: f16, %out: f16):
-           %1 = arith.mulf %in, %in_0 : f16
-           %2 = arith.addf %out, %1 : f16
-           linalg.yield %2 : f16
-       } -> tensor<16x16xf16>
-    return %0 : tensor<16x16xf16>
+  func.func @matmul24(%Ad: tensor<16x16xf16>,
+                      %B: tensor<16x16xf16>,
+                      %Cin: tensor<16x16xf16>) -> tensor<16x16xf16> {
+    %A = sparse_tensor.convert %Ad : tensor<16x16xf16> to tensor<16x16xf16, #NV_24>
+    %C = linalg.matmul
+      ins(%A, %B: tensor<16x16xf16, #NV_24>, tensor<16x16xf16>)
+      outs(%Cin: tensor<16x16xf16>) -> tensor<16x16xf16>
+    return %C : tensor<16x16xf16>
   }
 
   func.func @main() {
@@ -81,7 +74,9 @@ module {
     // By effectively computing D = A B + C with id(B) and zero(C)
     // the resulting matrix returns the pruned A back to the caller.
     //
-    %D = call @matmul(%A, %B, %C): (tensor<16x16xf16>, tensor<16x16xf16>, tensor<16x16xf16>) -> (tensor<16x16xf16>)
+    %D = call @matmul24(%A, %B, %C): (tensor<16x16xf16>,
+                                      tensor<16x16xf16>,
+                                      tensor<16x16xf16>) -> (tensor<16x16xf16>)
 
     //
     // This was the original matrix.

>From 2e3f16e0251f7371baefc4a927830edcf0c2bb1b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Tue, 2 Jan 2024 14:24:03 -0800
Subject: [PATCH 2/2] clang-format

---
 mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 3b19f0979d414c..87a37a7926e9e5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -450,7 +450,8 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
 
 /// Test for 2:4 matrix with suitable metadata.
 static bool isAdmissible24(SparseTensorType &aTp) {
-  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) && aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+  return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
+         aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
 }
 
 /// Test for conversion into 2:4 matrix.



More information about the Mlir-commits mailing list