[Mlir-commits] [mlir] [mlir][sparse] support tensor.pad on CSR tensors (PR #90687)

Peiming Liu llvmlistbot at llvm.org
Wed May 1 13:56:41 PDT 2024


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/90687

>From 520c703cd79dc23bd1d53b0f649f6e0d15e98da2 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Apr 2024 23:22:38 +0000
Subject: [PATCH 1/2] [mlir][sparse] support tensor.pad on CSR tensors

---
 .../Transforms/Utils/SparseTensorIterator.cpp | 132 +++++++++++++-----
 .../Transforms/Utils/SparseTensorIterator.h   |   6 +-
 .../CPU/padded_sparse_conv_2d.mlir            |  30 ++--
 3 files changed, 111 insertions(+), 57 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value posLo = MULI(p, lvlSize);
     return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
   ValueRange getLvlBuffers() const override { return {}; }
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
+    assert(!inPadZone && "Not implememnted");
     assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
     // No need to linearize the position for non-annotated tensors.
     return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
 
     assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
-    Value p = parentPos.front();
 
-    SmallVector<Value> memCrd(batchPrefix);
-    memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
-    memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
-    return {pLo, pHi};
+    auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+      Value p = parentPos.front();
+      SmallVector<Value> memCrd(batchPrefix);
+      memCrd.push_back(p);
+      Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+      memCrd.back() = ADDI(p, C_IDX(1));
+      Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+      return {pLo, pHi};
+    };
+
+    if (inPadZone == nullptr)
+      return loadRange();
+
+    SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+    scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+    // True branch.
+    b.setInsertionPointToStart(posRangeIf.thenBlock());
+    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+    SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+    b.create<scf::YieldOp>(l, emptyRange);
+
+    // False branch.
+    b.setInsertionPointToStart(posRangeIf.elseBlock());
+    auto [pLo, pHi] = loadRange();
+    SmallVector<Value, 2> loadedRange{pLo, pHi};
+    b.create<scf::YieldOp>(l, loadedRange);
+
+    b.setInsertionPointAfter(posRangeIf);
+    ValueRange posRange = posRangeIf.getResults();
+    return {posRange.front(), posRange.back()};
   }
 };
 
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 &&
            "loose-compressed level must be the first non-unique level.");
+    assert(!inPadZone && "Not implememnted");
     SmallVector<Value> memCrd(batchPrefix);
     Value p = parentPos.front();
     p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 || parentPos.size() == 2);
+    assert(!inPadZone && "Not implememnted");
     Value p = parentPos.front();
     Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
 
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
       : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
     assert(parentPos.size() == 1 && isUnique() &&
            "n:m level can not be non-unique.");
+    assert(!inPadZone && "Not implememnted");
     // Each n:m blk has exactly n specified elements.
     auto n = getN(lt);
     Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
   };
 
   void genInitImpl(OpBuilder &b, Location l,
-                   const SparseIterator *parent) override {
-
-    if (isBatchIterator() && batchCrds.size() <= stl.lvl)
-      batchCrds.resize(stl.lvl + 1, nullptr);
-
-    Value c0 = C_IDX(0);
-    ValueRange pPos = c0;
-    // If the parent iterator is a batch iterator, we also start from 0 (but
-    // on a different batch).
-    if (parent && !parent->isBatchIterator())
-      pPos = parent->getCurPosition();
-
-    ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
-    std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
-    // Seek to the lowest position.
-    seek(posLo);
-  }
+                   const SparseIterator *parent) override;
 
   ValuePair genForCond(OpBuilder &b, Location l) override {
     if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
 // A util base-iterator that delegates all methods to the wrapped iterator.
 class SimpleWrapIterator : public SparseIterator {
 public:
-  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
-      : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+  SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+                     unsigned extraCursorVal = 0)
+      : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
 
   SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
     return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
 public:
   PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
               Value padHigh)
-      : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
-        padHigh(padHigh) {
-    assert(!randomAccessible() && "Not implemented.");
+      : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+                           wrap->randomAccessible() ? 1 : 0),
+        padLow(padLow), padHigh(padHigh) {
+    // assert(!randomAccessible());
   }
 
   // For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
   }
 
+  // For padded dense iterator, we append a `inPadZone: bool` in addition to
+  // values used by the wrapped iterator.
+  ValueRange getCurPosition() const override { return getCursor(); }
+
+  SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+    SmallVector<Type> ret = wrap->getCursorValTypes(b);
+    // Need a extra boolean value `inPadZone` for padded dense iterator.
+    if (randomAccessible())
+      ret.push_back(b.getI1Type());
+
+    return ret;
+  }
+
   // The upper bound after padding becomes `size + padLow + padHigh`.
   Value upperBound(OpBuilder &b, Location l) const override {
     return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
 
   void locateImpl(OpBuilder &b, Location l, Value crd) override {
     assert(randomAccessible());
+    wrap->locate(b, l, SUBI(crd, padLow));
+
+    // inPadZone = crd < padLow || crd >= size + padLow.
+    Value inPadLow = CMPI(ult, crd, padLow);
+    Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+    getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+    updateCrd(crd);
   }
 
   Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
   return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
 }
 
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+                                  const SparseIterator *parent) {
+
+  if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+    batchCrds.resize(stl.lvl + 1, nullptr);
+
+  Value c0 = C_IDX(0);
+  ValueRange pPos = c0;
+  Value inPadZone = nullptr;
+  // If the parent iterator is a batch iterator, we also start from 0 (but
+  // on a different batch).
+  if (parent && !parent->isBatchIterator()) {
+    pPos = parent->getCurPosition();
+    if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+      // A padded dense iterator create "sparse" padded zone, which need to be
+      // handled specially.
+      inPadZone = pPos.back();
+      pPos = pPos.drop_back();
+    }
+  }
+
+  ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+  std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+  // Seek to the lowest position.
+  seek(posLo);
+}
+
 void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
                                           const SparseIterator *) {
   Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
   ///
   /// For a sparse level, [posLo, loopHi) specifies the range of index pointer
   /// to load coordinate from the coordinate buffer.
-  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
-                                              ValueRange batchPrefix,
-                                              ValueRange parentPos) const = 0;
+  virtual std::pair<Value, Value>
+  peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+              ValueRange parentPos, Value inPadZone = nullptr) const = 0;
 
   Level getLevel() const { return lvl; }
   LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
 // Do the same run, but now with direct IR generation and VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-#CCCC = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
-  map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+  map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
 }>
 
 // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
   return %ret : tensor<3x8x8x1xf32>
 }
 
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
   %cst_0 = arith.constant 0.00000e+00 : f32
   %buf = tensor.empty() : tensor<3x8x8x1xf32>
   %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
   %padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
    ^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
      tensor.yield %cst_0 : f32
-   } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+   } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
 
   %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
                                      strides = dense<1> : tensor<2xi64>}
-     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+     ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
     outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
   return %ret : tensor<3x8x8x1xf32>
 }
@@ -105,8 +97,8 @@ func.func @main() {
 
   %dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
 
-  %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
-  %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+  %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+  %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
 
 
   // CHECK:   ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
   // CHECK-SAME:     ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
   // CHECK-SAME:     ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
   // CHECK-SAME:     ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
-  %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+  %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
       : tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
-  vector.print %CCCC_v : vector<3x8x8x1xf32>
+  vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
 
   bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
   bufferization.dealloc_tensor %static_input  : tensor<3x8x8x3xf32>
   bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+  bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
 
-  bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+  bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
 
   return
 }

>From 0681112fe9be3e394e116678dbe50d1086eaf148 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 1 May 2024 20:25:47 +0000
Subject: [PATCH 2/2] address comments

---
 .../Transforms/Utils/SparseTensorIterator.cpp | 19 +++--
 .../fuse_sparse_pad_with_consumer.mlir        | 79 +++++++++++++++++++
 2 files changed, 92 insertions(+), 6 deletions(-)
 create mode 100644 mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index 112b9f6c252786..a74e177f031401 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -151,13 +151,14 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
 
     SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
     scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
-    // True branch.
+    // True branch, returns a "fake" empty range [0, 0) if parent
+    // iterator is in pad zone.
     b.setInsertionPointToStart(posRangeIf.thenBlock());
-    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+
     SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
     b.create<scf::YieldOp>(l, emptyRange);
 
-    // False branch.
+    // False branch, returns the actual range.
     b.setInsertionPointToStart(posRangeIf.elseBlock());
     auto [pLo, pHi] = loadRange();
     SmallVector<Value, 2> loadedRange{pLo, pHi};
@@ -487,6 +488,7 @@ class SimpleWrapIterator : public SparseIterator {
   bool isBatchIterator() const override { return wrap->isBatchIterator(); }
   bool randomAccessible() const override { return wrap->randomAccessible(); };
   bool iteratableByFor() const override { return wrap->iteratableByFor(); };
+
   SmallVector<Value> serialize() const override { return wrap->serialize(); };
   void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
   ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
@@ -601,9 +603,7 @@ class PadIterator : public SimpleWrapIterator {
               Value padHigh)
       : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
                            wrap->randomAccessible() ? 1 : 0),
-        padLow(padLow), padHigh(padHigh) {
-    // assert(!randomAccessible());
-  }
+        padLow(padLow), padHigh(padHigh) {}
 
   // For LLVM-style RTTI.
   static bool classof(const SparseIterator *from) {
@@ -614,6 +614,13 @@ class PadIterator : public SimpleWrapIterator {
     return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
   }
 
+  // Returns a pair of values for *upper*, *lower* bound respectively.
+  ValuePair genForCond(OpBuilder &b, Location l) override {
+    if (randomAccessible())
+      return {getCrd(), upperBound(b, l)};
+    return wrap->genForCond(b, l);
+  }
+
   // For padded dense iterator, we append a `inPadZone: bool` in addition to
   // values used by the wrapped iterator.
   ValueRange getCurPosition() const override { return getCursor(); }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
new file mode 100644
index 00000000000000..4f509bf747ab68
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -canonicalize | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+#elemwise = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A
+    affine_map<(i,j) -> (i,j)>,  // B
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"],
+  doc = "X(i,j) = A(i,j) OP B(i,j)"
+}
+
+
+// CHECK-LABEL:   func.func @padded_mul(
+// CHECK-SAME:                          %[[VAL_0:.*]]: tensor<4x4xf32, #sparse>,
+// CHECK-SAME:                          %[[VAL_1:.*]]: tensor<8x8xf32>) -> tensor<8x8xf32> {
+// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant -1 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 6 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
+// CHECK:           %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK:           %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK:           %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
+// CHECK:           %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
+// CHECK:           linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
+// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
+// CHECK:             %[[VAL_18:.*]] = arith.cmpi uge, %[[VAL_15]], %[[VAL_3]] : index
+// CHECK:             %[[VAL_19:.*]] = arith.ori %[[VAL_17]], %[[VAL_18]] : i1
+// CHECK:             %[[VAL_20:.*]]:2 = scf.if %[[VAL_19]] -> (index, index) {
+// CHECK:               scf.yield %[[VAL_6]], %[[VAL_6]] : index, index
+// CHECK:             } else {
+// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
+// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:               scf.yield %[[VAL_21]], %[[VAL_23]] : index, index
+// CHECK:             }
+// CHECK:             scf.for %[[VAL_24:.*]] = %[[VAL_20]]#0 to %[[VAL_20]]#1 step %[[VAL_5]] {
+// CHECK:               %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:               %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
+// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK:               %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK:               %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
+// CHECK:               memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:           %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<8x8xf32>
+// CHECK:           return %[[VAL_31]] : tensor<8x8xf32>
+// CHECK:         }
+func.func @padded_mul(%arg0: tensor<4x4xf32, #CSR>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
+  %cst_0 = arith.constant 0.00000e+00 : f32
+  %buf = tensor.empty() : tensor<8x8xf32>
+  %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+  %padded = tensor.pad %arg0 low[2, 2] high[2, 2] {
+  ^bb0(%arg75: index, %arg76: index):
+    tensor.yield %cst_0 : f32
+  } : tensor<4x4xf32, #CSR> to tensor<8x8xf32, #CSR>
+
+  %0 = linalg.generic #elemwise
+     ins(%padded, %arg1: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
+    outs(%s: tensor<8x8xf32>) {
+      ^bb(%a: f32, %b: f32, %x: f32):
+        %0 = arith.mulf %a, %b : f32
+        linalg.yield %0 : f32
+  } -> tensor<8x8xf32>
+
+  return %0 : tensor<8x8xf32>
+}



More information about the Mlir-commits mailing list