[Mlir-commits] [mlir] [mlir][sparse] Support pretty print to debug sparse iteration. (PR #80207)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 31 13:58:51 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Patch is 71.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/80207.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h (+17-2)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td (+7)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (+3-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp (+3-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp (+8-5)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+12-8)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+203-112)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h (+108-64)
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+48-228)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index e93e2aefb344f..8b2875a751d4a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -47,6 +47,12 @@ enum class ReinterpretMapScope {
kExceptGeneric, // reinterprets operation other than linalg.generic
};
+/// Defines a scope for reinterpret map pass.
+enum class DebugSparseIteration {
+ kNone, // generate fully inlined (and functional) sparse iteration
+ kInterfaceOnly, // generate only place-holder for sparse iteration
+};
+
#define GEN_PASS_DECL
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
@@ -74,11 +80,20 @@ std::unique_ptr<Pass> createPreSparsificationRewritePass();
/// Options for the Sparsification pass.
struct SparsificationOptions {
+ SparsificationOptions(SparseParallelizationStrategy p, DebugSparseIteration d,
+ bool enableRT)
+ : parallelizationStrategy(p), debugSparseIteration(d),
+ enableRuntimeLibrary(enableRT) {}
+
SparsificationOptions(SparseParallelizationStrategy p, bool enableRT)
- : parallelizationStrategy(p), enableRuntimeLibrary(enableRT) {}
+ : SparsificationOptions(p, DebugSparseIteration::kNone, enableRT) {}
+
SparsificationOptions()
- : SparsificationOptions(SparseParallelizationStrategy::kNone, true) {}
+ : SparsificationOptions(SparseParallelizationStrategy::kNone,
+ DebugSparseIteration::kNone, true) {}
+
SparseParallelizationStrategy parallelizationStrategy;
+ DebugSparseIteration debugSparseIteration;
bool enableRuntimeLibrary;
};
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
index f38779ed9ed2b..126b91510d391 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
@@ -130,6 +130,13 @@ def SparsificationPass : Pass<"sparsification", "ModuleOp"> {
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
"any-storage-any-loop",
"Enable sparse parallelization for any storage and loop."))}]>,
+ Option<"debugSparseIteration", "debug-sparse-iteration", "mlir::DebugSparseIteration",
+ "mlir::DebugSparseIteration::kNone",
+ "Pretty print sparse loops to debug sparse iteration", [{llvm::cl::values(
+ clEnumValN(mlir::DebugSparseIteration::kNone, "none",
+ "Turn off pretty printing and generates functional code."),
+ clEnumValN(mlir::DebugSparseIteration::kInterfaceOnly, "interface-only",
+ "Generate non-functional interfaces for sparse iteration."))}]>,
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
"true", "Enable runtime library for manipulating sparse tensors">,
];
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 375e10f9068e4..0ae9f6483588d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -82,13 +82,15 @@ struct SparsificationPass
SparsificationPass(const SparsificationPass &pass) = default;
SparsificationPass(const SparsificationOptions &options) {
parallelization = options.parallelizationStrategy;
+ debugSparseIteration = options.debugSparseIteration;
enableRuntimeLibrary = options.enableRuntimeLibrary;
}
void runOnOperation() override {
auto *ctx = &getContext();
// Translate strategy flags to strategy options.
- SparsificationOptions options(parallelization, enableRuntimeLibrary);
+ SparsificationOptions options(parallelization, debugSparseIteration,
+ enableRuntimeLibrary);
// Apply sparsification and cleanup rewriting.
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..2ceb214052aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1369,7 +1369,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
return failure();
// Recursively generates code if admissible.
- env.startEmit();
+ env.startEmit(options.debugSparseIteration);
genBuffers(env, rewriter);
// TODO: Constant affine expression should be handled differently when using
// slice-based codegen, it does not matter now because we already reject the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
index d3de55e4d59bd..0af1cc1745f51 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.cpp
@@ -59,7 +59,7 @@ LogicalResult CodegenEnv::initTensorExp() {
return success();
}
-void CodegenEnv::startEmit() {
+void CodegenEnv::startEmit(DebugSparseIteration emitStrategy) {
assert(insChain == nullptr && "must only start emitting once");
if (sparseOut) {
insChain = sparseOut->get();
@@ -96,7 +96,8 @@ void CodegenEnv::startEmit() {
/*dependentLvlGetter=*/
[this](TensorId t, Level lvl) -> std::vector<LoopCoeffPair> {
return merger().getDependentLoops(t, lvl);
- });
+ },
+ emitStrategy);
}
std::optional<Operation *> CodegenEnv::genLoopBoundary(
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
index 728af841cc7b1..7eeddac48f4f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
@@ -52,7 +52,7 @@ class CodegenEnv {
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }
- void startEmit();
+ void startEmit(DebugSparseIteration emitStrategy);
/// Generates loop boundary statements (entering/exiting loops). The function
/// passes and updates the passed-in parameters.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 3fa4004ae460e..8c1680a393181 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -81,17 +81,20 @@ static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor,
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
initialize(tensors, loopTag, hasOutput, isSparseOut, numLoops, dimGetter);
}
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
bool isSparseOut, unsigned numLoops,
- DependentLvlGetter dimGetter) {
+ DependentLvlGetter dimGetter,
+ DebugSparseIteration emitStrategy) {
// First initialize the top-level type of the fields.
this->loopTag = loopTag;
this->hasOutput = hasOutput;
this->isSparseOut = isSparseOut;
+ SparseIterator::setDebugSparseIteration(emitStrategy);
const unsigned numManifestTensors = ts.size();
const unsigned synTensorId = numManifestTensors;
@@ -169,7 +172,7 @@ LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
Value offset = genSliceOffset(builder, loc, tensors[t], l);
Value stride = genSliceStride(builder, loc, tensors[t], l);
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
- lvls[t][l]->size());
+ lvls[t][l]->getSize());
return slicedIt;
}
return it;
@@ -465,7 +468,7 @@ std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
// Construct the while-loop with a parameter for each coordinate.
for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getItVals();
+ ValueRange itVals = it->getCursor();
ivs.append(itVals.begin(), itVals.end());
}
@@ -724,7 +727,7 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
// Forward the sparse iterator.
Value cmp = CMPI(eq, it.getCrd(), iv);
it.forwardIf(builder, loc, cmp);
- operands.append(it.getItVals().begin(), it.getItVals().end());
+ operands.append(it.getCursor().begin(), it.getCursor().end());
// const Value newPos = whileOp->getResult(o++);
// Following loops continue iteration from the break point of the
// current while loop.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index d0f447d926f71..e0b4f81487a68 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "mlir/IR/PatternMatch.h"
@@ -84,14 +85,17 @@ class LoopEmitter {
/// `isSparseOut` indicates that the sparse output tensor is empty,
/// so the loop emitter will generate loops over it according to the
/// level-sizes.
- void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0, DependentLvlGetter getter = nullptr);
-
- explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
- bool hasOutput = false, bool isSparseOut = false,
- unsigned numLoops = 0,
- DependentLvlGetter getter = nullptr);
+ void
+ initialize(ValueRange tensors, StringAttr loopTag = nullptr,
+ bool hasOutput = false, bool isSparseOut = false,
+ unsigned numLoops = 0, DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
+
+ explicit LoopEmitter(
+ ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false,
+ bool isSparseOut = false, unsigned numLoops = 0,
+ DependentLvlGetter getter = nullptr,
+ DebugSparseIteration emitStrategy = DebugSparseIteration::kNone);
/// Starts a loop emitting session by generating all the buffers needed
/// for iterating over the tensors.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..bdaf794744bea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -46,20 +46,6 @@ using ValueTuple = std::tuple<Value, Value, Value>;
namespace {
-class SparseLevel : public SparseTensorLevel {
-public:
- SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
-
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, crdBuffer, iv);
- }
-
-protected:
- const Value crdBuffer;
-};
-
class DenseLevel : public SparseTensorLevel {
public:
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
@@ -74,53 +60,27 @@ class DenseLevel : public SparseTensorLevel {
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
if (encoded) {
- Value posLo = MULI(p, lvlSize);
- return {posLo, lvlSize};
+ Value posLo = MULI(p, getSize());
+ return {posLo, getSize()};
}
// No need to linearize the position for non-annotated tensors.
- return {C_IDX(0), lvlSize};
+ return {C_IDX(0), getSize()};
}
const bool encoded;
};
-class CompressedLevel : public SparseLevel {
+class SparseLevel : public SparseTensorLevel {
public:
- CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
+ SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ ValueRange lvlBuf)
+ : SparseTensorLevel(tid, lvl, lt, lvlSize, lvlBuf) {
+ assert(!lvlBuf.empty());
}
-private:
- const Value posBuffer;
-};
-
-class LooseCompressedLevel : public SparseLevel {
-public:
- LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value posBuffer, Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
-
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
- p = MULI(p, C_IDX(2));
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
+ Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
+ return genIndexLoad(b, l, getLvlBufs().front(), iv);
}
-
-private:
- const Value posBuffer;
};
class SingletonLevel : public SparseLevel {
@@ -142,8 +102,8 @@ class SingletonLevel : public SparseLevel {
class TwoOutFourLevel : public SparseLevel {
public:
TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
- : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
+ Value crdBuf)
+ : SparseLevel(tid, lvl, lt, lvlSize, crdBuf) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
@@ -154,6 +114,39 @@ class TwoOutFourLevel : public SparseLevel {
}
};
+class CompressedLevel : public SparseLevel {
+public:
+ CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ if (max == nullptr) {
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+ llvm_unreachable("compressed-nu should be the first non-unique level.");
+ }
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value posBuffer, Value crdBuffer)
+ : SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer, posBuffer}) {}
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ Value max) const override {
+ assert(max == nullptr && "loss compressed level can not be non-unique.");
+ p = MULI(p, C_IDX(2));
+ Value pLo = genIndexLoad(b, l, getPosBuf(), p);
+ Value pHi = genIndexLoad(b, l, getPosBuf(), ADDI(p, C_IDX(1)));
+ return {pLo, pHi};
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -203,7 +196,8 @@ static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
// SparseIterator derived classes.
//===----------------------------------------------------------------------===//
-namespace {
+namespace mlir {
+namespace sparse_tensor {
// The iterator that traverses a concrete sparse tensor levels. High-level
// abstract iterators wrap it to achieve more complex goals (such as collapsing
@@ -212,12 +206,11 @@ namespace {
class ConcreteIterator : public SparseIterator {
protected:
ConcreteIterator(const SparseTensorLevel &stl, IterKind kind,
- unsigned itValCnt)
- : SparseIterator(kind, stl.tid, stl.lvl, itValCnt, itValsStorage),
- stl(stl) {
- // Allocate enough storage for iterator values.
- itValsStorage.resize(itValCnt);
- }
+ unsigned cursorValCnt)
+ : SparseIterator(kind, stl.tid, stl.lvl, cursorValCnt, cursorValsStorage),
+ stl(stl), cursorValsStorage(cursorValCnt, nullptr) {
+ assert(getCursor().size() == cursorValCnt);
+ };
public:
// For LLVM-style RTTI.
@@ -228,22 +221,34 @@ class ConcreteIterator : public SparseIterator {
bool randomAccessible() const override { return isDenseLT(stl.getLT()); };
bool iteratableByFor() const override { return kind != IterKind::kDedup; };
Value upperBound(OpBuilder &b, Location l) const override {
- return stl.size();
+ return stl.getSize();
};
protected:
+ const SparseTensorLevel &stl;
// Owner of the storage, all wrappers build on top of a concrete iterator
// share the same storage such that the iterator values are always
// synchronized.
- SmallVector<Value> itValsStorage;
- const SparseTensorLevel &stl;
+ SmallVector<Value> cursorValsStorage;
};
+} // namespace sparse_tensor
+} // namespace mlir
+
+namespace {
+
class TrivialIterator : public ConcreteIterator {
public:
TrivialIterator(const SparseTensorLevel &stl)
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
+ std::string getDebugInterfacePrefix() const override {
+ return std::string("trivial<") + stl.toString() + ">";
+ }
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ return {b.getIndexType()};
+ }
+
SmallVector<Value> serialize() const override {
SmallVector<Value> ret;
ret.push_back(getItPos());
@@ -286,12 +291,12 @@ class TrivialIterator : public ConcreteIterator {
return std::make_pair(getItPos(), posHi);
}
- Value genNotEnd(OpBuilder &b, Location l) override {
+ Value genNotEndImpl(OpBuilder &b, Location l) override {
// We used the first level bound as the bound the collapsed set of levels.
return CMPI(ult, getItPos(), posHi);
}
- Value deref(OpBuilder &b, Location l) override {
+ Value derefImpl(OpBuilder &b, Location l) override {
if (randomAccessible()) {
updateCrd(SUBI(getItPos(), posLo));
} else {
@@ -302,24 +307,24 @@ class TrivialIterator : public ConcreteIterator {
ValueRange forwardImpl(OpBuilder &b, Location l) override {
seek(ADDI(getItPos(), C_IDX(1)));
- return getItVals();
+ return getCursor();
}
ValueRange forwardIf(OpBuilder &b, Location l, Value cond) override {
- Value curPos = getItVals().front();
+ Value curPos = getCursor().front();
Value nxPos = forward(b, l).front();
seek(SELECT(cond, nxPos, curPos));
- return getItVals();
+ return getCursor();
}
- void locate(OpBuilder &b, Location l, Value crd) override {
+ void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
// Seek to the linearized position.
seek(ADDI(crd, posLo));
updateCrd(crd);
}
- Value getItPos() const { return getItVals().front(); }
+ Value getItPos() const { return getCursor().front(); }
Value posLo, posHi;
};
@@ -337,6 +342,13 @@ class DedupIterator : public ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/80207
More information about the Mlir-commits
mailing list