[Mlir-commits] [mlir] [mlir][sparse] add a csr x bsr matmul test case (PR #73012)

Peiming Liu llvmlistbot at llvm.org
Tue Nov 21 09:12:52 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/73012

>From 5416c60008be6ab4470f37bb96a95b80c7324719 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 Nov 2023 16:47:42 +0000
Subject: [PATCH 1/2] [mlir][sparse] add a csr x bsr matmul test case

---
 .../SparseTensor/Transforms/LoopEmitter.cpp   |  3 +--
 .../SparseTensor/CPU/sparse_block_matmul.mlir | 25 +++++++++++++++++--
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index ba798f09c4d583b..595ff793b1138f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -1497,8 +1497,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
     levelReducedDep[tid][lvl]--;
     if (!resolved) {
       // TODO: support coiterating multiple slices
-      assert(loopInfo.trivialTidLvls.empty() &&
-             loopInfo.sliceDrivenInfo.size() == 1);
+      assert(loopInfo.sliceDrivenInfo.size() == 1);
       auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
           genSliceNextInduction(builder, loc, tid, lvl);
       // Update while loop induction operands.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 63da3c80b3edbc7..e5c3b55559b0293 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -41,6 +41,10 @@
   doc = "X(i,j) *= A(i,j) * B(j,i)"
 }
 
+#CSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> (i : dense, j : compressed)
+}>
+
 
 #BSR = #sparse_tensor.encoding<{
   map = ( i, j ) ->
@@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
   return %0 : tensor<4x4xf64>
 }
 
+func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
+                       %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
 func.func @mul_dense(%arg0: tensor<4x8xf64>,
                      %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
   %out = arith.constant dense<0.0> : tensor<4x4xf64>
@@ -132,6 +150,7 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
 
     %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
     %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
+    %4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
 
     %d = call @mul_dense(%td, %td)
          : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
@@ -139,11 +158,13 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
          : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
     %s24 = call @mul_24(%td, %3)
          : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
-
-    // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
+    %scsr = call @mul_csr_bsr(%4, %2)
+         : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+    // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
+    call @dumpf64(%scsr) : (tensor<4x4xf64>) -> ()
 
     return
   }

>From 8f1d30a79be8ea5a8aad609239ba2b175624a9c9 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 Nov 2023 17:12:16 +0000
Subject: [PATCH 2/2] address comments

---
 .../Dialect/SparseTensor/CPU/sparse_block_matmul.mlir            | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index e5c3b55559b0293..4bc080fc538fc6e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -160,6 +160,7 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
          : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
     %scsr = call @mul_csr_bsr(%4, %2)
          : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+
     // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()



More information about the Mlir-commits mailing list