[Mlir-commits] [mlir] 13e9afd - [mlir][sparse] Adding new `Merger::addLat` overload
wren romano
llvmlistbot at llvm.org
Tue Mar 21 16:22:12 PDT 2023
Author: wren romano
Date: 2023-03-21T16:22:04-07:00
New Revision: 13e9afd16d8aac49caf3abaa35bc97b5430331d3
URL: https://github.com/llvm/llvm-project/commit/13e9afd16d8aac49caf3abaa35bc97b5430331d3
DIFF: https://github.com/llvm/llvm-project/commit/13e9afd16d8aac49caf3abaa35bc97b5430331d3.diff
LOG: [mlir][sparse] Adding new `Merger::addLat` overload
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D146559
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 6e39404bb28aa..991c920c17399 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -280,6 +280,7 @@ class Merger {
/// Constructs a new iteration lattice point, and returns its identifier.
LatPointId addLat(TensorId t, LoopId i, ExprId e);
+ LatPointId addLat(const BitVector &bits, ExprId e);
/// Constructs a new (initially empty) set, and returns its identifier.
LatSetId addSet();
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 4a8c3cbfbe584..0691d2554f438 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -247,6 +247,13 @@ LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
return p;
}
+LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
+ assert(bits.size() == numLoops * numTensors);
+ const LatPointId p = latPoints.size();
+ latPoints.emplace_back(bits, e);
+ return p;
+}
+
LatSetId Merger::addSet() {
const LatSetId s = latSets.size();
latSets.emplace_back();
@@ -322,8 +329,7 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
const LatSetId s = addSet();
for (const LatPointId p : latSets[s0]) {
const ExprId e = addExp(kind, latPoints[p].exp, v, op);
- latPoints.emplace_back(latPoints[p].bits, e);
- latSets[s].push_back(latPoints.size() - 1);
+ latSets[s].push_back(addLat(latPoints[p].bits, e));
}
return s;
}
More information about the Mlir-commits
mailing list