[Mlir-commits] [mlir] [mlir][sparse] set up the skeleton for SparseTensorLevel abstraction. (PR #75645)
Peiming Liu
llvmlistbot at llvm.org
Fri Dec 15 11:51:12 PST 2023
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/75645
>From c6f97d52f616cb15643cd4beb7b7479dd629e436 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Dec 2023 19:37:33 +0000
Subject: [PATCH 1/2] [mlir][sparse] set up the skeleton for SparseTensorLevel
abstraction.
---
.../SparseTensor/Transforms/CMakeLists.txt | 1 +
.../Transforms/Utils/LoopEmitter.cpp | 92 ++++++---------
.../Transforms/Utils/LoopEmitter.h | 13 +--
.../Transforms/Utils/SparseTensorLevels.cpp | 46 ++++++++
.../Transforms/Utils/SparseTensorLevels.h | 109 ++++++++++++++++++
5 files changed, 197 insertions(+), 64 deletions(-)
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
create mode 100644 mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index ad8b0d02eca35e..d3ab65e4e1793a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
+ Utils/SparseTensorLevels.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 784c793c9bd119..0ba7cf33b6cbad 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -126,15 +126,15 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
// Generates a bool value for while loop condition that tries to iterate over a
// fully reduced level with affine index expression.
static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
- Value crdBuf, Value crdHi, Value posit,
- Value posHi) {
+ const SparseTensorLevel &level,
+ Value crdHi, Value posit, Value posHi) {
Value inBound = CMPI(ult, posit, posHi);
auto ifOp =
builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
// if (inbound)
// yield coord < crdHi
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value crd = genIndexLoad(builder, loc, crdBuf, posit);
+ Value crd = level.peekCrdAt(builder, loc, posit);
YIELD(CMPI(ult, crd, crdHi));
// else
// yield false
@@ -244,13 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
TensorId tid, Level lvl, Value pLo,
Value pHi) {
- const auto coordinates = coordinatesBuffers[tid][lvl];
- const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo);
+ SparseTensorLevel &level = *lvls[tid][lvl];
+ const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
auto whileOp = builder.create<scf::WhileOp>(
loc, builder.getIndexType(), pLo,
/*beforeBuilder=*/
- [pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
- ValueRange ivs) {
+ [pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
const auto pos = ivs[0];
Value inBound = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -261,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
// Load the next coordinates only when inbound (to avoid OOB
// accesses).
builder.setInsertionPointToStart(ifInBound.thenBlock());
- Value crd = genIndexLoad(builder, loc, coordinates, pos);
+ Value crd = level.peekCrdAt(builder, loc, pos);
Value isSameCrd = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, crd, sameCrd);
YIELD(isSameCrd);
@@ -284,11 +283,8 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
Level lvl) {
- // A load on the coordinates array yields the coordinate.
- const Value mem = coordinatesBuffers[tid][lvl];
- /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
const Value pos = posits[tid][lvl];
- const Value crd = genIndexLoad(builder, loc, mem, pos);
+ const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
return crd;
}
@@ -318,9 +314,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->segHi.assign(numTensors, std::vector<Value>());
this->posits.assign(numTensors, std::vector<Value>());
this->coords.assign(numTensors, std::vector<Value>());
- this->positionsBuffers.assign(numTensors, std::vector<Value>());
- this->coordinatesBuffers.assign(numTensors, std::vector<Value>());
this->valBuffer.assign(numTensors, nullptr);
+ this->lvls.resize(numTensors);
this->isSparseSlices.assign(numTensors, false);
this->sliceOffsets.assign(numTensors, std::vector<Value>());
this->sliceStrides.assign(numTensors, std::vector<Value>());
@@ -377,8 +372,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
segHi[tid].assign(lvlRank, Value());
posits[tid].assign(lvlRank, Value());
coords[tid].assign(lvlRank, Value());
- positionsBuffers[tid].assign(lvlRank, Value());
- coordinatesBuffers[tid].assign(lvlRank, Value());
+ lvls[tid].resize(lvlRank);
+
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());
@@ -448,22 +443,7 @@ void LoopEmitter::initializeLoopEmit(
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
- // This should be called only once at beginning.
- assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
- !highs[t][l]);
- const auto lvlTp = lvlTypes[t][l];
- // Handle sparse storage schemes.
- if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
- // Generate sparse primitives to obtain positions and coordinates.
- positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
- coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
- } else if (isSingletonLT(lvlTp) || is2OutOf4LT(lvlTp)) {
- // Singleton level, fetch coordinates.
- coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
- } else {
- // Dense level, nothing to fetch.
- assert(isDenseLT(lvlTp));
- }
+ lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l);
// Find upper bound in current dimension.
highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
@@ -756,8 +736,7 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz);
}
assert(crdHi);
- return genSparseReducedAffineCond(builder, loc,
- coordinatesBuffers[tid][lvl], crdHi,
+ return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi,
ivs[0], highs[tid][lvl]);
}
case LoopCondKind::SparseAffineUnRedCond: {
@@ -802,10 +781,9 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
Value posit = ivs[0];
- Value crdBuf = coordinatesBuffers[tid][lvl];
// We need to substract the offset to get relative coordinates.
// TODO: Maybe assert relC >=0 during runtime in debug build?
- Value absC = genIndexLoad(builder, loc, crdBuf, posit);
+ Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit);
auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
posits[tid][lvl] = posit;
coords[tid][lvl] = relC;
@@ -1189,9 +1167,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
// The induction variable gives the position.
const Value pos = forOp.getInductionVar();
posits[tid][lvl] = pos;
- // Generating a load on the coordinates array yields the crd.
- const Value mem = coordinatesBuffers[tid][lvl];
- const Value crd = genIndexLoad(builder, loc, mem, pos);
+ const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
coords[tid][lvl] = crd;
// Generate an if-condition to filter out coordinates that are not
@@ -1255,7 +1231,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
assert(lvl == 0 || posits[tid][lvl - 1]);
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
- const Value mem = positionsBuffers[tid][lvl];
+ // TODO: eliminate the cast upon feature complete.
+ const Value mem =
+ isCompressedLT(lvlTp)
+ ? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
+ : static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
if (isLooseCompressedLT(lvlTp))
@@ -1623,8 +1603,7 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
/*beforeBuilder=*/
[this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
ValueRange args) {
- Value cond = genSparseReducedAffineCond(builder, loc,
- coordinatesBuffers[tid][lvl],
+ Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl],
sliceHi, args[0], posHi);
// continue if not yet break nor out of bound.
builder.create<scf::ConditionOp>(loc, cond, args);
@@ -1848,12 +1827,14 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
Value pHi, pLo;
if (lvl == 0) {
pLo = c0;
- pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+ // TODO: eliminate the cast upon feature complete.pLo = c0;
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
+ pHi = genIndexLoad(builder, loc, pBuf, c1);
} else {
- pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
- posits[tid][lvl - 1]);
- pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
- ADDI(posits[tid][lvl - 1], c1));
+ // TODO: eliminate the cast upon feature complete.} else {
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+ pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
+ pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
}
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
@@ -1868,7 +1849,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
// nonempty. though we assume that even on empty sparse tensors, a non-empty
// ptr/idx buffer is allocated for each level so it would not cause OOB to
// avoid generating a ifOp here.
- Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+ Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
// FIXME: We need the relative offset related to the base slice.
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1955,9 +1936,10 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
Value &curTupleCnt = reduc[2];
Value pHi = ADDI(iv, c1);
- Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
- Value sPHi =
- genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
+ // TODO: eliminate the cast upon feature complete.
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+ Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
+ Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
// isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
// one non-empty lvl, the slice is non-empty.
@@ -1975,8 +1957,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
// }
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
- Value curC =
- genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo);
+ Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
Value isSmaller = CMPI(ult, curC, minCrd);
Value newMin = SELECT(isSmaller, curC, minCrd);
YIELD(newMin);
@@ -2176,8 +2157,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
/* if pLo < pHi */ {
builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
// coord = load[pLo]
- Value coord =
- genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+ Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
Value pred = CMPI(eq, coord, info.minCrd);
auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
/* if coord == minCrd */ {
@@ -2209,7 +2189,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
auto newMin =
builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
builder.setInsertionPointToStart(&newMin.getThenRegion().front());
- YIELD(genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo));
+ YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo));
builder.setInsertionPointToStart(&newMin.getElseRegion().front());
YIELD(curMinCrd);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 78bb53e4483f60..272d2bf0e89c2e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,6 +11,8 @@
#include <vector>
+#include "SparseTensorLevels.h"
+
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
@@ -241,12 +243,6 @@ class LoopEmitter {
const std::vector<std::vector<Value>> &getPosits() const { return posits; };
const std::vector<std::vector<Value>> &getCoords() const { return coords; };
const std::vector<std::vector<Value>> &getHighs() const { return highs; };
- const std::vector<std::vector<Value>> &getPositionBuffers() const {
- return positionsBuffers;
- };
- const std::vector<std::vector<Value>> &getCoordinateBuffers() const {
- return coordinatesBuffers;
- };
const std::vector<Value> &getValBuffer() const { return valBuffer; };
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
@@ -648,8 +644,9 @@ class LoopEmitter {
std::vector<std::vector<Value>> segHi;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- std::vector<std::vector<Value>> positionsBuffers; // to_positions
- std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+ // std::vector<std::vector<Value>> positionsBuffers; // to_positions
+ // std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+ std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
std::vector<Value> valBuffer; // to_value
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
new file mode 100644
index 00000000000000..a9dae17e9de055
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
@@ -0,0 +1,46 @@
+#include "SparseTensorLevels.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
+ Level l) {
+ auto stt = getSparseTensorType(t);
+
+ LevelType lt = stt.getLvlType(l);
+ Value lvlSz = stt.hasEncoding()
+ ? builder.create<LvlOp>(loc, t, l).getResult()
+ : builder.create<tensor::DimOp>(loc, t, l).getResult();
+
+ switch (*getLevelFormat(lt)) {
+ case LevelFormat::Dense:
+ return std::make_unique<DenseLevel>(lvlSz);
+ case LevelFormat::Compressed: {
+ Value posBuf = genToPositions(builder, loc, t, l);
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<CompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ }
+ case LevelFormat::LooseCompressed: {
+ Value posBuf = genToPositions(builder, loc, t, l);
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<LooseCompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ }
+ case LevelFormat::Singleton: {
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<SingletonLevel>(lt, lvlSz, crdBuf);
+ }
+ case LevelFormat::TwoOutOfFour: {
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<TwoOutFourLevel>(lt, lvlSz, crdBuf);
+ }
+ }
+ llvm_unreachable("unrecognizable level format");
+}
+
+Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
+ return genIndexLoad(b, l, crdBuffer, pos);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
new file mode 100644
index 00000000000000..c6574295ca7fae
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
@@ -0,0 +1,109 @@
+//===- TensorLevels.h -------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+class SparseTensorLevel {
+ SparseTensorLevel(SparseTensorLevel &&) = delete;
+ SparseTensorLevel(const SparseTensorLevel &) = delete;
+
+public:
+ SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
+ virtual ~SparseTensorLevel() = default;
+
+ virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
+
+ LevelType getLT() const { return lt; }
+ Value getPos() const { return pos; }
+ Value getCrd() const { return crd; }
+ Value getLoopHi() const { return loopHi; }
+ Value getLoopLo() const { return loopLo; }
+
+protected:
+ SparseTensorLevel(LevelType lt, Value lvlSize)
+ : lt(lt), lvlSize(lvlSize), pos(nullptr), crd(nullptr), loopHi(nullptr),
+ loopLo(nullptr){};
+
+ const LevelType lt;
+ const Value lvlSize;
+
+public: // TODO: make these values private upon feature complete.
+ Value pos;
+ Value crd;
+ Value loopHi;
+ Value loopLo;
+};
+
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
+
+class DenseLevel : public SparseTensorLevel {
+public:
+ DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
+ // Dense level, loop upper bound equals to the level size.
+ loopHi = lvlSize;
+ }
+
+ Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
+ return pos;
+ }
+};
+
+class SparseLevel : public SparseTensorLevel {
+public:
+ SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+ : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
+
+ Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override;
+
+public: // TODO: make these values private upon feature complete.
+ const Value crdBuffer;
+};
+
+class CompressedLevel : public SparseLevel {
+public:
+ CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+ const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
+ Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+ const Value posBuffer;
+};
+
+class SingletonLevel : public SparseLevel {
+public:
+ SingletonLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer) {}
+};
+
+class TwoOutFourLevel : public SparseLevel {
+public:
+ TwoOutFourLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer) {}
+};
+
+} // namespace sparse_tensor
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
>From a09eb73346c8dc80e942d4ee935db1325f9dc031 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Dec 2023 19:50:59 +0000
Subject: [PATCH 2/2] address comments.
---
.../Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h | 2 --
.../SparseTensor/Transforms/Utils/SparseTensorLevels.h | 8 ++++----
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 272d2bf0e89c2e..8fbba896d8621c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -644,8 +644,6 @@ class LoopEmitter {
std::vector<std::vector<Value>> segHi;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- // std::vector<std::vector<Value>> positionsBuffers; // to_positions
- // std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
std::vector<Value> valBuffer; // to_value
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
index c6574295ca7fae..d4fc518e9efb6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
@@ -1,4 +1,4 @@
-//===- TensorLevels.h -------------------------------------------*- C++ -*-===//
+//===- SparseTensorLevels.h -------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//
-#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
-#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -106,4 +106,4 @@ class TwoOutFourLevel : public SparseLevel {
} // namespace sparse_tensor
} // namespace mlir
-#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
More information about the Mlir-commits
mailing list