[Mlir-commits] [mlir] [mlir][sparse] add a csr x bsr matmul test case (PR #73012)
Peiming Liu
llvmlistbot at llvm.org
Tue Nov 21 08:49:14 PST 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/73012
None
>From 5416c60008be6ab4470f37bb96a95b80c7324719 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 21 Nov 2023 16:47:42 +0000
Subject: [PATCH] [mlir][sparse] add a csr x bsr matmul test case
---
.../SparseTensor/Transforms/LoopEmitter.cpp | 3 +--
.../SparseTensor/CPU/sparse_block_matmul.mlir | 25 +++++++++++++++++--
2 files changed, 24 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index ba798f09c4d583b..595ff793b1138f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -1497,8 +1497,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
levelReducedDep[tid][lvl]--;
if (!resolved) {
// TODO: support coiterating multiple slices
- assert(loopInfo.trivialTidLvls.empty() &&
- loopInfo.sliceDrivenInfo.size() == 1);
+ assert(loopInfo.sliceDrivenInfo.size() == 1);
auto [nxNonEmpty, nxMinCrd, nxAbsOffset] =
genSliceNextInduction(builder, loc, tid, lvl);
// Update while loop induction operands.
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 63da3c80b3edbc7..e5c3b55559b0293 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -41,6 +41,10 @@
doc = "X(i,j) *= A(i,j) * B(j,i)"
}
+#CSR = #sparse_tensor.encoding<{
+ map = ( i, j ) -> (i : dense, j : compressed)
+}>
+
#BSR = #sparse_tensor.encoding<{
map = ( i, j ) ->
@@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
return %0 : tensor<4x4xf64>
}
+func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>,
+ %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> {
+ %out = arith.constant dense<0.0> : tensor<4x4xf64>
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>)
+ outs(%out: tensor<4x4xf64>) {
+ ^bb(%x: f64, %y : f64, %z : f64):
+ %1 = arith.mulf %x, %y : f64
+ %2 = arith.addf %1, %z : f64
+ linalg.yield %2 : f64
+ } -> tensor<4x4xf64>
+ return %0 : tensor<4x4xf64>
+}
+
func.func @mul_dense(%arg0: tensor<4x8xf64>,
%arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
%out = arith.constant dense<0.0> : tensor<4x4xf64>
@@ -132,6 +150,7 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
%2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
%3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
+ %4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR>
%d = call @mul_dense(%td, %td)
: (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
@@ -139,11 +158,13 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
: (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
%s24 = call @mul_24(%td, %3)
: (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
-
- // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
+ %scsr = call @mul_csr_bsr(%4, %2)
+ : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
+ // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) )
call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
+ call @dumpf64(%scsr) : (tensor<4x4xf64>) -> ()
return
}
More information about the Mlir-commits
mailing list