[Mlir-commits] [mlir] [mlir][sparse] support loop range query using SparseTensorLevel. (PR #75670)

Peiming Liu llvmlistbot at llvm.org
Fri Dec 15 16:00:21 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/75670

>From a42484d5a549ab2002206f32321611c5b87e276d Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Dec 2023 23:40:25 +0000
Subject: [PATCH 1/2] [mlir][sparse] implement peekRange for SparseTensorLevel.

---
 .../Transforms/Utils/LoopEmitter.cpp          |  71 +++------
 .../Transforms/Utils/SparseTensorLevel.cpp    | 146 ++++++++++++++++++
 .../Transforms/Utils/SparseTensorLevel.h      |  63 ++------
 .../sparse_conv_2d_slice_based.mlir           |  15 +-
 4 files changed, 186 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 0ba7cf33b6cbad..35faf1769746d8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -244,12 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
 Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
                                   TensorId tid, Level lvl, Value pLo,
                                   Value pHi) {
-  SparseTensorLevel &level = *lvls[tid][lvl];
-  const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
+  SparseTensorLevel &stl = *lvls[tid][lvl];
+  const Value sameCrd = stl.peekCrdAt(builder, loc, pLo);
   auto whileOp = builder.create<scf::WhileOp>(
       loc, builder.getIndexType(), pLo,
       /*beforeBuilder=*/
-      [pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
+      [pHi, &stl, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
         const auto pos = ivs[0];
         Value inBound = builder.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -260,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
           // Load the next coordinates only when inbound (to avoid OOB
           // accesses).
           builder.setInsertionPointToStart(ifInBound.thenBlock());
-          Value crd = level.peekCrdAt(builder, loc, pos);
+          Value crd = stl.peekCrdAt(builder, loc, pos);
           Value isSameCrd = builder.create<arith::CmpIOp>(
               loc, arith::CmpIPredicate::eq, crd, sameCrd);
           YIELD(isSameCrd);
@@ -1226,27 +1226,19 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
 
   const Value c0 = C_IDX(0);
   const Value c1 = C_IDX(1);
-  const Value c2 = C_IDX(2);
   // Either the first level, or the previous level has been set.
   /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   assert(lvl == 0 || posits[tid][lvl - 1]);
-  if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
-    // TODO: eliminate the cast upon feature complete.
-    const Value mem =
-        isCompressedLT(lvlTp)
-            ? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
-            : static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
-
-    Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
-    if (isLooseCompressedLT(lvlTp))
-      pLo = builder.create<arith::MulIOp>(loc, pLo, c2);
-    posits[tid][lvl] = genIndexLoad(builder, loc, mem, pLo);
-
-    const Value pHi = ADDI(pLo, c1);
-    highs[tid][lvl] = genIndexLoad(builder, loc, mem, pHi);
+  if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp) ||
+      is2OutOf4LT(lvlTp)) {
+
+    Value pos = lvl == 0 ? c0 : posits[tid][lvl - 1];
+    std::tie(posits[tid][lvl], highs[tid][lvl]) =
+        lvls[tid][lvl]->peekRangeAt(builder, loc, pos);
     return;
   }
   if (isSingletonLT(lvlTp)) {
+    // TODO: merge this as well when SparseTensorLevel support dedup.
     const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
     posits[tid][lvl] = pLo;
 
@@ -1262,13 +1254,6 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
                           : ADDI(pLo, c1);
     return;
   }
-  if (is2OutOf4LT(lvlTp)) {
-    const Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
-    // Each 2:4 block has exactly two specified elements.
-    posits[tid][lvl] = MULI(pLo, c2);
-    highs[tid][lvl] = ADDI(posits[tid][lvl], c2);
-    return;
-  }
   llvm_unreachable("Unrecognized level-type!");
 }
 
@@ -1824,18 +1809,11 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   auto [nxSz, stride] = sliceMeta[tid][lvl][1];
   assert(stride == 1 && "Not yet implemented");
   Value sPtrBuf = slicePosBuffer[tid][lvl][0];
-  Value pHi, pLo;
-  if (lvl == 0) {
-    pLo = c0;
-    // TODO: eliminate the cast upon feature complete.pLo = c0;
-    Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
-    pHi = genIndexLoad(builder, loc, pBuf, c1);
-  } else {
-    // TODO: eliminate the cast upon feature complete.} else {
-    Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
-    pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
-    pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
-  }
+  const SparseTensorLevel &stl = *lvls[tid][lvl];
+
+  Value p = lvl == 0 ? c0 : posits[tid][lvl - 1];
+  auto [pLo, pHi] = stl.peekRangeAt(builder, loc, p);
+
   // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
   updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
   updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
@@ -1849,7 +1827,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   // nonempty. though we assume that even on empty sparse tensors, a non-empty
   // ptr/idx buffer is allocated for each level so it would not cause OOB to
   // avoid generating a ifOp here.
-  Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
+  Value minCrd = stl.peekCrdAt(builder, loc, pLo);
 
   // FIXME: We need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1879,7 +1857,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
 // }
 void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
                                           TensorId tid, Level lvl) {
-  Value c0 = C_IDX(0), c1 = C_IDX(1);
+  Value c0 = C_IDX(0);
   unsigned depth = levelReducedDep[tid][lvl];
   // The remaining slice size after reduction.
   Value remSz = sliceMeta[tid][lvl][depth + 1].first;
@@ -1929,17 +1907,14 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
 
   ValueRange result = genUnResolvedSliceTreeTraverse(
       builder, loc, tid, unResSlices, firstResLvl, reduc,
-      [this, c1, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
-                                    MutableArrayRef<Value> reduc) {
+      [this, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, Value iv,
+                                MutableArrayRef<Value> reduc) {
         Value &nonEmpty = reduc[0];
         Value &minCrd = reduc[1];
         Value &curTupleCnt = reduc[2];
 
-        Value pHi = ADDI(iv, c1);
-        // TODO: eliminate the cast upon feature complete.
-        Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
-        Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
-        Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
+        const SparseTensorLevel &stl = *lvls[tid][lvl];
+        auto [sPLo, sPHi] = stl.peekRangeAt(builder, loc, iv);
 
         // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
         // one non-empty lvl, the slice is non-empty.
@@ -1957,7 +1932,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
           // }
           OpBuilder::InsertionGuard guard(builder);
           builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
-          Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
+          Value curC = stl.peekCrdAt(builder, loc, sPLo);
           Value isSmaller = CMPI(ult, curC, minCrd);
           Value newMin = SELECT(isSmaller, curC, minCrd);
           YIELD(newMin);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index d9d26794d7bcec..1af1e99b2833b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -13,6 +13,79 @@
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
+using ValuePair = std::pair<Value, Value>;
+
+//===----------------------------------------------------------------------===//
+// SparseTensorLevel derived classes.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class SparseLevel : public SparseTensorLevel {
+public:
+  SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
+
+  Value peekCrdAt(OpBuilder &, Location, Value) const override;
+
+protected:
+  const Value crdBuffer;
+};
+
+class DenseLevel : public SparseTensorLevel {
+public:
+  DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
+    // Dense level, loop upper bound equals to the level size.
+    loopHi = lvlSize;
+  }
+
+  Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
+    return pos;
+  }
+
+  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+};
+
+class CompressedLevel : public SparseLevel {
+public:
+  CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+
+private:
+  const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+  LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
+                       Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+
+private:
+  const Value posBuffer;
+};
+
+class SingletonLevel : public SparseLevel {
+public:
+  SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+};
+
+class TwoOutFourLevel : public SparseLevel {
+public:
+  TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer) {}
+
+  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+};
+
+} // namespace
 
 std::unique_ptr<SparseTensorLevel>
 sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
@@ -49,6 +122,79 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
   llvm_unreachable("unrecognizable level format");
 }
 
+//===----------------------------------------------------------------------===//
+// File local helper functions/macros.
+//===----------------------------------------------------------------------===//
+#define CMPI(p, lhs, rhs)                                                      \
+  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)))
+
+#define C_IDX(v) (constantIndex(b, l, (v)))
+#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
+#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)))
+#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)))
+#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)))
+#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)))
+#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)))
+#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
+#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
+
+static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) {
+  return std::make_pair(lo, ADDI(lo, sz));
+}
+
+//===----------------------------------------------------------------------===//
+// SparseTensorLevel derived classes implemetation.
+//===----------------------------------------------------------------------===//
+
 Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
   return genIndexLoad(b, l, crdBuffer, pos);
 }
+
+// PeekRange Implementation for all sparse levels.
+ValuePair DenseLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
+                                  Value max) const {
+  assert(max == nullptr && "Dense level can not be non-unique.");
+  return constantRange(b, l, C_IDX(0), lvlSize);
+}
+ValuePair CompressedLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
+                                       Value max) const {
+  if (max == nullptr) {
+    Value pLo = genIndexLoad(b, l, posBuffer, p);
+    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+    return {pLo, pHi};
+  }
+  llvm_unreachable("TODO: dedup not implemented");
+}
+ValuePair LooseCompressedLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
+                                            Value max) const {
+  // Allows this?
+  assert(max == nullptr && "loss compressed level can not be non-unique.");
+
+  p = MULI(p, C_IDX(2));
+  Value pLo = genIndexLoad(b, l, posBuffer, p);
+  Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+  return {pLo, pHi};
+}
+ValuePair SingletonLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
+                                      Value max) const {
+
+  if (max == nullptr)
+    return constantRange(b, l, p, C_IDX(1));
+  llvm_unreachable("TODO: dedup not implemented");
+}
+ValuePair TwoOutFourLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
+                                       Value max) const {
+  assert(max == nullptr && "2:4 level can not be non-unique.");
+  // Each 2:4 block has exactly two specified elements.
+  Value c2 = C_IDX(2);
+  return constantRange(b, l, MULI(p, c2), c2);
+}
+
+#undef CMPI
+#undef C_IDX
+#undef YIELD
+#undef ADDI
+#undef ANDI
+#undef SUBI
+#undef MULI
+#undef SELECT
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index e10356a55cc7e3..f5c29cda7c54f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -17,6 +17,8 @@ namespace sparse_tensor {
 class SparseTensorLevel {
   SparseTensorLevel(SparseTensorLevel &&) = delete;
   SparseTensorLevel(const SparseTensorLevel &) = delete;
+  SparseTensorLevel &operator=(SparseTensorLevel &&) = delete;
+  SparseTensorLevel &operator=(const SparseTensorLevel &) = delete;
 
 public:
   SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
@@ -24,6 +26,13 @@ class SparseTensorLevel {
 
   virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
 
+  /// Peeks the lower and upper bound to *fully* traverse the level with
+  /// the given position `p` that the immediate parent level is current at.
+  /// `bound` is only used when the level is `non-unique` and deduplication is
+  /// required. It specifies the max upper bound of the non-unique segment.
+  virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l, Value p,
+                                              Value bound = Value()) const = 0;
+
   LevelType getLT() const { return lt; }
   Value getPos() const { return pos; }
   Value getCrd() const { return crd; }
@@ -49,60 +58,6 @@ class SparseTensorLevel {
 std::unique_ptr<SparseTensorLevel>
 makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
 
-class DenseLevel : public SparseTensorLevel {
-public:
-  DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
-    // Dense level, loop upper bound equals to the level size.
-    loopHi = lvlSize;
-  }
-
-  Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
-    return pos;
-  }
-};
-
-class SparseLevel : public SparseTensorLevel {
-public:
-  SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
-      : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
-
-  Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override;
-
-public: // TODO: make these values private upon feature complete.
-  const Value crdBuffer;
-};
-
-class CompressedLevel : public SparseLevel {
-public:
-  CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
-      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-public: // TODO: make these values private upon feature complete.
-  const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
-  LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
-                       Value crdBuffer)
-      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
-public: // TODO: make these values private upon feature complete.
-  const Value posBuffer;
-};
-
-class SingletonLevel : public SparseLevel {
-public:
-  SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
-      : SparseLevel(lt, lvlSize, crdBuffer) {}
-};
-
-class TwoOutFourLevel : public SparseLevel {
-public:
-  TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
-      : SparseLevel(lt, lvlSize, crdBuffer) {}
-};
-
 } // namespace sparse_tensor
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index a3c1e76a3d09af..bf61e792ffbe05 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -27,11 +27,12 @@
 // CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
 // CHECK-DAG:       %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
 // CHECK-DAG:       %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// CHECK-DAG:       %[[VAL_19:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// CHECK:           memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// CHECK:           memref.store %[[VAL_19]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// CHECK:           %[[VAL_20:.*]] = arith.cmpi ugt, %[[VAL_19]], %[[VAL_8]] : index
-// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK-DAG:       %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
+// CHECK-DAG:       %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
+// CHECK:           memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
+// CHECK:           memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref<?xindex>
 // CHECK:           %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
 // CHECK:           %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
 // CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
@@ -56,8 +57,8 @@
 // CHECK:               scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index
 // CHECK:             } do {
 // CHECK:             ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index):
-// CHECK:               %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
-// CHECK:               %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
+// CHECK-DAG:           %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
+// CHECK-DAG:           %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
 // CHECK:               %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref<?xindex>
 // CHECK:               %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index
 // CHECK:               %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1

>From 85db22ec58e2ac3add5f4e135f38cfd770b28fe0 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Sat, 16 Dec 2023 00:00:10 +0000
Subject: [PATCH 2/2] address comments.

---
 .../Transforms/Utils/SparseTensorLevel.cpp    | 134 ++++++++----------
 1 file changed, 60 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 1af1e99b2833b0..aea0910d980ab7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -15,6 +15,26 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 using ValuePair = std::pair<Value, Value>;
 
+//===----------------------------------------------------------------------===//
+// File local helper functions/macros.
+//===----------------------------------------------------------------------===//
+#define CMPI(p, lhs, rhs)                                                      \
+  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)))
+
+#define C_IDX(v) (constantIndex(b, l, (v)))
+#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
+#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)))
+#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)))
+#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)))
+#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)))
+#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)))
+#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
+#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
+
+static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) {
+  return std::make_pair(lo, ADDI(lo, sz));
+}
+
 //===----------------------------------------------------------------------===//
 // SparseTensorLevel derived classes.
 //===----------------------------------------------------------------------===//
@@ -26,7 +46,9 @@ class SparseLevel : public SparseTensorLevel {
   SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
       : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
 
-  Value peekCrdAt(OpBuilder &, Location, Value) const override;
+  Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override {
+    return genIndexLoad(b, l, crdBuffer, pos);
+  }
 
 protected:
   const Value crdBuffer;
@@ -43,7 +65,11 @@ class DenseLevel : public SparseTensorLevel {
     return pos;
   }
 
-  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    assert(max == nullptr && "Dense level can not be non-unique.");
+    return constantRange(b, l, C_IDX(0), lvlSize);
+  }
 };
 
 class CompressedLevel : public SparseLevel {
@@ -51,7 +77,15 @@ class CompressedLevel : public SparseLevel {
   CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
       : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    if (max == nullptr) {
+      Value pLo = genIndexLoad(b, l, posBuffer, p);
+      Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+      return {pLo, pHi};
+    }
+    llvm_unreachable("TODO: dedup not implemented");
+  }
 
 private:
   const Value posBuffer;
@@ -63,7 +97,16 @@ class LooseCompressedLevel : public SparseLevel {
                        Value crdBuffer)
       : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    // Allows this?
+    assert(max == nullptr && "loss compressed level can not be non-unique.");
+
+    p = MULI(p, C_IDX(2));
+    Value pLo = genIndexLoad(b, l, posBuffer, p);
+    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+    return {pLo, pHi};
+  }
 
 private:
   const Value posBuffer;
@@ -74,7 +117,12 @@ class SingletonLevel : public SparseLevel {
   SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
       : SparseLevel(lt, lvlSize, crdBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    if (max == nullptr)
+      return constantRange(b, l, p, C_IDX(1));
+    llvm_unreachable("TODO: dedup not implemented");
+  }
 };
 
 class TwoOutFourLevel : public SparseLevel {
@@ -82,7 +130,13 @@ class TwoOutFourLevel : public SparseLevel {
   TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
       : SparseLevel(lt, lvlSize, crdBuffer) {}
 
-  ValuePair peekRangeAt(OpBuilder &, Location, Value, Value) const override;
+  ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+                        Value max) const override {
+    assert(max == nullptr && "2:4 level can not be non-unique.");
+    // Each 2:4 block has exactly two specified elements.
+    Value c2 = C_IDX(2);
+    return constantRange(b, l, MULI(p, c2), c2);
+  }
 };
 
 } // namespace
@@ -122,74 +176,6 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
   llvm_unreachable("unrecognizable level format");
 }
 
-//===----------------------------------------------------------------------===//
-// File local helper functions/macros.
-//===----------------------------------------------------------------------===//
-#define CMPI(p, lhs, rhs)                                                      \
-  (b.create<arith::CmpIOp>(l, arith::CmpIPredicate::p, (lhs), (rhs)))
-
-#define C_IDX(v) (constantIndex(b, l, (v)))
-#define YIELD(vs) (b.create<scf::YieldOp>(l, (vs)))
-#define ADDI(lhs, rhs) (b.create<arith::AddIOp>(l, (lhs), (rhs)))
-#define ANDI(lhs, rhs) (b.create<arith::AndIOp>(l, (lhs), (rhs)))
-#define SUBI(lhs, rhs) (b.create<arith::SubIOp>(l, (lhs), (rhs)))
-#define MULI(lhs, rhs) (b.create<arith::MulIOp>(l, (lhs), (rhs)))
-#define REMUI(lhs, rhs) (b.create<arith::RemUIOp>(l, (lhs), (rhs)))
-#define DIVUI(lhs, rhs) (b.create<arith::DivUIOp>(l, (lhs), (rhs)))
-#define SELECT(c, lhs, rhs) (b.create<arith::SelectOp>(l, (c), (lhs), (rhs)))
-
-static ValuePair constantRange(OpBuilder &b, Location l, Value lo, Value sz) {
-  return std::make_pair(lo, ADDI(lo, sz));
-}
-
-//===----------------------------------------------------------------------===//
-// SparseTensorLevel derived classes implemetation.
-//===----------------------------------------------------------------------===//
-
-Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
-  return genIndexLoad(b, l, crdBuffer, pos);
-}
-
-// PeekRange Implementation for all sparse levels.
-ValuePair DenseLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
-                                  Value max) const {
-  assert(max == nullptr && "Dense level can not be non-unique.");
-  return constantRange(b, l, C_IDX(0), lvlSize);
-}
-ValuePair CompressedLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
-                                       Value max) const {
-  if (max == nullptr) {
-    Value pLo = genIndexLoad(b, l, posBuffer, p);
-    Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-    return {pLo, pHi};
-  }
-  llvm_unreachable("TODO: dedup not implemented");
-}
-ValuePair LooseCompressedLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
-                                            Value max) const {
-  // Allows this?
-  assert(max == nullptr && "loss compressed level can not be non-unique.");
-
-  p = MULI(p, C_IDX(2));
-  Value pLo = genIndexLoad(b, l, posBuffer, p);
-  Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
-  return {pLo, pHi};
-}
-ValuePair SingletonLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
-                                      Value max) const {
-
-  if (max == nullptr)
-    return constantRange(b, l, p, C_IDX(1));
-  llvm_unreachable("TODO: dedup not implemented");
-}
-ValuePair TwoOutFourLevel::peekRangeAt(OpBuilder &b, Location l, Value p,
-                                       Value max) const {
-  assert(max == nullptr && "2:4 level can not be non-unique.");
-  // Each 2:4 block has exactly two specified elements.
-  Value c2 = C_IDX(2);
-  return constantRange(b, l, MULI(p, c2), c2);
-}
-
 #undef CMPI
 #undef C_IDX
 #undef YIELD



More information about the Mlir-commits mailing list