[Mlir-commits] [mlir] b52eb7c - [mlir][sparse] add a csr x bsr matmul test case (#73012)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 21 09:14:49 PST 2023


Author: Peiming Liu
Date: 2023-11-21T09:14:45-08:00
New Revision: b52eb7c2fe936697c538d49c723db59eb2fda1f5

URL: https://github.com/llvm/llvm-project/commit/b52eb7c2fe936697c538d49c723db59eb2fda1f5
DIFF: https://github.com/llvm/llvm-project/commit/b52eb7c2fe936697c538d49c723db59eb2fda1f5.diff

LOG: [mlir][sparse] add a csr x bsr matmul test case (#73012)

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index ba798f09c4d583b..595ff793b1138f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -1497,8 +1497,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
     levelReducedDep[tid][lvl]--;
     if (!resolved) {
       // TODO: support coiterating multiple slices
-      assert(loopInfo.trivialTidLvls.empty() &&
-             loopInfo.sliceDrivenInfo.size() == 1);
+      assert(loopInfo.sliceDrivenInfo.size() == 1);
       auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
           genSliceNextInduction(builder, loc, tid, lvl);
       // Update while loop induction operands.

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 63da3c80b3edbc7..4bc080fc538fc6e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -41,6 +41,10 @@
   doc = "X(i,j) *= A(i,j) * B(j,i)"
 }
 
+#CSR = #sparse_tensor.encoding<{
+  map = ( i, j ) -> (i : dense, j : compressed)
+}>
+
 
 #BSR = #sparse_tensor.encoding<{
   map = ( i, j ) ->
@@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
   return %0 : tensor<4x4xf64>
 }
 
+func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
+                       %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+  %out = arith.constant dense<0.0> : tensor<4x4xf64>
+  %0 = linalg.generic #trait_mul
+    ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
+    outs(%out: tensor<4x4xf64>) {
+      ^bb(%x: f64, %y : f64, %z : f64):
+        %1 = arith.mulf %x, %y : f64
+        %2 = arith.addf %1, %z : f64
+        linalg.yield %2 : f64
+  } -> tensor<4x4xf64>
+  return %0 : tensor<4x4xf64>
+}
+
 func.func @mul_dense(%arg0: tensor<4x8xf64>,
                      %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
   %out = arith.constant dense<0.0> : tensor<4x4xf64>
@@ -132,6 +150,7 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
 
     %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
     %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
+    %4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
 
     %d = call @mul_dense(%td, %td)
          : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
@@ -139,11 +158,14 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
          : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
     %s24 = call @mul_24(%td, %3)
          : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
+    %scsr = call @mul_csr_bsr(%4, %2)
+         : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
 
-    // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
+    // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
     call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
     call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
+    call @dumpf64(%scsr) : (tensor<4x4xf64>) -> ()
 
     return
   }


        


More information about the Mlir-commits mailing list