[Mlir-commits] [mlir] [mlir][sparse] support tensor.pad on CSR tensors (PR #90687)
Aart Bik
llvmlistbot at llvm.org
Wed May 1 11:43:52 PDT 2024
================
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
- Value p = parentPos.front();
- SmallVector<Value> memCrd(batchPrefix);
- memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
- memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
- return {pLo, pHi};
+ auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+ Value p = parentPos.front();
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+ return {pLo, pHi};
+ };
+
+ if (inPadZone == nullptr)
+ return loadRange();
+
+ SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+ scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+ // True branch.
+ b.setInsertionPointToStart(posRangeIf.thenBlock());
+ // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
----------------
aartbik wrote:
can you move this comment into the True branch line,
so that the True/False branch is clear
For symmetry, you can also say what the False branch does at L160
https://github.com/llvm/llvm-project/pull/90687
More information about the Mlir-commits
mailing list