[Mlir-commits] [mlir] 13e9afd - [mlir][sparse] Adding new `Merger::addLat` overload

wren romano llvmlistbot at llvm.org
Tue Mar 21 16:22:12 PDT 2023


Author: wren romano
Date: 2023-03-21T16:22:04-07:00
New Revision: 13e9afd16d8aac49caf3abaa35bc97b5430331d3

URL: https://github.com/llvm/llvm-project/commit/13e9afd16d8aac49caf3abaa35bc97b5430331d3
DIFF: https://github.com/llvm/llvm-project/commit/13e9afd16d8aac49caf3abaa35bc97b5430331d3.diff

LOG: [mlir][sparse] Adding new `Merger::addLat` overload

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146559

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 6e39404bb28aa..991c920c17399 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -280,6 +280,7 @@ class Merger {
 
   /// Constructs a new iteration lattice point, and returns its identifier.
   LatPointId addLat(TensorId t, LoopId i, ExprId e);
+  LatPointId addLat(const BitVector &bits, ExprId e);
 
   /// Constructs a new (initially empty) set, and returns its identifier.
   LatSetId addSet();

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 4a8c3cbfbe584..0691d2554f438 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -247,6 +247,13 @@ LatPointId Merger::addLat(TensorId t, LoopId i, ExprId e) {
   return p;
 }
 
+LatPointId Merger::addLat(const BitVector &bits, ExprId e) {
+  assert(bits.size() == numLoops * numTensors);
+  const LatPointId p = latPoints.size();
+  latPoints.emplace_back(bits, e);
+  return p;
+}
+
 LatSetId Merger::addSet() {
   const LatSetId s = latSets.size();
   latSets.emplace_back();
@@ -322,8 +329,7 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v,
   const LatSetId s = addSet();
   for (const LatPointId p : latSets[s0]) {
     const ExprId e = addExp(kind, latPoints[p].exp, v, op);
-    latPoints.emplace_back(latPoints[p].bits, e);
-    latSets[s].push_back(latPoints.size() - 1);
+    latSets[s].push_back(addLat(latPoints[p].bits, e));
   }
   return s;
 }


        


More information about the Mlir-commits mailing list