[Mlir-commits] [mlir] [mlir][sparse] Support pretty print to debug sparse iteration. (PR #80207)
Peiming Liu
llvmlistbot at llvm.org
Thu Feb 1 12:10:32 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/80207
>From 372d11372712bf827a9f6877c9b74636e256fee5 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 30 Jan 2024 22:08:32 +0000
Subject: [PATCH 1/5] [mlir][sparse] Support pretty print to debug sparse
iteration.
---
.../Transforms/SparseTensorPasses.cpp | 4 +-
.../Transforms/Sparsification.cpp | 2 +-
.../Transforms/Utils/CodegenEnv.cpp | 5 +-
.../Transforms/Utils/CodegenEnv.h | 2 +-
.../Transforms/Utils/LoopEmitter.cpp | 13 +-
.../Transforms/Utils/LoopEmitter.h | 20 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 315 +++++++++++-------
.../Transforms/Utils/SparseTensorLevel.h | 172 ++++++----
.../sparse_conv_2d_slice_based.mlir | 276 +++------------
9 files changed, 387 insertions(+), 422 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 375e10f9068e4..0ae9f6483588d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,13 +82,15 @@ struct SparsificationPass
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
+ debugSparseIteration = options.debugSparseIteration;
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
- SparsificationOptions options(parallelization, enableRuntimeLibrary);
+ SparsificationOptions options(parallelization, debugSparseIteration,
+ enableRuntimeLibrary);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..2ceb214052aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Recursively generates code if admissible.
- env.startEmit();
+ env.startEmit(options.debugSparseIteration);
genBuffers(env, rewriter);
// TODO: Constant affine expression should be handled differently when using
// slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd..0af1cc1745f51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
return success();
}
-void CodegenEnv::startEmit() {
+void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
assert(insChain == nullptr && "must only start emitting once");
if (sparseOut) {
insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
/*dependentLvlGetter=*/
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
return merger().getDependentLoops(t, lvl);
- });
+ },
+ emitStrategy);
}
std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b1..7eeddac48f4f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
- void startEmit();
+ void startEmit(DebugSparseIteration emitStrategy);
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 3fa4004ae460e..8c1680a393181 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
+ SparseIterator::setDebugSparseIteration(emitStrategy);
const unsigned numManifestTensors = ts.size();
const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Value offset = genSliceOffset(builder, loc, tensors[t], l);
Value stride = genSliceStride(builder, loc, tensors[t], l);
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
- lvls[t][l]->size());
+ lvls[t][l]->getSize());
return slicedIt;
}
return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
// Construct the while-loop with a parameter for each coordinate.
for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getItVals();
+ ValueRange itVals = it->getCursor();
ivs.append(itVals.begin(), itVals.end());
}
@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// Forward the sparse iterator.
Value cmp = CMPI(eq, it.getCrd(), iv);
it.forwardIf(builder, loc, cmp);
- operands.append(it.getItVals().begin(), it.getItVals().end());
+ operands.append(it.getCursor().begin(), it.getCursor().end());
// const Value newPos = whileOp->getResult(o++);
// Following loops continue iteration from the break point of the
// current while loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index d0f447d926f71..e0b4f81487a68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/IR/PatternMatch.h"
@@ -84,14 +85,17 @@ class LoopEmitter {
/// `isSparseOut` indicates that the sparse output tensor is empty,
/// so the loop emitter will generate loops over it according to the
/// level-sizes.
- void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
-
- explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0,
- DependentLvlGetter getter = nullptr);
+ void
+ initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+ bool hasOutput = false, bool isSparseOut = false,
+ unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+
+ explicit LoopEmitter(
+ ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
+ bool isSparseOut = false, unsigned numLoops = 0,
+ DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..bdaf794744bea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,20 +46,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
-class SparseLevel : public SparseTensorLevel {
-public:
- SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
-
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, crdBuffer, iv);
- }
-
-protected:
- const Value crdBuffer;
-};
-
class DenseLevel : public SparseTensorLevel {
public:
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
@@ -74,53 +60,27 @@ class DenseLevel : public SparseTensorLevel {
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
if (encoded) {
- Value posLo = MULI(p, lvlSize);
- return {posLo, lvlSize};
+ Value posLo = MULI(p, getSize());
+ return {posLo, getSize()};
}
// No need to linearize the position for non-annotated tensors.
- return {C_IDX(0), lvlSize};
+ return {C_IDX(0), getSize()};
}
const bool encoded;
};
-class CompressedLevel : public SparseLevel {
+class SparseLevel : public SparseTensorLevel {
public:
- CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
+ SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ ValueRange lvlBuf)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
+ assert(!lvlBuf.empty());
}
-private:
- const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
- LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
- p = MULI(p, C_IDX(2));
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
+ Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+ return genIndexLoad(b, l, getLvlBufs().front(), iv);
}
-
-private:
- const Value posBuffer;
};
class SingletonLevel : public SparseLevel {
@@ -142,8 +102,8 @@ class SingletonLevel : public SparseLevel {
class TwoOutFourLevel : public SparseLevel {
public:
TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ Value crdBuf)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
@@ -154,6 +114,39 @@ class TwoOutFourLevel : public SparseLevel {
}
};
+class CompressedLevel : public SparseLevel {
+public:
+ CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ if (max == nullptr) {
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+ llvm_unreachable("compressed-nu should be the first non-unique level.");
+ }
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ assert(max == nullptr && "loss compressed level can not be non-unique.");
+ p = MULI(p, C_IDX(2));
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -203,7 +196,8 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
-namespace {
+namespace mlir {
+namespace sparse_tensor {
// The iterator that traverses a concrete sparse tensor levels. High-level
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -212,12 +206,11 @@ namespace {
class ConcreteIterator : public SparseIterator {
protected:
ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
- unsigned itValCnt)
- : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
- stl(stl) {
- // Allocate enough storage for iterator values.
- itValsStorage.resize(itValCnt);
- }
+ unsigned cursorValCnt)
+ : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
+ stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
+ assert(getCursor().size() == cursorValCnt);
+ };
public:
// For LLVM-style RTTI.
@@ -228,22 +221,34 @@ class ConcreteIterator : public SparseIterator {
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
bool iteratableByFor() const override { return kind != IterKind::kDedup; };
Value upperBound(OpBuilder &b, Location l) const override {
- return stl.size();
+ return stl.getSize();
};
protected:
+ const SparseTensorLevel &stl;
// Owner of the storage, all wrappers build on top of a concrete iterator
// share the same storage such that the iterator values are always
// synchronized.
- SmallVector<Value> itValsStorage;
- const SparseTensorLevel &stl;
+ SmallVector<Value> cursorValsStorage;
};
+} // namespace sparse_tensor
+} // namespace mlir
+
+namespace {
+
class TrivialIterator : public ConcreteIterator {
public:
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("trivial<") + stl.toString() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return {b.getIndexType()};
+ }
+
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
ret.push_back(getItPos());
@@ -286,12 +291,12 @@ class TrivialIterator : public ConcreteIterator {
return std::make_pair(getItPos(), posHi);
}
- Value genNotEnd(OpBuilder &b, Location l) override {
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
return CMPI(ult, getItPos(), posHi);
}
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
if (randomAccessible()) {
updateCrd(SUBI(getItPos(), posLo));
} else {
@@ -302,24 +307,24 @@ class TrivialIterator : public ConcreteIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override {
seek(ADDI(getItPos(), C_IDX(1)));
- return getItVals();
+ return getCursor();
}
ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
- Value curPos = getItVals().front();
+ Value curPos = getCursor().front();
Value nxPos = forward(b, l).front();
seek(SELECT(cond, nxPos, curPos));
- return getItVals();
+ return getCursor();
}
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
// Seek to the linearized position.
seek(ADDI(crd, posLo));
updateCrd(crd);
}
- Value getItPos() const { return getItVals().front(); }
+ Value getItPos() const { return getCursor().front(); }
Value posLo, posHi;
};
@@ -337,6 +342,13 @@ class DedupIterator : public ConcreteIterator {
return from->kind == IterKind::kDedup;
}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("dedup<") + stl.toString() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return {b.getIndexType(), b.getIndexType()};
+ }
+
ValuePair getCurPosition() const override { return {getPos(), getSegHi()}; }
void genInitImpl(OpBuilder &b, Location l,
@@ -355,21 +367,21 @@ class DedupIterator : public ConcreteIterator {
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
- ret.append(getItVals().begin(), getItVals().end());
+ ret.append(getCursor().begin(), getCursor().end());
ret.push_back(posHi);
return ret;
};
void deserialize(ValueRange vs) override {
assert(vs.size() == 3);
- seek(vs.take_front(getItVals().size()));
+ seek(vs.take_front(getCursor().size()));
posHi = vs.back();
};
- Value genNotEnd(OpBuilder &b, Location l) override {
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
return CMPI(ult, getPos(), posHi);
}
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
updateCrd(stl.peekCrdAt(b, l, getPos()));
return getCrd();
};
@@ -377,11 +389,11 @@ class DedupIterator : public ConcreteIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override {
Value nxPos = getSegHi(); // forward the position to the next segment.
seek({nxPos, genSegmentHigh(b, l, nxPos)});
- return getItVals();
+ return getCursor();
}
- Value getPos() const { return getItVals()[0]; }
- Value getSegHi() const { return getItVals()[1]; }
+ Value getPos() const { return getCursor()[0]; }
+ Value getSegHi() const { return getCursor()[1]; }
Value posHi;
};
@@ -419,6 +431,13 @@ class FilterIterator : public SparseIterator {
return from->kind == IterKind::kFilter;
}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("filter<") + wrap->getDebugInterfacePrefix() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return wrap->getCursorValTypes(b);
+ }
+
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override { return size; };
@@ -441,14 +460,14 @@ class FilterIterator : public SparseIterator {
}
}
- Value genNotEnd(OpBuilder &b, Location l) override;
+ Value genNotEndImpl(OpBuilder &b, Location l) override;
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
updateCrd(fromWrapCrd(b, l, wrap->deref(b, l)));
return getCrd();
}
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
wrap->locate(b, l, toWrapCrd(b, l, crd));
updateCrd(crd);
@@ -469,8 +488,7 @@ class NonEmptySubSectIterator : public SparseIterator {
const SparseIterator *parent,
std::unique_ptr<SparseIterator> &&delegate,
Value subSectSz)
- : SparseIterator(IterKind::kNonEmptySubSect, delegate->tid, delegate->lvl,
- 3, /*itVals=*/subSectMeta),
+ : SparseIterator(IterKind::kNonEmptySubSect, 3, subSectMeta, *delegate),
parent(parent), delegate(std::move(delegate)),
tupleSz(this->delegate->serialize().size()), subSectSz(subSectSz) {
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -497,6 +515,14 @@ class NonEmptySubSectIterator : public SparseIterator {
return from->kind == IterKind::kNonEmptySubSect;
}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("ne_sub<") + delegate->getDebugInterfacePrefix() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ // minCrd, absolute offset, notEnd
+ return {b.getIndexType(), b.getIndexType(), b.getI1Type()};
+ }
+
// The sliced pointer buffer is organized as:
// [[itVal0, itVal1, ..., pNx0],
// [itVal0, itVal1, ..., pNx0],
@@ -519,8 +545,8 @@ class NonEmptySubSectIterator : public SparseIterator {
ValueRange{tupleId, C_IDX(tupleSz)});
}
- void storeItVals(OpBuilder &b, Location l, Value tupleId,
- ValueRange itVals) const {
+ void storeCursorVals(OpBuilder &b, Location l, Value tupleId,
+ ValueRange itVals) const {
assert(itVals.size() == tupleSz);
for (unsigned i = 0; i < tupleSz; i++) {
b.create<memref::StoreOp>(l, itVals[i], subSectPosBuf,
@@ -528,7 +554,8 @@ class NonEmptySubSectIterator : public SparseIterator {
}
}
- SmallVector<Value> loadItVals(OpBuilder &b, Location l, Value tupleId) const {
+ SmallVector<Value> loadCursorVals(OpBuilder &b, Location l,
+ Value tupleId) const {
SmallVector<Value> ret;
for (unsigned i = 0; i < tupleSz; i++) {
Value v = b.create<memref::LoadOp>(l, subSectPosBuf,
@@ -560,7 +587,7 @@ class NonEmptySubSectIterator : public SparseIterator {
void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override;
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
Value absOff = crd;
if (isSubSectRoot())
@@ -576,9 +603,11 @@ class NonEmptySubSectIterator : public SparseIterator {
return SUBI(wrapCrd, getAbsOff());
}
- Value genNotEnd(OpBuilder &b, Location l) override { return getNotEnd(); };
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
+ return getNotEnd();
+ };
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
// Use the relative offset to coiterate.
Value crd;
auto *p = dyn_cast_or_null<NonEmptySubSectIterator>(parent);
@@ -638,7 +667,7 @@ class SubSectIterator : public SparseIterator {
std::unique_ptr<SparseIterator> &&wrap, Value size,
unsigned stride)
: SparseIterator(IterKind::kSubSect, *wrap,
- /*extraVal=*/wrap->randomAccessible() ? 0 : 1),
+ /*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
stride(stride), helper(*this) {
assert(stride == 1 && "Not implemented.");
@@ -651,6 +680,16 @@ class SubSectIterator : public SparseIterator {
return from->kind == IterKind::kSubSect;
}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("subsect<") + wrap->getDebugInterfacePrefix() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ SmallVector<Type> ret = wrap->getCursorValTypes(b);
+ if (!randomAccessible())
+ ret.push_back(b.getIndexType()); // The extra counter.
+ return ret;
+ }
+
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return randomAccessible(); };
Value upperBound(OpBuilder &b, Location l) const override { return size; }
@@ -662,7 +701,7 @@ class SubSectIterator : public SparseIterator {
if (randomAccessible()) {
return ADDI(getCrd(), nxLvlTupleStart);
};
- return ADDI(getItVals().back(), nxLvlTupleStart);
+ return ADDI(getCursor().back(), nxLvlTupleStart);
}
void genInitImpl(OpBuilder &b, Location l, const SparseIterator *) override {
@@ -680,10 +719,10 @@ class SubSectIterator : public SparseIterator {
return;
}
assert(!randomAccessible());
- assert(getItVals().size() == wrap->getItVals().size() + 1);
+ assert(getCursor().size() == wrap->getCursor().size() + 1);
// Extra counter that counts the number of actually visited coordinates in
// the sparse subsection.
- getMutItVals().back() = C_IDX(0);
+ getMutCursorVals().back() = C_IDX(0);
Value tupleId;
if (auto *p = llvm::dyn_cast<SubSectIterator>(&parent)) {
assert(p->lvl + 1 == lvl);
@@ -696,16 +735,16 @@ class SubSectIterator : public SparseIterator {
helper.deserializeFromTupleId(b, l, tupleId);
}
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
helper.locate(b, l, crd);
updateCrd(crd);
}
- Value genNotEnd(OpBuilder &b, Location l) override {
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
return helper.genNotEnd(b, l);
}
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
Value crd = helper.deref(b, l);
updateCrd(crd);
return crd;
@@ -714,9 +753,9 @@ class SubSectIterator : public SparseIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override {
helper.forward(b, l);
assert(!randomAccessible());
- assert(getItVals().size() == wrap->getItVals().size() + 1);
- getMutItVals().back() = ADDI(getItVals().back(), C_IDX(1));
- return getItVals();
+ assert(getCursor().size() == wrap->getCursor().size() + 1);
+ getMutCursorVals().back() = ADDI(getCursor().back(), C_IDX(1));
+ return getCursor();
};
Value nxLvlTupleStart;
@@ -737,30 +776,82 @@ class SubSectIterator : public SparseIterator {
// SparseIterator derived classes implementation.
//===----------------------------------------------------------------------===//
+DebugSparseIteration SparseIterator::emitStrategy = DebugSparseIteration::kNone;
+
void SparseIterator::genInit(OpBuilder &b, Location l,
const SparseIterator *p) {
+ if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ std::string prefix = getDebugInterfacePrefix();
+ Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
+ getCursorValTypes(b));
+ seek(begin->getResults());
+ return;
+ }
// TODO: support lowering to function call.
return genInitImpl(b, l, p);
}
-ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
+ if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ std::string prefix = getDebugInterfacePrefix();
+ Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
+ getCursor(), b.getI1Type());
+ return notEnd->getResult(0);
+ }
// TODO: support lowering to function call.
+ return genNotEndImpl(b, l);
+}
+
+void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
+ if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ std::string prefix = getDebugInterfacePrefix();
+ SmallVector<Value> args = getCursor();
+ args.push_back(crd);
+ Operation *locate = b.create(l, b.getStringAttr(prefix + ".locate"), args,
+ getCursorValTypes(b));
+ seek(locate->getResults());
+ updateCrd(crd);
+ return;
+ }
+ return locateImpl(b, l, crd);
+}
+
+Value SparseIterator::deref(OpBuilder &b, Location l) {
+ if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ std::string prefix = getDebugInterfacePrefix();
+ SmallVector<Value> args = getCursor();
+ Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
+ getCursor(), b.getIndexType());
+ updateCrd(deref->getResult(0));
+ return getCrd();
+ }
+ return derefImpl(b, l);
+}
+
+ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
+ if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ std::string prefix = getDebugInterfacePrefix();
+ Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
+ getCursor(), getCursorValTypes(b));
+ seek(next->getResults());
+ return getCursor();
+ }
return forwardImpl(b, l);
}
ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
- auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), cond, true);
+ auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), cond, true);
// Generate else branch first, otherwise iterator values will be updated by
// `forward()`.
b.setInsertionPointToStart(ifOp.elseBlock());
- YIELD(getItVals());
+ YIELD(getCursor());
b.setInsertionPointToStart(ifOp.thenBlock());
YIELD(forward(b, l));
b.setInsertionPointAfter(ifOp);
seek(ifOp.getResults());
- return getItVals();
+ return getCursor();
}
Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
@@ -817,7 +908,7 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
return r.front();
}
-Value FilterIterator::genNotEnd(OpBuilder &b, Location l) {
+Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
assert(!wrap->randomAccessible());
auto r = genWhenInBound(
b, l, *wrap, C_FALSE,
@@ -844,7 +935,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
// forward a subsection).
Value isFirst = constantI1(b, l, true);
- SmallVector<Value> whileArgs(getItVals().begin(), getItVals().end());
+ SmallVector<Value> whileArgs(getCursor().begin(), getCursor().end());
whileArgs.push_back(isFirst);
auto whileOp = b.create<scf::WhileOp>(
l, ValueRange(whileArgs).getTypes(), whileArgs,
@@ -870,14 +961,14 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
[this](OpBuilder &b, Location l, ValueRange ivs) {
linkNewScope(ivs);
wrap->forward(b, l);
- SmallVector<Value> yieldVals(getItVals().begin(), getItVals().end());
+ SmallVector<Value> yieldVals(getCursor().begin(), getCursor().end());
yieldVals.push_back(constantI1(b, l, false));
YIELD(yieldVals);
});
b.setInsertionPointAfter(whileOp);
linkNewScope(whileOp.getResults());
- return getItVals();
+ return getCursor();
}
SubSectIterHelper::SubSectIterHelper(const NonEmptySubSectIterator &subSect)
@@ -889,7 +980,7 @@ SubSectIterHelper::SubSectIterHelper(const SubSectIterator &iter)
void SubSectIterHelper::deserializeFromTupleId(OpBuilder &b, Location l,
Value tupleId) {
assert(!subSect.randomAccessible());
- wrap.deserialize(subSect.loadItVals(b, l, tupleId));
+ wrap.deserialize(subSect.loadCursorVals(b, l, tupleId));
}
void SubSectIterHelper::locate(OpBuilder &b, Location l, Value crd) {
@@ -943,7 +1034,7 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
// is corresponding to the current node.
helper.subSect.storeNxLvlStart(b, l, tupleId, cnt);
- SmallVector<Value> whileArgs(helper.wrap.getItVals());
+ SmallVector<Value> whileArgs(helper.wrap.getCursor());
whileArgs.append(iterArgs.begin(), iterArgs.end());
auto whileOp = b.create<scf::WhileOp>(
@@ -1039,7 +1130,7 @@ void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
.front();
// Cache the sparse range.
- storeItVals(b, l, tupleId, helper.wrap.serialize());
+ storeCursorVals(b, l, tupleId, helper.wrap.serialize());
tupleId = ADDI(tupleId, C_IDX(1));
return {minCrd, tupleId};
});
@@ -1068,7 +1159,7 @@ void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
// Only have one root node.
tupleCnt = C_IDX(1);
// Cache the sparse range.
- storeItVals(b, l, c0, delegate->serialize());
+ storeCursorVals(b, l, c0, delegate->serialize());
SmallVector<Value> elseRet{c0, c0, /*notEnd=*/C_FALSE};
auto meta = genWhenInBound(
b, l, *delegate, elseRet,
@@ -1095,7 +1186,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
// if (offset + size > parents.size)
// isNonEmpty = false;
Value fastPathP = CMPI(ugt, getMinCrd(), getAbsOff());
- auto ifOp = b.create<scf::IfOp>(l, getItVals().getTypes(), fastPathP, true);
+ auto ifOp = b.create<scf::IfOp>(l, getCursor().getTypes(), fastPathP, true);
{
OpBuilder::InsertionGuard guard(b);
// Take the fast path
@@ -1134,7 +1225,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
// Update the forwarded iterator values if needed.
auto ifIsMin = b.create<scf::IfOp>(l, isMin, false);
b.setInsertionPointToStart(&ifIsMin.getThenRegion().front());
- storeItVals(b, l, tupleId, delegate->serialize());
+ storeCursorVals(b, l, tupleId, delegate->serialize());
b.setInsertionPointAfter(ifIsMin);
// if (!wrap.end())
// yield(min(nxMinCrd, *wrap), true)
@@ -1172,7 +1263,7 @@ ValueRange NonEmptySubSectIterator::forwardImpl(OpBuilder &b, Location l) {
nxNotEnd = ANDI(nxNotEnd, CMPI(ult, crd, upperBound(b, l)));
seek(ValueRange{nxMinCrd, nxAbsOff, nxNotEnd});
- return getItVals();
+ return getCursor();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index bf115712bdfc1..2faa2a8de5651 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -10,13 +10,16 @@
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
namespace mlir {
namespace sparse_tensor {
-/// The base class for all types of sparse tensor levels. It provides interfaces
-/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
-/// `peekCrdAt`).
+class ConcreteIterator;
+
+/// The base class for all types of sparse tensor levels. It provides
+/// interfaces to query the loop range (see `peekRangeAt`) and look up the
+/// coordinates (see `peekCrdAt`).
class SparseTensorLevel {
SparseTensorLevel(SparseTensorLevel &&) = delete;
SparseTensorLevel(const SparseTensorLevel &) = delete;
@@ -26,6 +29,10 @@ class SparseTensorLevel {
public:
virtual ~SparseTensorLevel() = default;
+ std::string toString() const {
+ return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
+ std::to_string(lvl) + "]";
+ }
virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
@@ -46,7 +53,17 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
- Value size() const { return lvlSize; }
+ Value getSize() const { return lvlVals.front(); }
+ Value getCrdBuf() const {
+ assert(lvlVals.size() > 1);
+ return lvlVals[1];
+ }
+ Value getPosBuf() const {
+ assert(lvlVals.size() > 2);
+ return lvlVals[2];
+ }
+ ValueRange getLvlVals() const { return lvlVals; }
+ ValueRange getLvlBufs() const { return ValueRange(lvlVals).drop_front(); }
//
// Level properties
@@ -55,12 +72,24 @@ class SparseTensorLevel {
protected:
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
- : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
+ : tid(tid), lvl(lvl), lt(lt), lvlVals() {
+ lvlVals.push_back(lvlSize);
+ };
+
+ SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize,
+ ValueRange lvlBufs)
+ : tid(tid), lvl(lvl), lt(lt), lvlVals() {
+ lvlVals.push_back(lvlSize);
+ lvlVals.append(lvlBufs.begin(), lvlBufs.end());
+ };
public:
const unsigned tid, lvl;
const LevelType lt;
- const Value lvlSize;
+ // The first value in the vector is always lvlsize; for sparse levels, the
+ // second value is always the coordinate buffer; for sparse level with
+ // position buffers, the third value is always the position buffer.
+ SmallVector<Value, 3> lvlVals;
};
enum class IterKind : uint8_t {
@@ -80,37 +109,47 @@ class SparseIterator {
SparseIterator &operator=(const SparseIterator &) = delete;
protected:
- SparseIterator(IterKind kind, unsigned tid, unsigned lvl, unsigned itValsCnt,
- SmallVectorImpl<Value> &storage)
- : kind(kind), tid(tid), lvl(lvl), crd(nullptr), itValsCnt(itValsCnt),
- itValsStorageRef(storage){};
-
- SparseIterator(IterKind kind, const SparseIterator &wrap)
- : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
- itValsCnt(wrap.itValsCnt), itValsStorageRef(wrap.itValsStorageRef) {
- assert(wrap.itValsCnt == itValsStorageRef.size());
- };
-
- SparseIterator(IterKind kind, const SparseIterator &wrap, unsigned extraVal)
- : kind(kind), tid(wrap.tid), lvl(wrap.lvl), crd(nullptr),
- itValsCnt(wrap.itValsCnt + extraVal),
- itValsStorageRef(wrap.itValsStorageRef) {
- itValsStorageRef.append(extraVal, nullptr);
- assert(itValsCnt == itValsStorageRef.size());
+ SparseIterator(IterKind kind, unsigned tid, unsigned lvl,
+ unsigned cursorValsCnt,
+ SmallVectorImpl<Value> &cursorValStorage)
+ : kind(kind), tid(tid), lvl(lvl), crd(nullptr),
+ cursorValsCnt(cursorValsCnt), cursorValsStorageRef(cursorValStorage){};
+
+ SparseIterator(IterKind kind, unsigned cursorValsCnt,
+ SmallVectorImpl<Value> &cursorValStorage,
+ const SparseIterator &delegate)
+ : SparseIterator(kind, delegate.tid, delegate.lvl, cursorValsCnt,
+ cursorValStorage){};
+
+ SparseIterator(IterKind kind, const SparseIterator &wrap,
+ unsigned extraCursorCnt = 0)
+ : SparseIterator(kind, wrap.tid, wrap.lvl,
+ extraCursorCnt + wrap.cursorValsCnt,
+ wrap.cursorValsStorageRef) {
+ assert(wrap.cursorValsCnt == wrap.cursorValsStorageRef.size());
+ cursorValsStorageRef.append(extraCursorCnt, nullptr);
+ assert(cursorValsStorageRef.size() == wrap.cursorValsCnt + extraCursorCnt);
};
public:
virtual ~SparseIterator() = default;
+ static void setDebugSparseIteration(DebugSparseIteration strategy) {
+ SparseIterator::emitStrategy = strategy;
+ }
+
+ virtual std::string getDebugInterfacePrefix() const = 0;
+ virtual SmallVector<Type> getCursorValTypes(OpBuilder &b) const = 0;
+
Value getCrd() const { return crd; }
- ValueRange getItVals() const {
- return ValueRange(itValsStorageRef).take_front(itValsCnt);
+ ValueRange getCursor() const {
+ return ValueRange(cursorValsStorageRef).take_front(cursorValsCnt);
};
// Sets the iterate to the specified position.
void seek(ValueRange vals) {
- assert(vals.size() == itValsCnt);
- std::copy(vals.begin(), vals.end(), itValsStorageRef.begin());
+ assert(vals.size() == cursorValsCnt);
+ std::copy(vals.begin(), vals.end(), cursorValsStorageRef.begin());
// Now that the iterator is re-positioned, the coordinate becomes invalid.
crd = nullptr;
}
@@ -120,20 +159,21 @@ class SparseIterator {
//
// Whether the iterator support random access (i.e., support look up by
- // *coordinate*). A random access iterator must also traverses a dense space.
+ // *coordinate*). A random access iterator must also traverses a dense
+ // space.
virtual bool randomAccessible() const = 0;
// Whether the iterator can simply traversed by a for loop.
virtual bool iteratableByFor() const { return false; };
- // Get the upper bound of the sparse space that the iterator might visited. A
- // sparse space is a subset of a dense space [0, bound), this function returns
- // *bound*.
+ // Get the upper bound of the sparse space that the iterator might visited.
+ // A sparse space is a subset of a dense space [0, bound), this function
+ // returns *bound*.
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
- // Serializes and deserializes the current status to/from a set of values. The
- // ValueRange should contain values that are sufficient to recover the current
- // iterating postion (i.e., itVals) as well as loop bound.
+ // Serializes and deserializes the current status to/from a set of values.
+ // The ValueRange should contain values that are sufficient to recover the
+ // current iterating postion (i.e., itVals) as well as loop bound.
//
// Not every type of iterator supports the operations, e.g., non-empty
// subsection iterator does not because the the number of non-empty
@@ -155,23 +195,31 @@ class SparseIterator {
// Forwards the iterator to the next element.
ValueRange forward(OpBuilder &b, Location l);
- // Actual Implementation provided by derived class.
- virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
- virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+ // be done on an iterator that supports randm access.
+ void locate(OpBuilder &b, Location l, Value crd);
// Returns a boolean value that equals `!it.end()`
- virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
+ Value genNotEnd(OpBuilder &b, Location l);
// Dereferences the iterator, loads the coordinate at the current position.
//
// The method assumes that the iterator is not currently exhausted (i.e.,
// it != it.end()).
- virtual Value deref(OpBuilder &b, Location l) = 0;
+ Value deref(OpBuilder &b, Location l);
- // Gets the current position and the optional *position high* (for non-unique
- // iterators), the value is essentially the number of sparse coordinate that
- // the iterator is current visiting. It should be able to uniquely identify
- // the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
+ // Actual Implementation provided by derived class.
+ virtual void genInitImpl(OpBuilder &, Location, const SparseIterator *) = 0;
+ virtual ValueRange forwardImpl(OpBuilder &b, Location l) = 0;
+ virtual void locateImpl(OpBuilder &b, Location l, Value crd) {
+ llvm_unreachable("Unsupported");
+ }
+ virtual Value genNotEndImpl(OpBuilder &b, Location l) = 0;
+ virtual Value derefImpl(OpBuilder &b, Location l) = 0;
+ // Gets the current position and the optional *position high* (for
+ // non-unique iterators), the value is essentially the number of sparse
+ // coordinate that the iterator is current visiting. It should be able to
+ // uniquely identify the sparse range for the next level. See
+ // SparseTensorLevel::peekRangeAt();
//
// Not every type of iterator supports the operation, e.g., non-empty
// subsection iterator does not because it represent a range of coordinates
@@ -202,33 +250,29 @@ class SparseIterator {
// yield it
//
// The function is virtual to allow alternative implementation. For example,
- // if it.next() is trivial to compute, we can use a select operation instead.
- // E.g.,
+ // if it.next() is trivial to compute, we can use a select operation
+ // instead. E.g.,
//
// it = select cond ? it+1 : it
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
- // Locate the iterator to the position specified by *crd*, this can only
- // be done on an iterator that supports randm access.
- virtual void locate(OpBuilder &b, Location l, Value crd) {
- llvm_unreachable("Unsupported");
- }
-
// Update the SSA value for the iterator after entering a new scope.
ValueRange linkNewScope(ValueRange pos) {
assert(!randomAccessible() && "random accessible iterators are traversed "
"by coordinate, call locate() instead.");
- seek(pos.take_front(itValsCnt));
- return pos.drop_front(itValsCnt);
+ seek(pos.take_front(cursorValsCnt));
+ return pos.drop_front(cursorValsCnt);
};
protected:
void updateCrd(Value crd) { this->crd = crd; }
- MutableArrayRef<Value> getMutItVals() {
- MutableArrayRef<Value> ref = itValsStorageRef;
- return ref.take_front(itValsCnt);
+ MutableArrayRef<Value> getMutCursorVals() {
+ MutableArrayRef<Value> ref = cursorValsStorageRef;
+ return ref.take_front(cursorValsCnt);
}
+ static DebugSparseIteration emitStrategy;
+
public:
const IterKind kind; // For LLVM-style RTTI.
const unsigned tid, lvl; // tensor level identifier.
@@ -239,14 +283,14 @@ class SparseIterator {
// A range of value that together defines the current state of the
// iterator. Only loop variants should be included.
//
- // For trivial iterators, it is the position; for dedup iterators, it consists
- // of the positon and the segment high, for non-empty subsection iterator, it
- // is the metadata that specifies the subsection.
- // Note that the wrapped iterator shares the same storage to maintain itVals
- // with it wrapper, which means the wrapped iterator might only own a subset
- // of all the values stored in itValStorage.
- const unsigned itValsCnt;
- SmallVectorImpl<Value> &itValsStorageRef;
+ // For trivial iterators, it is the position; for dedup iterators, it
+ // consists of the positon and the segment high, for non-empty subsection
+ // iterator, it is the metadata that specifies the subsection. Note that the
+ // wrapped iterator shares the same storage to maintain itVals with it
+ // wrapper, which means the wrapped iterator might only own a subset of all
+ // the values stored in itValStorage.
+ const unsigned cursorValsCnt;
+ SmallVectorImpl<Value> &cursorValsStorageRef;
};
/// Helper function to create a TensorLevel object from given `tensor`.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 70cf0f9af45b5..18118cab9e52c 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,5 +1,4 @@
-// TODO: re-enable after lowering coo.next to function call (such that loop structure is more clear).
-// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="debug-sparse-iteration=interface-only" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
@@ -8,233 +7,54 @@
#DCSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
+
// CHECK-LABEL: func.func @conv2d_all_sparse_CSR(
-// C_HECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse>,
-// C_HECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse> {
-// C_HECK-DAG: %[[VAL_2:.*]] = arith.constant true
-// C_HECK-DAG: %[[VAL_3:.*]] = arith.constant -2 : index
-// C_HECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index
-// C_HECK-DAG: %[[VAL_5:.*]] = arith.constant 8 : index
-// C_HECK-DAG: %[[VAL_6:.*]] = arith.constant 3 : index
-// C_HECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index
-// C_HECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// C_HECK-DAG: %[[VAL_9:.*]] = arith.constant 0 : i32
-// C_HECK-DAG: %[[VAL_10:.*]] = arith.constant false
-// C_HECK-DAG: %[[VAL_11:.*]] = tensor.empty() : tensor<6x6xi32, #sparse>
-// C_HECK-DAG: %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG: %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG: %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG: %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xi32, #sparse> to memref<?xindex>
-// C_HECK-DAG: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xi32, #sparse> to memref<?xi32>
-// C_HECK-DAG: %[[VAL_17:.*]] = memref.alloca() : memref<9xindex>
-// C_HECK-DAG: %[[VAL_18:.*]] = memref.alloca() : memref<3xindex>
-// C_HECK-DAG: %[[POS_LO:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_8]]] : memref<?xindex>
-// C_HECK-DAG: %[[POS_HI:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_7]]] : memref<?xindex>
-// C_HECK: memref.store %[[POS_LO]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK: memref.store %[[POS_HI]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK: %[[VAL_20:.*]] = arith.cmpi ult, %[[POS_LO]], %[[POS_HI]] : index
-// C_HECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[POS_LO]]] : memref<?xindex>
-// C_HECK: %[[VAL_22:.*]] = arith.cmpi uge, %[[VAL_21]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_23:.*]] = arith.andi %[[VAL_20]], %[[VAL_22]] : i1
-// C_HECK: %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
-// C_HECK: %[[VAL_25:.*]] = arith.select %[[VAL_23]], %[[VAL_24]], %[[VAL_8]] : index
-// C_HECK: %[[VAL_26:.*]]:3 = scf.while (%[[VAL_27:.*]] = %[[VAL_20]], %[[VAL_28:.*]] = %[[VAL_21]], %[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_11]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// C_HECK: scf.condition(%[[VAL_27]]) %[[VAL_28]], %[[VAL_29]], %[[VAL_30]] : index, index, tensor<6x6xi32, #sparse>
-// C_HECK: } do {
-// C_HECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: tensor<6x6xi32, #sparse>):
-// C_HECK: %[[VAL_34:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK: %[[VAL_35:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK: memref.store %[[VAL_8]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<3xindex>
-// C_HECK: %[[VAL_36:.*]] = arith.addi %[[VAL_32]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_37:.*]]:5 = scf.while (%[[VAL_38:.*]] = %[[VAL_34]], %[[VAL_39:.*]] = %[[VAL_10]], %[[VAL_40:.*]] = %[[VAL_5]], %[[VAL_41:.*]] = %[[VAL_8]], %[[VAL_42:.*]] = %[[VAL_8]]) : (index, i1, index, index, index) -> (index, i1, index, index, index) {
-// C_HECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_38]], %[[VAL_35]] : index
-// C_HECK: %[[VAL_44:.*]] = scf.if %[[VAL_43]] -> (i1) {
-// C_HECK: %[[VAL_45:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
-// C_HECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_36]] : index
-// C_HECK: scf.yield %[[VAL_46]] : i1
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_10]] : i1
-// C_HECK: }
-// C_HECK: scf.condition(%[[VAL_44]]) %[[VAL_38]], %[[VAL_39]], %[[VAL_40]], %[[VAL_41]], %[[VAL_42]] : index, i1, index, index, index
-// C_HECK: } do {
-// C_HECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: i1, %[[VAL_49:.*]]: index, %[[VAL_50:.*]]: index, %[[VAL_51:.*]]: index):
-// C_HECK-DAG: %[[VAL_52:.*]] = arith.addi %[[VAL_47]], %[[VAL_7]] : index
-// C_HECK-DAG: %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_47]]] : memref<?xindex>
-// C_HECK: %[[VAL_54:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_52]]] : memref<?xindex>
-// C_HECK: %[[VAL_55:.*]] = arith.cmpi ult, %[[VAL_53]], %[[VAL_54]] : index
-// C_HECK: %[[VAL_56:.*]] = arith.ori %[[VAL_55]], %[[VAL_48]] : i1
-// C_HECK: %[[VAL_57:.*]] = scf.if %[[VAL_55]] -> (index) {
-// C_HECK: %[[VAL_58:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_53]]] : memref<?xindex>
-// C_HECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_49]] : index
-// C_HECK: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_58]], %[[VAL_49]] : index
-// C_HECK: scf.yield %[[VAL_60]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_49]] : index
-// C_HECK: }
-// C_HECK: memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_50]]] : memref<9xindex>
-// C_HECK: %[[VAL_61:.*]] = arith.addi %[[VAL_50]], %[[VAL_6]] : index
-// C_HECK: memref.store %[[VAL_54]], %[[VAL_17]]{{\[}}%[[VAL_61]]] : memref<9xindex>
-// C_HECK: %[[VAL_62:.*]] = arith.addi %[[VAL_50]], %[[VAL_7]] : index
-// C_HECK: %[[VAL_63:.*]] = arith.addi %[[VAL_51]], %[[VAL_7]] : index
-// C_HECK: scf.yield %[[VAL_52]], %[[VAL_56]], %[[VAL_57]], %[[VAL_62]], %[[VAL_63]] : index, i1, index, index, index
-// C_HECK: }
-// C_HECK: %[[VAL_64:.*]] = arith.cmpi uge, %[[VAL_65:.*]]#2, %[[VAL_6]] : index
-// C_HECK: %[[VAL_66:.*]] = arith.andi %[[VAL_65]]#1, %[[VAL_64]] : i1
-// C_HECK: %[[VAL_67:.*]] = arith.addi %[[VAL_65]]#2, %[[VAL_3]] : index
-// C_HECK: %[[VAL_68:.*]] = arith.select %[[VAL_66]], %[[VAL_67]], %[[VAL_8]] : index
-// C_HECK: %[[VAL_69:.*]]:3 = scf.while (%[[VAL_70:.*]] = %[[VAL_65]]#1, %[[VAL_71:.*]] = %[[VAL_65]]#2, %[[VAL_72:.*]] = %[[VAL_68]], %[[VAL_73:.*]] = %[[VAL_33]]) : (i1, index, index, tensor<6x6xi32, #sparse>) -> (index, index, tensor<6x6xi32, #sparse>) {
-// C_HECK: scf.condition(%[[VAL_70]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, tensor<6x6xi32, #sparse>
-// C_HECK: } do {
-// C_HECK: ^bb0(%[[VAL_74:.*]]: index, %[[VAL_75:.*]]: index, %[[VAL_76:.*]]: tensor<6x6xi32, #sparse>):
-// C_HECK: %[[VAL_77:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK: %[[VAL_78:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK: %[[VAL_79:.*]]:3 = scf.while (%[[VAL_80:.*]] = %[[VAL_77]], %[[VAL_81:.*]] = %[[VAL_9]], %[[VAL_82:.*]] = %[[VAL_10]]) : (index, i32, i1) -> (index, i32, i1) {
-// C_HECK: %[[VAL_83:.*]] = arith.cmpi ult, %[[VAL_80]], %[[VAL_78]] : index
-// C_HECK: %[[VAL_84:.*]] = scf.if %[[VAL_83]] -> (i1) {
-// C_HECK: %[[VAL_85:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_80]]] : memref<?xindex>
-// C_HECK: %[[VAL_86:.*]] = arith.cmpi ult, %[[VAL_85]], %[[VAL_36]] : index
-// C_HECK: scf.yield %[[VAL_86]] : i1
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_10]] : i1
-// C_HECK: }
-// C_HECK: scf.condition(%[[VAL_84]]) %[[VAL_80]], %[[VAL_81]], %[[VAL_82]] : index, i32, i1
-// C_HECK: } do {
-// C_HECK: ^bb0(%[[VAL_87:.*]]: index, %[[VAL_88:.*]]: i32, %[[VAL_89:.*]]: i1):
-// C_HECK: %[[VAL_90:.*]] = arith.subi %[[VAL_87]], %[[VAL_77]] : index
-// C_HECK: %[[VAL_91:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_87]]] : memref<?xindex>
-// C_HECK: %[[VAL_92:.*]] = arith.subi %[[VAL_91]], %[[VAL_32]] : index
-// C_HECK: %[[VAL_93:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_90]]] : memref<9xindex>
-// C_HECK: %[[VAL_94:.*]] = arith.addi %[[VAL_90]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_95:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_94]]] : memref<9xindex>
-// C_HECK: %[[VAL_96:.*]] = arith.addi %[[VAL_75]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_97:.*]]:2 = scf.while (%[[VAL_98:.*]] = %[[VAL_93]], %[[VAL_99:.*]] = %[[VAL_88]]) : (index, i32) -> (index, i32) {
-// C_HECK: %[[VAL_100:.*]] = arith.cmpi ult, %[[VAL_98]], %[[VAL_95]] : index
-// C_HECK: %[[VAL_101:.*]] = scf.if %[[VAL_100]] -> (i1) {
-// C_HECK: %[[VAL_102:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_98]]] : memref<?xindex>
-// C_HECK: %[[VAL_103:.*]] = arith.cmpi ult, %[[VAL_102]], %[[VAL_96]] : index
-// C_HECK: scf.yield %[[VAL_103]] : i1
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_10]] : i1
-// C_HECK: }
-// C_HECK: scf.condition(%[[VAL_101]]) %[[VAL_98]], %[[VAL_99]] : index, i32
-// C_HECK: } do {
-// C_HECK: ^bb0(%[[VAL_104:.*]]: index, %[[VAL_105:.*]]: i32):
-// C_HECK: %[[VAL_106:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_104]]] : memref<?xindex>
-// C_HECK: %[[VAL_107:.*]] = arith.subi %[[VAL_106]], %[[VAL_75]] : index
-// C_HECK: %[[VAL_108:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_104]]] : memref<?xi32>
-// C_HECK: %[[VAL_109:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_92]], %[[VAL_107]]] : tensor<3x3xi32>
-// C_HECK: %[[VAL_110:.*]] = arith.muli %[[VAL_108]], %[[VAL_109]] : i32
-// C_HECK: %[[VAL_111:.*]] = arith.addi %[[VAL_105]], %[[VAL_110]] : i32
-// C_HECK: %[[VAL_112:.*]] = arith.addi %[[VAL_104]], %[[VAL_7]] : index
-// C_HECK: scf.yield %[[VAL_112]], %[[VAL_111]] : index, i32
-// C_HECK: }
-// C_HECK: %[[VAL_113:.*]] = arith.addi %[[VAL_87]], %[[VAL_7]] : index
-// C_HECK: scf.yield %[[VAL_113]], %[[VAL_114:.*]]#1, %[[VAL_2]] : index, i32, i1
-// C_HECK: }
-// C_HECK: %[[VAL_115:.*]] = scf.if %[[VAL_116:.*]]#2 -> (tensor<6x6xi32, #sparse>) {
-// C_HECK: %[[VAL_117:.*]] = sparse_tensor.insert %[[VAL_116]]#1 into %[[VAL_76]]{{\[}}%[[VAL_32]], %[[VAL_75]]] : tensor<6x6xi32, #sparse>
-// C_HECK: scf.yield %[[VAL_117]] : tensor<6x6xi32, #sparse>
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_76]] : tensor<6x6xi32, #sparse>
-// C_HECK: }
-// C_HECK: %[[VAL_118:.*]] = arith.cmpi ugt, %[[VAL_74]], %[[VAL_75]] : index
-// C_HECK: %[[VAL_119:.*]]:3 = scf.if %[[VAL_118]] -> (index, i1, index) {
-// C_HECK: %[[VAL_120:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// C_HECK: scf.yield %[[VAL_74]], %[[VAL_2]], %[[VAL_120]] : index, i1, index
-// C_HECK: } else {
-// C_HECK: %[[VAL_121:.*]]:2 = scf.for %[[VAL_122:.*]] = %[[VAL_8]] to %[[VAL_65]]#3 step %[[VAL_7]] iter_args(%[[VAL_123:.*]] = %[[VAL_5]], %[[VAL_124:.*]] = %[[VAL_10]]) -> (index, i1) {
-// C_HECK: %[[VAL_125:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// C_HECK: %[[VAL_126:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_127:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_126]]] : memref<9xindex>
-// C_HECK: %[[VAL_128:.*]] = arith.cmpi ult, %[[VAL_125]], %[[VAL_127]] : index
-// C_HECK: %[[VAL_129:.*]] = scf.if %[[VAL_128]] -> (index) {
-// C_HECK: %[[VAL_130:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_125]]] : memref<?xindex>
-// C_HECK: %[[VAL_131:.*]] = arith.cmpi eq, %[[VAL_130]], %[[VAL_74]] : index
-// C_HECK: %[[VAL_132:.*]] = scf.if %[[VAL_131]] -> (index) {
-// C_HECK: %[[VAL_133:.*]] = arith.addi %[[VAL_125]], %[[VAL_7]] : index
-// C_HECK: memref.store %[[VAL_133]], %[[VAL_17]]{{\[}}%[[VAL_122]]] : memref<9xindex>
-// C_HECK: scf.yield %[[VAL_133]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_125]] : index
-// C_HECK: }
-// C_HECK: scf.yield %[[VAL_132]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_125]] : index
-// C_HECK: }
-// C_HECK: %[[VAL_134:.*]] = arith.cmpi ult, %[[VAL_129]], %[[VAL_127]] : index
-// C_HECK: %[[VAL_135:.*]] = scf.if %[[VAL_134]] -> (index) {
-// C_HECK: %[[VAL_136:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_129]]] : memref<?xindex>
-// C_HECK: scf.yield %[[VAL_136]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_123]] : index
-// C_HECK: }
-// C_HECK: %[[VAL_137:.*]] = arith.ori %[[VAL_134]], %[[VAL_124]] : i1
-// C_HECK: %[[VAL_138:.*]] = arith.cmpi ult, %[[VAL_135]], %[[VAL_123]] : index
-// C_HECK: %[[VAL_139:.*]] = arith.select %[[VAL_138]], %[[VAL_135]], %[[VAL_123]] : index
-// C_HECK: scf.yield %[[VAL_139]], %[[VAL_137]] : index, i1
-// C_HECK: }
-// C_HECK: %[[VAL_140:.*]] = arith.addi %[[VAL_141:.*]]#0, %[[VAL_7]] : index
-// C_HECK: %[[VAL_142:.*]] = arith.addi %[[VAL_141]]#0, %[[VAL_3]] : index
-// C_HECK: %[[VAL_143:.*]] = arith.cmpi uge, %[[VAL_140]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_144:.*]] = arith.select %[[VAL_143]], %[[VAL_142]], %[[VAL_8]] : index
-// C_HECK: scf.yield %[[VAL_141]]#0, %[[VAL_141]]#1, %[[VAL_144]] : index, i1, index
-// C_HECK: }
-// C_HECK: %[[VAL_145:.*]] = arith.addi %[[VAL_75]], %[[VAL_7]] : index
-// C_HECK: %[[VAL_146:.*]] = arith.cmpi ugt, %[[VAL_147:.*]]#2, %[[VAL_145]] : index
-// C_HECK: %[[VAL_148:.*]] = arith.select %[[VAL_146]], %[[VAL_147]]#2, %[[VAL_145]] : index
-// C_HECK: %[[VAL_149:.*]] = arith.addi %[[VAL_148]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_150:.*]] = arith.cmpi ule, %[[VAL_149]], %[[VAL_5]] : index
-// C_HECK: %[[VAL_151:.*]] = arith.andi %[[VAL_147]]#1, %[[VAL_150]] : i1
-// C_HECK: scf.yield %[[VAL_151]], %[[VAL_147]]#0, %[[VAL_148]], %[[VAL_115]] : i1, index, index, tensor<6x6xi32, #sparse>
-// C_HECK: }
-// C_HECK: %[[VAL_152:.*]] = arith.cmpi ugt, %[[VAL_31]], %[[VAL_32]] : index
-// C_HECK: %[[VAL_153:.*]]:3 = scf.if %[[VAL_152]] -> (index, i1, index) {
-// C_HECK: %[[VAL_154:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// C_HECK: scf.yield %[[VAL_31]], %[[VAL_2]], %[[VAL_154]] : index, i1, index
-// C_HECK: } else {
-// C_HECK: %[[VAL_155:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK: %[[VAL_156:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<3xindex>
-// C_HECK: %[[VAL_157:.*]] = arith.cmpi ult, %[[VAL_155]], %[[VAL_156]] : index
-// C_HECK: %[[VAL_158:.*]] = scf.if %[[VAL_157]] -> (index) {
-// C_HECK: %[[VAL_159:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_155]]] : memref<?xindex>
-// C_HECK: %[[VAL_160:.*]] = arith.cmpi eq, %[[VAL_159]], %[[VAL_31]] : index
-// C_HECK: %[[VAL_161:.*]] = scf.if %[[VAL_160]] -> (index) {
-// C_HECK: %[[VAL_162:.*]] = arith.addi %[[VAL_155]], %[[VAL_7]] : index
-// C_HECK: memref.store %[[VAL_162]], %[[VAL_18]]{{\[}}%[[VAL_8]]] : memref<3xindex>
-// C_HECK: scf.yield %[[VAL_162]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_155]] : index
-// C_HECK: }
-// C_HECK: scf.yield %[[VAL_161]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_155]] : index
-// C_HECK: }
-// C_HECK: %[[VAL_163:.*]] = arith.cmpi ult, %[[VAL_158]], %[[VAL_156]] : index
-// C_HECK: %[[VAL_164:.*]] = scf.if %[[VAL_163]] -> (index) {
-// C_HECK: %[[VAL_165:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_158]]] : memref<?xindex>
-// C_HECK: scf.yield %[[VAL_165]] : index
-// C_HECK: } else {
-// C_HECK: scf.yield %[[VAL_5]] : index
-// C_HECK: }
-// C_HECK: %[[VAL_166:.*]] = arith.cmpi ult, %[[VAL_164]], %[[VAL_5]] : index
-// C_HECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_164]], %[[VAL_5]] : index
-// C_HECK: %[[VAL_168:.*]] = arith.addi %[[VAL_167]], %[[VAL_7]] : index
-// C_HECK: %[[VAL_169:.*]] = arith.addi %[[VAL_167]], %[[VAL_3]] : index
-// C_HECK: %[[VAL_170:.*]] = arith.cmpi uge, %[[VAL_168]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_171:.*]] = arith.select %[[VAL_170]], %[[VAL_169]], %[[VAL_8]] : index
-// C_HECK: scf.yield %[[VAL_167]], %[[VAL_163]], %[[VAL_171]] : index, i1, index
-// C_HECK: }
-// C_HECK: %[[VAL_172:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index
-// C_HECK: %[[VAL_173:.*]] = arith.cmpi ugt, %[[VAL_174:.*]]#2, %[[VAL_172]] : index
-// C_HECK: %[[VAL_175:.*]] = arith.select %[[VAL_173]], %[[VAL_174]]#2, %[[VAL_172]] : index
-// C_HECK: %[[VAL_176:.*]] = arith.addi %[[VAL_175]], %[[VAL_6]] : index
-// C_HECK: %[[VAL_177:.*]] = arith.cmpi ule, %[[VAL_176]], %[[VAL_5]] : index
-// C_HECK: %[[VAL_178:.*]] = arith.andi %[[VAL_174]]#1, %[[VAL_177]] : i1
-// C_HECK: scf.yield %[[VAL_178]], %[[VAL_174]]#0, %[[VAL_175]], %[[VAL_179:.*]]#2 : i1, index, index, tensor<6x6xi32, #sparse>
-// C_HECK: }
-// C_HECK: %[[VAL_180:.*]] = sparse_tensor.load %[[VAL_181:.*]]#2 hasInserts : tensor<6x6xi32, #sparse>
-// C_HECK: return %[[VAL_180]] : tensor<6x6xi32, #sparse>
-// C_HECK: }
+// CHECK: "ne_sub<trivial<compressed[0,0]>>.begin"
+// CHECK: scf.while {{.*}} {
+// CHECK: "ne_sub<trivial<compressed[0,0]>>.not_end"
+// CHECK: } do {
+// CHECK: %[[D0:.*]] = "ne_sub<trivial<compressed[0,0]>>.deref"
+// CHECK: "ne_sub<trivial<compressed[0,1]>>.begin"
+// CHECK: scf.while {{.*}} {
+// CHECK: "ne_sub<trivial<compressed[0,1]>>.not_end"
+// CHECK: } do {
+// CHECK: %[[D1:.*]] = "ne_sub<trivial<compressed[0,1]>>.deref"
+// CHECK: "subsect<trivial<compressed[0,0]>>.begin"
+// CHECK: scf.while {{.*}} {
+// CHECK: "subsect<trivial<compressed[0,0]>>.not_end
+// CHECK: } do {
+// CHECK: %[[D2:.*]] = "subsect<trivial<compressed[0,0]>>.deref"
+// CHECK: "trivial<dense[1,0]>.locate"(%{{.*}}, %[[D2]])
+// CHECK: "subsect<trivial<compressed[0,1]>>.begin"
+// CHECK: scf.while {{.*}} {
+// CHECK: "subsect<trivial<compressed[0,1]>>.not_end"
+// CHECK: } do {
+// CHECK: %[[D3:.*]] = "subsect<trivial<compressed[0,1]>>.deref"
+// CHECK: "trivial<dense[1,1]>.locate"(%{{.*}}, %[[D3]])
+// CHECK: tensor.extract %{{.*}}{{\[}}%[[D2]], %[[D3]]]
+// CHECK: arith.muli
+// CHECK: arith.addi
+// CHECK: "subsect<trivial<compressed[0,1]>>.next
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: "subsect<trivial<compressed[0,0]>>.next
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: scf.if {{.*}} {
+// CHECK: sparse_tensor.insert %{{.*}} into %{{.*}}{{\[}}%[[D0]], %[[D1]]]
+// CHECK: scf.yield
+// CHECK: } else {
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: "ne_sub<trivial<compressed[0,1]>>.next"
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: "ne_sub<trivial<compressed[0,0]>>.next"
+// CHECK: scf.yield
+// CHECK: }
+// CHECK: sparse_tensor.load
+// CHECK: return
+// CHECK: }
func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>,
%arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> {
%0 = tensor.empty() : tensor<6x6xi32, #DCSR>
>From b5744e8d32609cc50049f719edd974f13d959e73 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 31 Jan 2024 22:15:32 +0000
Subject: [PATCH 2/5] revert unintended change
---
.../Transforms/Utils/SparseTensorLevel.cpp | 105 +++++++++---------
.../Transforms/Utils/SparseTensorLevel.h | 37 ++----
2 files changed, 59 insertions(+), 83 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index bdaf794744bea..604136c3884f9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -45,6 +45,19 @@ using ValueTuple = std::tuple<Value, Value, Value>;
//===----------------------------------------------------------------------===//
namespace {
+class SparseLevel : public SparseTensorLevel {
+public:
+ SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value crdBuffer)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
+
+ Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+ return genIndexLoad(b, l, crdBuffer, iv);
+ }
+
+protected:
+ const Value crdBuffer;
+};
class DenseLevel : public SparseTensorLevel {
public:
@@ -60,27 +73,53 @@ class DenseLevel : public SparseTensorLevel {
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
if (encoded) {
- Value posLo = MULI(p, getSize());
- return {posLo, getSize()};
+ Value posLo = MULI(p, lvlSize);
+ return {posLo, lvlSize};
}
// No need to linearize the position for non-annotated tensors.
- return {C_IDX(0), getSize()};
+ return {C_IDX(0), lvlSize};
}
const bool encoded;
};
-class SparseLevel : public SparseTensorLevel {
+class CompressedLevel : public SparseLevel {
public:
- SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- ValueRange lvlBuf)
- : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
- assert(!lvlBuf.empty());
+ CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ if (max == nullptr) {
+ Value pLo = genIndexLoad(b, l, posBuffer, p);
+ Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+ llvm_unreachable("compressed-nu should be the first non-unique level.");
}
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, getLvlBufs().front(), iv);
+private:
+ const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ assert(max == nullptr && "loss compressed level can not be non-unique.");
+ p = MULI(p, C_IDX(2));
+ Value pLo = genIndexLoad(b, l, posBuffer, p);
+ Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
}
+
+private:
+ const Value posBuffer;
};
class SingletonLevel : public SparseLevel {
@@ -102,8 +141,8 @@ class SingletonLevel : public SparseLevel {
class TwoOutFourLevel : public SparseLevel {
public:
TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuf)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
+ Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
@@ -114,39 +153,6 @@ class TwoOutFourLevel : public SparseLevel {
}
};
-class CompressedLevel : public SparseLevel {
-public:
- CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, getPosBuf(), p);
- Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
- }
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
- LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
- p = MULI(p, C_IDX(2));
- Value pLo = genIndexLoad(b, l, getPosBuf(), p);
- Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
-};
-
} // namespace
//===----------------------------------------------------------------------===//
@@ -195,9 +201,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
//===----------------------------------------------------------------------===//
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
-
-namespace mlir {
-namespace sparse_tensor {
+namespace {
// The iterator that traverses a concrete sparse tensor levels. High-level
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -232,11 +236,6 @@ class ConcreteIterator : public SparseIterator {
SmallVector<Value> cursorValsStorage;
};
-} // namespace sparse_tensor
-} // namespace mlir
-
-namespace {
-
class TrivialIterator : public ConcreteIterator {
public:
TrivialIterator(const SparseTensorLevel &stl)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 2faa2a8de5651..eb75df9feaae9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -15,11 +15,9 @@
namespace mlir {
namespace sparse_tensor {
-class ConcreteIterator;
-
-/// The base class for all types of sparse tensor levels. It provides
-/// interfaces to query the loop range (see `peekRangeAt`) and look up the
-/// coordinates (see `peekCrdAt`).
+/// The base class for all types of sparse tensor levels. It provides interfaces
+/// to query the loop range (see `peekRangeAt`) and look up the coordinates (see
+/// `peekCrdAt`).
class SparseTensorLevel {
SparseTensorLevel(SparseTensorLevel &&) = delete;
SparseTensorLevel(const SparseTensorLevel &) = delete;
@@ -33,6 +31,7 @@ class SparseTensorLevel {
return std::string(toMLIRString(lt)) + "[" + std::to_string(tid) + "," +
std::to_string(lvl) + "]";
}
+
virtual Value peekCrdAt(OpBuilder &b, Location l, Value iv) const = 0;
/// Peeks the lower and upper bound to *fully* traverse the level with
@@ -53,17 +52,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
- Value getSize() const { return lvlVals.front(); }
- Value getCrdBuf() const {
- assert(lvlVals.size() > 1);
- return lvlVals[1];
- }
- Value getPosBuf() const {
- assert(lvlVals.size() > 2);
- return lvlVals[2];
- }
- ValueRange getLvlVals() const { return lvlVals; }
- ValueRange getLvlBufs() const { return ValueRange(lvlVals).drop_front(); }
+ Value getSize() const { return lvlSize; }
//
// Level properties
@@ -72,24 +61,12 @@ class SparseTensorLevel {
protected:
SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize)
- : tid(tid), lvl(lvl), lt(lt), lvlVals() {
- lvlVals.push_back(lvlSize);
- };
-
- SparseTensorLevel(unsigned tid, unsigned lvl, LevelType lt, Value lvlSize,
- ValueRange lvlBufs)
- : tid(tid), lvl(lvl), lt(lt), lvlVals() {
- lvlVals.push_back(lvlSize);
- lvlVals.append(lvlBufs.begin(), lvlBufs.end());
- };
+ : tid(tid), lvl(lvl), lt(lt), lvlSize(lvlSize){};
public:
const unsigned tid, lvl;
const LevelType lt;
- // The first value in the vector is always lvlsize; for sparse levels, the
- // second value is always the coordinate buffer; for sparse level with
- // position buffers, the third value is always the position buffer.
- SmallVector<Value, 3> lvlVals;
+ const Value lvlSize;
};
enum class IterKind : uint8_t {
>From afd24ca0d52ca80a00570eabdd6b4022bac3f5ef Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 31 Jan 2024 22:21:33 +0000
Subject: [PATCH 3/5] revert unintended change
---
.../Transforms/Utils/SparseTensorLevel.cpp | 2 ++
.../Transforms/Utils/SparseTensorLevel.h | 31 +++++++++----------
2 files changed, 17 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 604136c3884f9..97d51dbec4a5e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -45,6 +45,7 @@ using ValueTuple = std::tuple<Value, Value, Value>;
//===----------------------------------------------------------------------===//
namespace {
+
class SparseLevel : public SparseTensorLevel {
public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
@@ -201,6 +202,7 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
//===----------------------------------------------------------------------===//
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
+
namespace {
// The iterator that traverses a concrete sparse tensor levels. High-level
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index eb75df9feaae9..728e2973e83c3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -136,21 +136,20 @@ class SparseIterator {
//
// Whether the iterator support random access (i.e., support look up by
- // *coordinate*). A random access iterator must also traverses a dense
- // space.
+ // *coordinate*). A random access iterator must also traverses a dense space.
virtual bool randomAccessible() const = 0;
// Whether the iterator can simply traversed by a for loop.
virtual bool iteratableByFor() const { return false; };
- // Get the upper bound of the sparse space that the iterator might visited.
- // A sparse space is a subset of a dense space [0, bound), this function
- // returns *bound*.
+ // Get the upper bound of the sparse space that the iterator might visited. A
+ // sparse space is a subset of a dense space [0, bound), this function returns
+ // *bound*.
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
- // Serializes and deserializes the current status to/from a set of values.
- // The ValueRange should contain values that are sufficient to recover the
- // current iterating postion (i.e., itVals) as well as loop bound.
+ // Serializes and deserializes the current status to/from a set of values. The
+ // ValueRange should contain values that are sufficient to recover the current
+ // iterating postion (i.e., itVals) as well as loop bound.
//
// Not every type of iterator supports the operations, e.g., non-empty
// subsection iterator does not because the the number of non-empty
@@ -227,8 +226,8 @@ class SparseIterator {
// yield it
//
// The function is virtual to allow alternative implementation. For example,
- // if it.next() is trivial to compute, we can use a select operation
- // instead. E.g.,
+ // if it.next() is trivial to compute, we can use a select operation instead.
+ // E.g.,
//
// it = select cond ? it+1 : it
virtual ValueRange forwardIf(OpBuilder &b, Location l, Value cond);
@@ -260,12 +259,12 @@ class SparseIterator {
// A range of value that together defines the current state of the
// iterator. Only loop variants should be included.
//
- // For trivial iterators, it is the position; for dedup iterators, it
- // consists of the positon and the segment high, for non-empty subsection
- // iterator, it is the metadata that specifies the subsection. Note that the
- // wrapped iterator shares the same storage to maintain itVals with it
- // wrapper, which means the wrapped iterator might only own a subset of all
- // the values stored in itValStorage.
+ // For trivial iterators, it is the position; for dedup iterators, it consists
+ // of the positon and the segment high, for non-empty subsection iterator, it
+ // is the metadata that specifies the subsection.
+ // Note that the wrapped iterator shares the same storage to maintain itVals
+ // with it wrapper, which means the wrapped iterator might only own a subset
+ // of all the values stored in itValStorage.
const unsigned cursorValsCnt;
SmallVectorImpl<Value> &cursorValsStorageRef;
};
>From f0c9c744b7c2000b181ebf8da6b7a619bd28d6ae Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 1 Feb 2024 20:08:49 +0000
Subject: [PATCH 4/5] rebase
---
.../SparseTensor/Transforms/SparseTensorPasses.cpp | 4 ++--
.../SparseTensor/Transforms/Sparsification.cpp | 2 +-
.../SparseTensor/Transforms/Utils/CodegenEnv.cpp | 2 +-
.../SparseTensor/Transforms/Utils/CodegenEnv.h | 2 +-
.../SparseTensor/Transforms/Utils/LoopEmitter.cpp | 6 +++---
.../SparseTensor/Transforms/Utils/LoopEmitter.h | 4 ++--
.../Transforms/Utils/SparseTensorLevel.cpp | 13 +++++++------
.../Transforms/Utils/SparseTensorLevel.h | 4 ++--
.../SparseTensor/sparse_conv_2d_slice_based.mlir | 2 +-
9 files changed, 20 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 0ae9f6483588d..8b89bd4dcdd03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,14 +82,14 @@ struct SparsificationPass
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
- debugSparseIteration = options.debugSparseIteration;
+ sparseEmitStrategy = options.sparseEmitStrategy;
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
- SparsificationOptions options(parallelization, debugSparseIteration,
+ SparsificationOptions options(parallelization, sparseEmitStrategy,
enableRuntimeLibrary);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 2ceb214052aa3..ab38ab5cc3f78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Recursively generates code if admissible.
- env.startEmit(options.debugSparseIteration);
+ env.startEmit(options.sparseEmitStrategy);
genBuffers(env, rewriter);
// TODO: Constant affine expression should be handled differently when using
// slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index 0af1cc1745f51..86c13d03c7ec6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
return success();
}
-void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
+void CodegenEnv::startEmit(SparseEmitStrategy emitStrategy) {
assert(insChain == nullptr && "must only start emitting once");
if (sparseOut) {
insChain = sparseOut->get();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 7eeddac48f4f1..d69ae53fb0f29 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
- void startEmit(DebugSparseIteration emitStrategy);
+ void startEmit(SparseEmitStrategy emitStrategy);
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 8c1680a393181..70488c34e440c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -82,19 +82,19 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter,
- DebugSparseIteration emitStrategy) {
+ SparseEmitStrategy emitStrategy) {
initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
DependentLvlGetter dimGetter,
- DebugSparseIteration emitStrategy) {
+ SparseEmitStrategy emitStrategy) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
- SparseIterator::setDebugSparseIteration(emitStrategy);
+ SparseIterator::setSparseEmitStrategy(emitStrategy);
const unsigned numManifestTensors = ts.size();
const unsigned synTensorId = numManifestTensors;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index e0b4f81487a68..5bab2c6a86081 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -89,13 +89,13 @@ class LoopEmitter {
initialize(ValueRange tensors, StringAttr loopTag = nullptr,
bool hasOutput = false, bool isSparseOut = false,
unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
- DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+ SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
explicit LoopEmitter(
ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
bool isSparseOut = false, unsigned numLoops = 0,
DependentLvlGetter getter = nullptr,
- DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+ SparseEmitStrategy emitStrategy = SparseEmitStrategy::kFunctional);
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 97d51dbec4a5e..c1fc2a062fa10 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -777,11 +777,12 @@ class SubSectIterator : public SparseIterator {
// SparseIterator derived classes implementation.
//===----------------------------------------------------------------------===//
-DebugSparseIteration SparseIterator::emitStrategy = DebugSparseIteration::kNone;
+SparseEmitStrategy SparseIterator::emitStrategy =
+ SparseEmitStrategy::kFunctional;
void SparseIterator::genInit(OpBuilder &b, Location l,
const SparseIterator *p) {
- if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *begin = b.create(l, b.getStringAttr(prefix + ".begin"), {},
getCursorValTypes(b));
@@ -793,7 +794,7 @@ void SparseIterator::genInit(OpBuilder &b, Location l,
}
Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
- if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *notEnd = b.create(l, b.getStringAttr(prefix + ".not_end"),
getCursor(), b.getI1Type());
@@ -804,7 +805,7 @@ Value SparseIterator::genNotEnd(OpBuilder &b, Location l) {
}
void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
- if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
SmallVector<Value> args = getCursor();
args.push_back(crd);
@@ -818,7 +819,7 @@ void SparseIterator::locate(OpBuilder &b, Location l, Value crd) {
}
Value SparseIterator::deref(OpBuilder &b, Location l) {
- if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
SmallVector<Value> args = getCursor();
Operation *deref = b.create(l, b.getStringAttr(prefix + ".deref"),
@@ -830,7 +831,7 @@ Value SparseIterator::deref(OpBuilder &b, Location l) {
}
ValueRange SparseIterator::forward(OpBuilder &b, Location l) {
- if (emitStrategy == DebugSparseIteration::kInterfaceOnly) {
+ if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
std::string prefix = getDebugInterfacePrefix();
Operation *next = b.create(l, b.getStringAttr(prefix + ".next"),
getCursor(), getCursorValTypes(b));
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 728e2973e83c3..6f5da8073cb60 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -111,7 +111,7 @@ class SparseIterator {
public:
virtual ~SparseIterator() = default;
- static void setDebugSparseIteration(DebugSparseIteration strategy) {
+ static void setSparseEmitStrategy(SparseEmitStrategy strategy) {
SparseIterator::emitStrategy = strategy;
}
@@ -247,7 +247,7 @@ class SparseIterator {
return ref.take_front(cursorValsCnt);
}
- static DebugSparseIteration emitStrategy;
+ static SparseEmitStrategy emitStrategy;
public:
const IterKind kind; // For LLVM-style RTTI.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
index 18118cab9e52c..6aba0ada947e1 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="debug-sparse-iteration=interface-only" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --sparse-reinterpret-map --sparsification="sparse-emit-strategy=debug-interface" --canonicalize --cse --allow-unregistered-dialect | FileCheck %s
#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
>From 6d99e7efb2c0a363784d2fa62e9de2a767543142 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 1 Feb 2024 20:10:16 +0000
Subject: [PATCH 5/5] address comments
---
.../Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
index 6f5da8073cb60..318530cda7632 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h
@@ -171,6 +171,7 @@ class SparseIterator {
// Forwards the iterator to the next element.
ValueRange forward(OpBuilder &b, Location l);
+ // Locate the iterator to the position specified by *crd*, this can only
// be done on an iterator that supports randm access.
void locate(OpBuilder &b, Location l, Value crd);
More information about the Mlir-commits
mailing list