[Mlir-commits] [mlir] [mlir][sparse] support tensor.pad on CSR tensors (PR #90687)

Peiming Liu llvmlistbot at llvm.org
Wed May 1 10:00:46 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/90687

>From 7918fe4e4ce48a271c19e55177fa4d46b3228d1c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Apr 2024 23:22:38 +0000
Subject: [PATCH] [mlir][sparse] support tensor.pad on CSR tensors

---
 .../Transforms/Utils/SparseTensorIterator.cpp | 132 +++++++++++++-----
 .../Transforms/Utils/SparseTensorIterator.h   |   6 +-
 .../CPU/padded_sparse_conv_2d.mlir            |  30 ++--
 3 files changed, 111 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
+    assert(!inPadZone && "Not implememnted");
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
 
     assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
-    Value p = parentPos.front();
 
-    SmallVector<Value> memCrd(batchPrefix);
-    memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
-    memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
-    return {pLo, pHi};
+    auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+      Value p = parentPos.front();
+      SmallVector<Value> memCrd(batchPrefix);
+      memCrd.push_back(p);
+      Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+      memCrd.back() = ADDI(p, C_IDX(1));
+      Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+      return {pLo, pHi};
+    };
+
+    if (inPadZone == nullptr)
+      return loadRange();
+
+    SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+    scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+    // True branch.
+    b.setInsertionPointToStart(posRangeIf.thenBlock());
+    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+    SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+    b.create<scf::YieldOp>(l, emptyRange);
+
+    // False branch.
+    b.setInsertionPointToStart(posRangeIf.elseBlock());
+    auto [pLo, pHi] = loadRange();
+    SmallVector<Value, 2> loadedRange{pLo, pHi};
+    b.create<scf::YieldOp>(l, loadedRange);
+
+    b.setInsertionPointAfter(posRangeIf);
+    ValueRange posRange = posRangeIf.getResults();
+    return {posRange.front(), posRange.back()};
   }
 };
 
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
+    assert(!inPadZone && "Not implememnted");
     SmallVector<Value> memCrd(batchPrefix);
     Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 || parentPos.size() == 2);
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
 
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && isUnique() &&
            "n:m level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
     Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
   };
 
   void genInitImpl(OpBuilder &b, Location l,
-                   const SparseIterator *parent) override {
-
-    if (isBatchIterator() && batchCrds.size() <= stl.lvl)
-      batchCrds.resize(stl.lvl + 1, nullptr);
-
-    Value c0 = C_IDX(0);
-    ValueRange pPos = c0;
-    // If the parent iterator is a batch iterator, we also start from 0 (but
-    // on a different batch).
-    if (parent && !parent->isBatchIterator())
-      pPos = parent->getCurPosition();
-
-    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
-    // Seek to the lowest position.
-    seek(posLo);
-  }
+                   const SparseIterator *parent) override;
 
   ValuePair genForCond(OpBuilder &b, Location l) override {
     if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
 // A util base-iterator that delegates all methods to the wrapped iterator.
 class SimpleWrapIterator : public SparseIterator {
 public:
-  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
-      : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+                     unsigned extraCursorVal = 0)
+      : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
 
   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
     return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
 public:
   PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
               Value padHigh)
-      : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
-        padHigh(padHigh) {
-    assert(!randomAccessible() && "Not implemented.");
+      : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+                           wrap->randomAccessible() ? 1 : 0),
+        padLow(padLow), padHigh(padHigh) {
+    // assert(!randomAccessible());
   }
 
   // For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
   }
 
+  // For padded dense iterator, we append a `inPadZone: bool` in addition to
+  // values used by the wrapped iterator.
+  ValueRange getCurPosition() const override { return getCursor(); }
+
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    SmallVector<Type> ret = wrap->getCursorValTypes(b);
+    // Need a extra boolean value `inPadZone` for padded dense iterator.
+    if (randomAccessible())
+      ret.push_back(b.getI1Type());
+
+    return ret;
+  }
+
   // The upper bound after padding becomes `size + padLow + padHigh`.
   Value upperBound(OpBuilder &b, Location l) const override {
     return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
 
   void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
+    wrap->locate(b, l, SUBI(crd, padLow));
+
+    // inPadZone = crd < padLow || crd >= size + padLow.
+    Value inPadLow = CMPI(ult, crd, padLow);
+    Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+    getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+    updateCrd(crd);
   }
 
   Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
 }
 
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+                                  const SparseIterator *parent) {
+
+  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+    batchCrds.resize(stl.lvl + 1, nullptr);
+
+  Value c0 = C_IDX(0);
+  ValueRange pPos = c0;
+  Value inPadZone = nullptr;
+  // If the parent iterator is a batch iterator, we also start from 0 (but
+  // on a different batch).
+  if (parent && !parent->isBatchIterator()) {
+    pPos = parent->getCurPosition();
+    if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+      // A padded dense iterator create "sparse" padded zone, which need to be
+      // handled specially.
+      inPadZone = pPos.back();
+      pPos = pPos.drop_back();
+    }
+  }
+
+  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+  // Seek to the lowest position.
+  seek(posLo);
+}
+
 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
                                           const SparseIterator *) {
   Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
   ///
   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
   /// to load coordinate from the coordinate buffer.
-  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix,
-                                              ValueRange parentPos) const = 0;
+  virtual std::pair<Value, Value>
+  peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+              ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
 // Do the same run, but now with direct IR generation and VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-#CCCC = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
 }>
 
 // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
   return %ret : tensor<3x8x8x1xf32>
 }
 
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
   %cst_0 = arith.constant 0.00000e+00 : f32
   %buf = tensor.empty() : tensor<3x8x8x1xf32>
   %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
   %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
    ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
      tensor.yield %cst_0 : f32
-   } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+   } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
 
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
     outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
   return %ret : tensor<3x8x8x1xf32>
 }
@@ -105,8 +97,8 @@ func.func @main() {
 
   %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
 
-  %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
-  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+  %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+  %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
 
 
   // CHECK:   ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
   // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
   // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
   // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
-  %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+  %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
-  vector.print %CCCC_v : vector<3x8x8x1xf32>
+  vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
 
   bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
   bufferization.dealloc_tensor %static_input  : tensor<3x8x8x3xf32>
   bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+  bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+  bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
 
   return
 }



More information about the Mlir-commits mailing list