[Mlir-commits] [mlir] [mlir][sparse] support tensor.pad on CSR tensors (PR #90687)
Peiming Liu
llvmlistbot at llvm.org
Wed May 1 10:00:46 PDT 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/90687
>From 7918fe4e4ce48a271c19e55177fa4d46b3228d1c Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Apr 2024 23:22:38 +0000
Subject: [PATCH] [mlir][sparse] support tensor.pad on CSR tensors
---
.../Transforms/Utils/SparseTensorIterator.cpp | 132 +++++++++++++-----
.../Transforms/Utils/SparseTensorIterator.h | 6 +-
.../CPU/padded_sparse_conv_2d.mlir | 30 ++--
3 files changed, 111 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
+ assert(!inPadZone && "Not implememnted");
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
- Value p = parentPos.front();
- SmallVector<Value> memCrd(batchPrefix);
- memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
- memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
- return {pLo, pHi};
+ auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+ Value p = parentPos.front();
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+ return {pLo, pHi};
+ };
+
+ if (inPadZone == nullptr)
+ return loadRange();
+
+ SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+ scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+ // True branch.
+ b.setInsertionPointToStart(posRangeIf.thenBlock());
+ // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+ SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+ b.create<scf::YieldOp>(l, emptyRange);
+
+ // False branch.
+ b.setInsertionPointToStart(posRangeIf.elseBlock());
+ auto [pLo, pHi] = loadRange();
+ SmallVector<Value, 2> loadedRange{pLo, pHi};
+ b.create<scf::YieldOp>(l, loadedRange);
+
+ b.setInsertionPointAfter(posRangeIf);
+ ValueRange posRange = posRangeIf.getResults();
+ return {posRange.front(), posRange.back()};
}
};
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
+ assert(!inPadZone && "Not implememnted");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
};
void genInitImpl(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
-
- if (isBatchIterator() && batchCrds.size() <= stl.lvl)
- batchCrds.resize(stl.lvl + 1, nullptr);
-
- Value c0 = C_IDX(0);
- ValueRange pPos = c0;
- // If the parent iterator is a batch iterator, we also start from 0 (but
- // on a different batch).
- if (parent && !parent->isBatchIterator())
- pPos = parent->getCurPosition();
-
- ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
- // Seek to the lowest position.
- seek(posLo);
- }
+ const SparseIterator *parent) override;
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
// A util base-iterator that delegates all methods to the wrapped iterator.
class SimpleWrapIterator : public SparseIterator {
public:
- SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+ SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+ unsigned extraCursorVal = 0)
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
public:
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
Value padHigh)
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
- padHigh(padHigh) {
- assert(!randomAccessible() && "Not implemented.");
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+ wrap->randomAccessible() ? 1 : 0),
+ padLow(padLow), padHigh(padHigh) {
+ // assert(!randomAccessible());
}
// For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
}
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
+ // values used by the wrapped iterator.
+ ValueRange getCurPosition() const override { return getCursor(); }
+
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ SmallVector<Type> ret = wrap->getCursorValTypes(b);
+ // Need a extra boolean value `inPadZone` for padded dense iterator.
+ if (randomAccessible())
+ ret.push_back(b.getI1Type());
+
+ return ret;
+ }
+
// The upper bound after padding becomes `size + padLow + padHigh`.
Value upperBound(OpBuilder &b, Location l) const override {
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
+ wrap->locate(b, l, SUBI(crd, padLow));
+
+ // inPadZone = crd < padLow || crd >= size + padLow.
+ Value inPadLow = CMPI(ult, crd, padLow);
+ Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+ getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+ updateCrd(crd);
}
Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) {
+
+ if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+ batchCrds.resize(stl.lvl + 1, nullptr);
+
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
+ Value inPadZone = nullptr;
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator()) {
+ pPos = parent->getCurPosition();
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+ // A padded dense iterator create "sparse" padded zone, which need to be
+ // handled specially.
+ inPadZone = pPos.back();
+ pPos = pPos.drop_back();
+ }
+ }
+
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+ // Seek to the lowest position.
+ seek(posLo);
+}
+
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *) {
Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
- virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
- ValueRange batchPrefix,
- ValueRange parentPos) const = 0;
+ virtual std::pair<Value, Value>
+ peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ ValueRange parentPos, Value inPadZone = nullptr) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
// Do the same run, but now with direct IR generation and VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-#CCCC = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
}>
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
return %ret : tensor<3x8x8x1xf32>
}
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
%cst_0 = arith.constant 0.00000e+00 : f32
%buf = tensor.empty() : tensor<3x8x8x1xf32>
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
%padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
tensor.yield %cst_0 : f32
- } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+ } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+ ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
return %ret : tensor<3x8x8x1xf32>
}
@@ -105,8 +97,8 @@ func.func @main() {
%dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
- %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
- %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+ %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+ %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
// CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
// CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
// CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
// CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
- %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+ %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
: tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
- vector.print %CCCC_v : vector<3x8x8x1xf32>
+ vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
bufferization.dealloc_tensor %static_input : tensor<3x8x8x3xf32>
bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+ bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+ bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
return
}
More information about the Mlir-commits
mailing list