[Mlir-commits] [mlir] [mlir][sparse] fix bugs when generate sparse conv_3d kernels. (PR #74561)

Peiming Liu llvmlistbot at llvm.org
Wed Dec 6 10:18:48 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/74561

>From 8739240dff45087bd3ad247034cb40acc992d14d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 6 Dec 2023 05:16:22 +0000
Subject: [PATCH] [mlir][sparse] fix bugs when generate sparse conv_3d kernels.

---
 .../SparseTensor/Transforms/LoopEmitter.cpp   | 19 ++----
 .../SparseTensor/CPU/sparse_conv_3d.mlir      | 62 ++++++++++++++++++-
 2 files changed, 66 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
index 6a2d7c33356f9..ff8561534a376 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
@@ -1454,28 +1454,19 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
 void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
                                                   Location loc, TensorId tid,
                                                   Level rootLvl, Value fcnt) {
+
   auto stt = getSparseTensorType(tensors[tid]);
 
   // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
-  // level (but not resolved). Since we forward an iterator at higher level of
-  // the tree, the subtree need to be pruned.
+  // sparse levels (but not resolved). Since we forward an iterator at higher
+  // level of the tree, the subtree need to be pruned.
   Level leafLvl = rootLvl + 1;
-  while (leafLvl < stt.getLvlRank() && !dependentLvlMap[tid][leafLvl].empty() &&
-         depFullyReduced(tid, leafLvl)) {
+  while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
+         !stt.isDenseLvl(leafLvl)) {
     leafLvl++;
   }
 
   Level curLvl = rootLvl + 1;
-  // Prunes all denses subtree.
-  while (curLvl < leafLvl && isDenseLT(lvlTypes[tid][curLvl])) {
-    // One step forward in parent level results in forwarding `slice.size` step
-    // in child dense level.
-    auto [size, stride] = sliceMeta[tid][curLvl].back();
-    assert(stride == 1 && "Not yet implemented");
-    fcnt = MULI(size, fcnt);
-    curLvl++;
-  }
-
   Value nxPosPtr = nullptr;
   if (curLvl < leafLvl) {
     assert(!isDenseLT(lvlTypes[tid][curLvl]));
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
index dfb1bb71a68c4..451d2b8769461 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d.mlir
@@ -38,10 +38,14 @@
   map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed)
 }>
 
-#DDC = #sparse_tensor.encoding<{
+#DCC = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed)
 }>
 
+#DDC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed)
+}>
+
 // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f
 func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor<?x?x?xf32> {
   %buf = tensor.empty(%s1, %s2, %s3) : tensor<?x?x?xf32>
@@ -74,6 +78,15 @@ func.func @conv_3d_CDC(%arg0: tensor<?x?x?xf32, #CDC>, %arg1: tensor<?x?x?xf32>)
   return %ret : tensor<?x?x?xf32, #CDC>
 }
 
+func.func @conv_3d_DCC(%arg0: tensor<?x?x?xf32, #DCC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DCC> {
+  %c6 = arith.constant 6 : index
+  %s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DCC>
+  %ret = linalg.conv_3d
+     ins (%arg0, %arg1: tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>)
+    outs (%s: tensor<?x?x?xf32, #DCC>) -> tensor<?x?x?xf32, #DCC>
+  return %ret : tensor<?x?x?xf32, #DCC>
+}
+
 func.func @conv_3d_DDC(%arg0: tensor<?x?x?xf32, #DDC>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32, #DDC> {
   %c6 = arith.constant 6 : index
   %s = tensor.empty(%c6, %c6, %c6) : tensor<?x?x?xf32, #DDC>
@@ -102,12 +115,15 @@ func.func @entry() {
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CCC>
   %in3D_CDC = sparse_tensor.convert %in3D
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #CDC>
+  %in3D_DCC = sparse_tensor.convert %in3D
+    : tensor<?x?x?xf32> to tensor<?x?x?xf32, #DCC>
   %in3D_DDC = sparse_tensor.convert %in3D
     : tensor<?x?x?xf32> to tensor<?x?x?xf32, #DDC>
 
   %dense_ret = call @conv_3d(%in3D, %filter3D, %out3D) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>)
   %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor<?x?x?xf32, #CCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CCC>)
   %CDC_ret = call @conv_3d_CDC(%in3D_CDC, %filter3D) : (tensor<?x?x?xf32, #CDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #CDC>)
+  %DCC_ret = call @conv_3d_DCC(%in3D_DCC, %filter3D) : (tensor<?x?x?xf32, #DCC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DCC>)
   %DDC_ret = call @conv_3d_DDC(%in3D_DDC, %filter3D) : (tensor<?x?x?xf32, #DDC>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32, #DDC>)
 
   //      CHECK:( ( ( 108, 108, 108, 108, 108, 108 ),
@@ -276,6 +292,48 @@ func.func @entry() {
       : tensor<?x?x?xf32>, vector<6x6x6xf32>
   vector.print %v2 : vector<6x6x6xf32>
 
+  // CHECK-NEXT:( ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 124, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ),
+  // CHECK-SAME:  ( ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ),
+  // CHECK-SAME:    ( 108, 108, 108, 108, 108, 108 ) ) )
+  %4 = sparse_tensor.convert %DCC_ret
+    : tensor<?x?x?xf32, #DCC> to tensor<?x?x?xf32>
+  %v4 = vector.transfer_read %3[%c0, %c0, %c0], %zero
+      : tensor<?x?x?xf32>, vector<6x6x6xf32>
+  vector.print %v2 : vector<6x6x6xf32>
+
   // Free the resources
   bufferization.dealloc_tensor %in3D : tensor<?x?x?xf32>
   bufferization.dealloc_tensor %filter3D : tensor<?x?x?xf32>
@@ -284,9 +342,11 @@ func.func @entry() {
   bufferization.dealloc_tensor %in3D_CDC : tensor<?x?x?xf32, #CDC>
   bufferization.dealloc_tensor %in3D_CCC : tensor<?x?x?xf32, #CCC>
   bufferization.dealloc_tensor %in3D_DDC : tensor<?x?x?xf32, #DDC>
+  bufferization.dealloc_tensor %in3D_DCC : tensor<?x?x?xf32, #DCC>
 
   bufferization.dealloc_tensor %CCC_ret : tensor<?x?x?xf32, #CCC>
   bufferization.dealloc_tensor %CDC_ret : tensor<?x?x?xf32, #CDC>
   bufferization.dealloc_tensor %DDC_ret : tensor<?x?x?xf32, #DDC>
+  bufferization.dealloc_tensor %DCC_ret : tensor<?x?x?xf32, #DCC>
   return
 }



More information about the Mlir-commits mailing list