[Mlir-commits] [mlir] [mlir][sparse] fix error when convolution stride is applied on a dens… (PR #79521)

Peiming Liu llvmlistbot at llvm.org
Thu Jan 25 15:29:24 PST 2024


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/79521

…e level.

>From fef22ec23fc86f6b5cc3390f04284bdec89ce312 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 25 Jan 2024 23:27:53 +0000
Subject: [PATCH] [mlir][sparse] fix error when convolution stride is applied
 on a dense level.

---
 .../Transforms/Utils/LoopEmitter.cpp          |  4 ++--
 .../Transforms/Utils/SparseTensorLevel.cpp    |  9 +++++---
 .../Transforms/Utils/SparseTensorLevel.h      |  2 +-
 .../CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir | 21 +++++++++++++++++++
 4 files changed, 30 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 6d832fdc0c2201..3fa4004ae460ef 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -313,8 +313,8 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
           Value loopHi = loopHighs[loop];
           size = ADDI(size, MULI(loopHi, C_IDX(stride)));
         }
-        it = makeNonEmptySubSectIterator(builder, loc, parent, std::move(lvlIt),
-                                         size, curDep.second);
+        it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
+                                         std::move(lvlIt), size, curDep.second);
       } else {
         Value size = loopHighs[loop];
         const SparseIterator &subSectIter = *iters[t][lvl].back();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 226cccbc422b9b..e43896942d7fe6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1271,7 +1271,7 @@ static const IterType *unwrapFilter(const SparseIterator *it) {
 }
 
 std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
-    OpBuilder &b, Location l, const SparseIterator *parent,
+    OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
 
   // Try unwrap the NonEmptySubSectIterator from a filter parent.
@@ -1279,9 +1279,12 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
   auto it = std::make_unique<NonEmptySubSectIterator>(
       b, l, parent, std::move(delegate), size);
 
-  if (stride != 1)
+  if (stride != 1) {
+    // TODO: We can safely skip bound checking on sparse levels, but for dense
+    // iteration space, we need the bound to infer the dense loop range.
     return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
-                                            C_IDX(stride), /*size=*/C_IDX(-1));
+                                            C_IDX(stride), /*size=*/loopBound);
+  }
   return it;
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 08f7c6a747eb57..d2b3396b72836c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -246,7 +246,7 @@ makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
 /// Helper function to create a SparseIterator object that iterate over the
 /// non-empty subsections set.
 std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
-    OpBuilder &b, Location l, const SparseIterator *parent,
+    OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
     std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
 
 /// Helper function to create a SparseIterator object that iterate over a
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
index 98adc26f06e56b..8ee80045afc760 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_strided_conv_2d_nhwc_hwcf.mlir
@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
   return %ret : tensor<?x?x?x?xf32>
 }
 
+func.func @conv_2d_nhwc_hwcf_dual_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tensor<?x?x?x?xf32, #CDCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
+  %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
+                                     strides = dense<2> : tensor<2xi64>}
+     ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>)
+    outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
+  return %ret : tensor<?x?x?x?xf32>
+}
+
 
 func.func @entry() {
   %c0 = arith.constant 0 : index
@@ -87,6 +95,8 @@ func.func @entry() {
 
   %in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
+  %filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
+    : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
   %in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
     : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CDCC>
 
@@ -94,9 +104,19 @@ func.func @entry() {
   %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
   %CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
 
+  %dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
+    : (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
+
   // CHECK:      ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
   // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
   // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
+  %v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
+      : tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
+  vector.print %v_dual : vector<3x3x3x1xf32>
+
+  // CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+  // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
+  // CHECK-SAME:   ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
   %dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
   vector.print %dense_v : vector<3x3x3x1xf32>
@@ -120,6 +140,7 @@ func.func @entry() {
   bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
   bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>
 
+  bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
   bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
   bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor<?x?x?x?xf32, #CDCC>
   return



More information about the Mlir-commits mailing list