[Mlir-commits] [mlir] [mlir][sparse] set up the skeleton for SparseTensorLevel abstraction. (PR #75645)

Peiming Liu llvmlistbot at llvm.org
Fri Dec 15 11:51:12 PST 2023


https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/75645

>From c6f97d52f616cb15643cd4beb7b7479dd629e436 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Dec 2023 19:37:33 +0000
Subject: [PATCH 1/2] [mlir][sparse] set up the skeleton for SparseTensorLevel
 abstraction.

---
 .../SparseTensor/Transforms/CMakeLists.txt    |   1 +
 .../Transforms/Utils/LoopEmitter.cpp          |  92 ++++++---------
 .../Transforms/Utils/LoopEmitter.h            |  13 +--
 .../Transforms/Utils/SparseTensorLevels.cpp   |  46 ++++++++
 .../Transforms/Utils/SparseTensorLevels.h     | 109 ++++++++++++++++++
 5 files changed, 197 insertions(+), 64 deletions(-)
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
 create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index ad8b0d02eca35e..d3ab65e4e1793a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
   Utils/IterationGraphSorter.cpp
   Utils/LoopEmitter.cpp
   Utils/SparseTensorDescriptor.cpp
+  Utils/SparseTensorLevels.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 784c793c9bd119..0ba7cf33b6cbad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -126,15 +126,15 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
 // Generates a bool value for while loop condition that tries to iterate over a
 // fully reduced level with affine index expression.
 static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
-                                        Value crdBuf, Value crdHi, Value posit,
-                                        Value posHi) {
+                                        const SparseTensorLevel &level,
+                                        Value crdHi, Value posit, Value posHi) {
   Value inBound = CMPI(ult, posit, posHi);
   auto ifOp =
       builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
   // if (inbound)
   //   yield coord < crdHi
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  Value crd = genIndexLoad(builder, loc, crdBuf, posit);
+  Value crd = level.peekCrdAt(builder, loc, posit);
   YIELD(CMPI(ult, crd, crdHi));
   // else
   //   yield false
@@ -244,13 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
 Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
                                   TensorId tid, Level lvl, Value pLo,
                                   Value pHi) {
-  const auto coordinates = coordinatesBuffers[tid][lvl];
-  const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo);
+  SparseTensorLevel &level = *lvls[tid][lvl];
+  const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
   auto whileOp = builder.create<scf::WhileOp>(
       loc, builder.getIndexType(), pLo,
       /*beforeBuilder=*/
-      [pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
-                                  ValueRange ivs) {
+      [pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
         const auto pos = ivs[0];
         Value inBound = builder.create<arith::CmpIOp>(
             loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -261,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
           // Load the next coordinates only when inbound (to avoid OOB
           // accesses).
           builder.setInsertionPointToStart(ifInBound.thenBlock());
-          Value crd = genIndexLoad(builder, loc, coordinates, pos);
+          Value crd = level.peekCrdAt(builder, loc, pos);
           Value isSameCrd = builder.create<arith::CmpIOp>(
               loc, arith::CmpIPredicate::eq, crd, sameCrd);
           YIELD(isSameCrd);
@@ -284,11 +283,8 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
 
 Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
                                 Level lvl) {
-  // A load on the coordinates array yields the coordinate.
-  const Value mem = coordinatesBuffers[tid][lvl];
-  /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   const Value pos = posits[tid][lvl];
-  const Value crd = genIndexLoad(builder, loc, mem, pos);
+  const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
   return crd;
 }
 
@@ -318,9 +314,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
   this->segHi.assign(numTensors, std::vector<Value>());
   this->posits.assign(numTensors, std::vector<Value>());
   this->coords.assign(numTensors, std::vector<Value>());
-  this->positionsBuffers.assign(numTensors, std::vector<Value>());
-  this->coordinatesBuffers.assign(numTensors, std::vector<Value>());
   this->valBuffer.assign(numTensors, nullptr);
+  this->lvls.resize(numTensors);
   this->isSparseSlices.assign(numTensors, false);
   this->sliceOffsets.assign(numTensors, std::vector<Value>());
   this->sliceStrides.assign(numTensors, std::vector<Value>());
@@ -377,8 +372,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
     segHi[tid].assign(lvlRank, Value());
     posits[tid].assign(lvlRank, Value());
     coords[tid].assign(lvlRank, Value());
-    positionsBuffers[tid].assign(lvlRank, Value());
-    coordinatesBuffers[tid].assign(lvlRank, Value());
+    lvls[tid].resize(lvlRank);
+
     sliceOffsets[tid].assign(lvlRank, Value());
     sliceStrides[tid].assign(lvlRank, Value());
 
@@ -448,22 +443,7 @@ void LoopEmitter::initializeLoopEmit(
 
     // Scan all levels of current tensor.
     for (Level l = 0; l < lvlRank; l++) {
-      // This should be called only once at beginning.
-      assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
-             !highs[t][l]);
-      const auto lvlTp = lvlTypes[t][l];
-      // Handle sparse storage schemes.
-      if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
-        // Generate sparse primitives to obtain positions and coordinates.
-        positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
-        coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
-      } else if (isSingletonLT(lvlTp) || is2OutOf4LT(lvlTp)) {
-        // Singleton level, fetch coordinates.
-        coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
-      } else {
-        // Dense level, nothing to fetch.
-        assert(isDenseLT(lvlTp));
-      }
+      lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l);
 
       // Find upper bound in current dimension.
       highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
@@ -756,8 +736,7 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
       crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz);
     }
     assert(crdHi);
-    return genSparseReducedAffineCond(builder, loc,
-                                      coordinatesBuffers[tid][lvl], crdHi,
+    return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi,
                                       ivs[0], highs[tid][lvl]);
   }
   case LoopCondKind::SparseAffineUnRedCond: {
@@ -802,10 +781,9 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
     sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
     // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
     Value posit = ivs[0];
-    Value crdBuf = coordinatesBuffers[tid][lvl];
     // We need to substract the offset to get relative coordinates.
     // TODO: Maybe assert relC >=0 during runtime in debug build?
-    Value absC = genIndexLoad(builder, loc, crdBuf, posit);
+    Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit);
     auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
     posits[tid][lvl] = posit;
     coords[tid][lvl] = relC;
@@ -1189,9 +1167,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
   // The induction variable gives the position.
   const Value pos = forOp.getInductionVar();
   posits[tid][lvl] = pos;
-  // Generating a load on the coordinates array yields the crd.
-  const Value mem = coordinatesBuffers[tid][lvl];
-  const Value crd = genIndexLoad(builder, loc, mem, pos);
+  const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
   coords[tid][lvl] = crd;
 
   // Generate an if-condition to filter out coordinates that are not
@@ -1255,7 +1231,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
   /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
   assert(lvl == 0 || posits[tid][lvl - 1]);
   if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
-    const Value mem = positionsBuffers[tid][lvl];
+    // TODO: eliminate the cast upon feature complete.
+    const Value mem =
+        isCompressedLT(lvlTp)
+            ? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
+            : static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
 
     Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
     if (isLooseCompressedLT(lvlTp))
@@ -1623,8 +1603,7 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
       /*beforeBuilder=*/
       [this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
                                        ValueRange args) {
-        Value cond = genSparseReducedAffineCond(builder, loc,
-                                                coordinatesBuffers[tid][lvl],
+        Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl],
                                                 sliceHi, args[0], posHi);
         // continue if not yet break nor out of bound.
         builder.create<scf::ConditionOp>(loc, cond, args);
@@ -1848,12 +1827,14 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   Value pHi, pLo;
   if (lvl == 0) {
     pLo = c0;
-    pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+    // TODO: eliminate the cast upon feature complete.pLo = c0;
+    Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
+    pHi = genIndexLoad(builder, loc, pBuf, c1);
   } else {
-    pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
-                       posits[tid][lvl - 1]);
-    pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
-                       ADDI(posits[tid][lvl - 1], c1));
+    // TODO: eliminate the cast upon feature complete.} else {
+    Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+    pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
+    pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
   }
   // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
   updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
@@ -1868,7 +1849,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
   // nonempty. though we assume that even on empty sparse tensors, a non-empty
   // ptr/idx buffer is allocated for each level so it would not cause OOB to
   // avoid generating a ifOp here.
-  Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+  Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
 
   // FIXME: We need the relative offset related to the base slice.
   Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1955,9 +1936,10 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
         Value &curTupleCnt = reduc[2];
 
         Value pHi = ADDI(iv, c1);
-        Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
-        Value sPHi =
-            genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
+        // TODO: eliminate the cast upon feature complete.
+        Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+        Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
+        Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
 
         // isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
         // one non-empty lvl, the slice is non-empty.
@@ -1975,8 +1957,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
           // }
           OpBuilder::InsertionGuard guard(builder);
           builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
-          Value curC =
-              genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo);
+          Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
           Value isSmaller = CMPI(ult, curC, minCrd);
           Value newMin = SELECT(isSmaller, curC, minCrd);
           YIELD(newMin);
@@ -2176,8 +2157,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
           /* if pLo < pHi */ {
             builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
             // coord = load[pLo]
-            Value coord =
-                genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+            Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
             Value pred = CMPI(eq, coord, info.minCrd);
             auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
             /* if coord == minCrd */ {
@@ -2209,7 +2189,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
           auto newMin =
               builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
           builder.setInsertionPointToStart(&newMin.getThenRegion().front());
-          YIELD(genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo));
+          YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo));
 
           builder.setInsertionPointToStart(&newMin.getElseRegion().front());
           YIELD(curMinCrd);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 78bb53e4483f60..272d2bf0e89c2e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,6 +11,8 @@
 
 #include <vector>
 
+#include "SparseTensorLevels.h"
+
 #include "mlir/Dialect/SparseTensor/IR/Enums.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
@@ -241,12 +243,6 @@ class LoopEmitter {
   const std::vector<std::vector<Value>> &getPosits() const { return posits; };
   const std::vector<std::vector<Value>> &getCoords() const { return coords; };
   const std::vector<std::vector<Value>> &getHighs() const { return highs; };
-  const std::vector<std::vector<Value>> &getPositionBuffers() const {
-    return positionsBuffers;
-  };
-  const std::vector<std::vector<Value>> &getCoordinateBuffers() const {
-    return coordinatesBuffers;
-  };
   const std::vector<Value> &getValBuffer() const { return valBuffer; };
 
   constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
@@ -648,8 +644,9 @@ class LoopEmitter {
   std::vector<std::vector<Value>> segHi;
   std::vector<std::vector<Value>> highs;
   std::vector<std::vector<Value>> lvlSizes;
-  std::vector<std::vector<Value>> positionsBuffers;   // to_positions
-  std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+  // std::vector<std::vector<Value>> positionsBuffers;   // to_positions
+  // std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+  std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
   std::vector<Value> valBuffer;                       // to_value
 
   //
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
new file mode 100644
index 00000000000000..a9dae17e9de055
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
@@ -0,0 +1,46 @@
+#include "SparseTensorLevels.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
+                                     Level l) {
+  auto stt = getSparseTensorType(t);
+
+  LevelType lt = stt.getLvlType(l);
+  Value lvlSz = stt.hasEncoding()
+                    ? builder.create<LvlOp>(loc, t, l).getResult()
+                    : builder.create<tensor::DimOp>(loc, t, l).getResult();
+
+  switch (*getLevelFormat(lt)) {
+  case LevelFormat::Dense:
+    return std::make_unique<DenseLevel>(lvlSz);
+  case LevelFormat::Compressed: {
+    Value posBuf = genToPositions(builder, loc, t, l);
+    Value crdBuf = genToCoordinates(builder, loc, t, l);
+    return std::make_unique<CompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+  }
+  case LevelFormat::LooseCompressed: {
+    Value posBuf = genToPositions(builder, loc, t, l);
+    Value crdBuf = genToCoordinates(builder, loc, t, l);
+    return std::make_unique<LooseCompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+  }
+  case LevelFormat::Singleton: {
+    Value crdBuf = genToCoordinates(builder, loc, t, l);
+    return std::make_unique<SingletonLevel>(lt, lvlSz, crdBuf);
+  }
+  case LevelFormat::TwoOutOfFour: {
+    Value crdBuf = genToCoordinates(builder, loc, t, l);
+    return std::make_unique<TwoOutFourLevel>(lt, lvlSz, crdBuf);
+  }
+  }
+  llvm_unreachable("unrecognizable level format");
+}
+
+Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
+  return genIndexLoad(b, l, crdBuffer, pos);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
new file mode 100644
index 00000000000000..c6574295ca7fae
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
@@ -0,0 +1,109 @@
+//===- TensorLevels.h -------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+class SparseTensorLevel {
+  SparseTensorLevel(SparseTensorLevel &&) = delete;
+  SparseTensorLevel(const SparseTensorLevel &) = delete;
+
+public:
+  SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
+  virtual ~SparseTensorLevel() = default;
+
+  virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
+
+  LevelType getLT() const { return lt; }
+  Value getPos() const { return pos; }
+  Value getCrd() const { return crd; }
+  Value getLoopHi() const { return loopHi; }
+  Value getLoopLo() const { return loopLo; }
+
+protected:
+  SparseTensorLevel(LevelType lt, Value lvlSize)
+      : lt(lt), lvlSize(lvlSize), pos(nullptr), crd(nullptr), loopHi(nullptr),
+        loopLo(nullptr){};
+
+  const LevelType lt;
+  const Value lvlSize;
+
+public: // TODO: make these values private upon feature complete.
+  Value pos;
+  Value crd;
+  Value loopHi;
+  Value loopLo;
+};
+
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
+
+class DenseLevel : public SparseTensorLevel {
+public:
+  DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
+    // Dense level, loop upper bound equals to the level size.
+    loopHi = lvlSize;
+  }
+
+  Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
+    return pos;
+  }
+};
+
+class SparseLevel : public SparseTensorLevel {
+public:
+  SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
+
+  Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override;
+
+public: // TODO: make these values private upon feature complete.
+  const Value crdBuffer;
+};
+
+class CompressedLevel : public SparseLevel {
+public:
+  CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+  const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+  LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
+                       Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+  const Value posBuffer;
+};
+
+class SingletonLevel : public SparseLevel {
+public:
+  SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer) {}
+};
+
+class TwoOutFourLevel : public SparseLevel {
+public:
+  TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+      : SparseLevel(lt, lvlSize, crdBuffer) {}
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_

>From a09eb73346c8dc80e942d4ee935db1325f9dc031 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Dec 2023 19:50:59 +0000
Subject: [PATCH 2/2] address comments.

---
 .../Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h   | 2 --
 .../SparseTensor/Transforms/Utils/SparseTensorLevels.h    | 8 ++++----
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 272d2bf0e89c2e..8fbba896d8621c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -644,8 +644,6 @@ class LoopEmitter {
   std::vector<std::vector<Value>> segHi;
   std::vector<std::vector<Value>> highs;
   std::vector<std::vector<Value>> lvlSizes;
-  // std::vector<std::vector<Value>> positionsBuffers;   // to_positions
-  // std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
   std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
   std::vector<Value> valBuffer;                       // to_value
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
index c6574295ca7fae..d4fc518e9efb6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
@@ -1,4 +1,4 @@
-//===- TensorLevels.h -------------------------------------------*- C++ -*-===//
+//===- SparseTensorLevels.h -------------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 
@@ -106,4 +106,4 @@ class TwoOutFourLevel : public SparseLevel {
 } // namespace sparse_tensor
 } // namespace mlir
 
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_



More information about the Mlir-commits mailing list