[Mlir-commits] [mlir] b60de1d - [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId`
wren romano
llvmlistbot at llvm.org
Wed Mar 15 12:27:55 PDT 2023
Author: wren romano
Date: 2023-03-15T12:27:47-07:00
New Revision: b60de1dfcc15d9505de958fe160b45bea11286f2
URL: https://github.com/llvm/llvm-project/commit/b60de1dfcc15d9505de958fe160b45bea11286f2
DIFF: https://github.com/llvm/llvm-project/commit/b60de1dfcc15d9505de958fe160b45bea11286f2.diff
LOG: [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId`
Since all callsites of `foreachTensorLoopId` would simply look up the `LatPointId` to extract its `BitVector`, it's cleaner to let the `Merger` handle that instead. This seems to better capture the intent of the `foreachTensorLoopId` method, and improves decoupling (since it removes a place that leaks the implementation detail that we use `BitVector`).
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146082
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 3c5d2d37e3e03..59c5b78fda7b8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -437,11 +437,11 @@ class Merger {
/// for each `TensorLoopId` and passing it the corresponding tensor
/// identifier, level, and level-type.
void
- foreachTensorLoopId(const BitVector &bits,
+ foreachTensorLoopId(LatPointId p,
function_ref<void(TensorLoopId, TensorId,
std::optional<Level>, DimLevelType)>
callback) const {
- for (const TensorLoopId b : bits.set_bits())
+ for (const TensorLoopId b : latPoints[p].bits.set_bits())
callback(b, tensor(b), getLvl(b), getDimLevelType(b));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9fedd5a78658d..2779e2d4be786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1273,18 +1273,18 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
SmallVector<TensorId> tids;
SmallVector<Level> lvls;
- env.merger().foreachTensorLoopId(
- env.lat(l0).bits, [&](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl, DimLevelType dlt) {
- assert(env.merger().loop(b) == idx);
- if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
- needsUniv = true;
- } else {
- // sparse/singleton levels.
- tids.push_back(tid);
- lvls.push_back(*lvl);
- }
- });
+ env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
+ std::optional<Level> lvl,
+ DimLevelType dlt) {
+ assert(env.merger().loop(b) == idx);
+ if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+ needsUniv = true;
+ } else {
+ // sparse/singleton levels.
+ tids.push_back(tid);
+ lvls.push_back(*lvl);
+ }
+ });
env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);
@@ -1342,7 +1342,6 @@ static bool translateBitsToTidLvlPairs(
CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
- const BitVector &all = env.lat(li).bits;
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
@@ -1350,8 +1349,8 @@ static bool translateBitsToTidLvlPairs(
unsigned numloopCond = 0;
bool hasNonUnique = false;
env.merger().foreachTensorLoopId(
- all, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
- DimLevelType dlt) {
+ li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+ DimLevelType dlt) {
if (simple.test(b)) {
if (isUndefDLT(dlt)) {
// An undefined dlt in the lattices, we probably mean to
More information about the Mlir-commits
mailing list