[Mlir-commits] [mlir] [mlir][sparse] support tensor.pad on CSR tensors (PR #90687)

Aart Bik llvmlistbot at llvm.org
Wed May 1 11:43:52 PDT 2024


================
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
       : SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
-                        ValueRange parentPos) const override {
+                        ValueRange parentPos, Value inPadZone) const override {
 
     assert(parentPos.size() == 1 &&
            "compressed level must be the first non-unique level.");
-    Value p = parentPos.front();
 
-    SmallVector<Value> memCrd(batchPrefix);
-    memCrd.push_back(p);
-    Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
-    memCrd.back() = ADDI(p, C_IDX(1));
-    Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
-    return {pLo, pHi};
+    auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+      Value p = parentPos.front();
+      SmallVector<Value> memCrd(batchPrefix);
+      memCrd.push_back(p);
+      Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+      memCrd.back() = ADDI(p, C_IDX(1));
+      Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+      return {pLo, pHi};
+    };
+
+    if (inPadZone == nullptr)
+      return loadRange();
+
+    SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+    scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+    // True branch.
+    b.setInsertionPointToStart(posRangeIf.thenBlock());
+    // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
----------------
aartbik wrote:

can you move this comment into the True branch line,
so that the True/False branch is clear

For symmetry, you can also say what the False branch does at L160

https://github.com/llvm/llvm-project/pull/90687


More information about the Mlir-commits mailing list