[Mlir-commits] [mlir] [mlir][sparse] fix bugs when generate sparse conv_3d kernels. (PR #74561)
Peiming Liu
llvmlistbot at llvm.org
Tue Dec 5 21:23:02 PST 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/74561
None
>From 3b6065410add0e8953c70277d6c0abef39fe7779 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 6 Dec 2023 05:16:22 +0000
Subject: [PATCH] [mlir][sparse] fix bugs when generate sparse conv_3d kernels.
---
.../SparseTensor/Transforms/LoopEmitter.cpp | 17 ++---
.../SparseTensor/CPU/sparse_conv_3d.mlir | 62 ++++++++++++++++++-
2 files changed, 65 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 6a2d7c33356f9..54eb52fe2e889 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -1454,28 +1454,19 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
Location loc, TensorId tid,
Level rootLvl, Value fcnt) {
+
auto stt = getSparseTensorType(tensors[tid]);
// Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
- // level (but not resolved). Since we forward an iterator at higher level of
- // the tree, the subtree need to be pruned.
+ // sparse levels (but not resolved). Since we forward an iterator at higher
+ // level of the tree, the subtree need to be pruned.
Level leafLvl = rootLvl + 1;
while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
- depFullyReduced(tid, leafLvl)) {
+ depFullyReduced(tid, leafLvl) && !stt.isDenseLvl(leafLvl)) {
leafLvl++;
}
Level curLvl = rootLvl + 1;
- // Prunes all denses subtree.
- while (curLvl < leafLvl && isDenseLT(lvlTypes[tid][curLvl])) {
- // One step forward in parent level results in forwarding `slice.size` step
- // in child dense level.
- auto [size, stride] = sliceMeta[tid][curLvl].back();
- assert(stride == 1 && "Not yet implemented");
- fcnt = MULI(size, fcnt);
- curLvl++;
- }
-
Value nxPosPtr = nullptr;
if (curLvl < leafLvl) {
assert(!isDenseLT(lvlTypes[tid][curLvl]));
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
index dfb1bb71a68c4..451d2b8769461 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
@@ -38,10 +38,14 @@
map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed)
}>
-#DDC = #sparse_tensor.encoding<{
+#DCC = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed)
}>
+#DDC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
// Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
%buf = tensor.empty(%s1, %s2, %s3) : tensor<?x?x?xf32>
@@ -74,6 +78,15 @@ func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>)
return %ret : tensor<?x?x?xf32, #CDC>
}
+func.func @conv_3d_DCC(%arg0: tensor<?x?x?xf32, #DCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DCC> {
+ %c6 = arith.constant 6 : index
+ %s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DCC>
+ %ret = linalg.conv_3d
+ ins (%arg0, %arg1: tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>)
+ outs (%s: tensor<?x?x?xf32, #DCC>) -> tensor<?x?x?xf32, #DCC>
+ return %ret : tensor<?x?x?xf32, #DCC>
+}
+
func.func @conv_3d_DDC(%arg0: tensor<?x?x?xf32, #DDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DDC> {
%c6 = arith.constant 6 : index
%s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DDC>
@@ -102,12 +115,15 @@ func.func @entry() {
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
%in3D_CDC = sparse_tensor.convert %in3D
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
+ %in3D_DCC = sparse_tensor.convert %in3D
+ : tensor<?x?x?xf32> to tensor<?x?x?xf32, #DCC>
%in3D_DDC = sparse_tensor.convert %in3D
: tensor<?x?x?xf32> to tensor<?x?x?xf32, #DDC>
%dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
%CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
%CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
+ %DCC_ret = call @conv_3d_DCC(%in3D_DCC, %filter3D) : (tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DCC>)
%DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DDC>)
// CHECK:( ( ( 108, 108, 108, 108, 108, 108 ),
@@ -276,6 +292,48 @@ func.func @entry() {
: tensor<?x?x?xf32>, vector<6x6x6xf32>
vector.print %v2 : vector<6x6x6xf32>
+ // CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
+ // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
+ // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
+ // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
+ // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ),
+ // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) )
+ %4 = sparse_tensor.convert %DCC_ret
+ : tensor<?x?x?xf32, #DCC> to tensor<?x?x?xf32>
+ %v4 = vector.transfer_read %3[%c0, %c0, %c0], %zero
+ : tensor<?x?x?xf32>, vector<6x6x6xf32>
+ vector.print %v2 : vector<6x6x6xf32>
+
// Free the resources
bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
@@ -284,9 +342,11 @@ func.func @entry() {
bufferization.dealloc_tensor %in3D_CDC : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %in3D_DDC : tensor<?x?x?xf32, #DDC>
+ bufferization.dealloc_tensor %in3D_DCC : tensor<?x?x?xf32, #DCC>
bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
bufferization.dealloc_tensor %DDC_ret : tensor<?x?x?xf32, #DDC>
+ bufferization.dealloc_tensor %DCC_ret : tensor<?x?x?xf32, #DCC>
return
}
More information about the Mlir-commits
mailing list