[Mlir-commits] [mlir] b60de1d - [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId`

wren romano llvmlistbot at llvm.org
Wed Mar 15 12:27:55 PDT 2023


Author: wren romano
Date: 2023-03-15T12:27:47-07:00
New Revision: b60de1dfcc15d9505de958fe160b45bea11286f2

URL: https://github.com/llvm/llvm-project/commit/b60de1dfcc15d9505de958fe160b45bea11286f2
DIFF: https://github.com/llvm/llvm-project/commit/b60de1dfcc15d9505de958fe160b45bea11286f2.diff

LOG: [mlir][sparse] Updating `Merger::foreachTensorLoopId` to take `LatPointId`

Since all callsites of `foreachTensorLoopId` would simply look up the `LatPointId` to extract its `BitVector`, it's cleaner to let the `Merger` handle that instead.  This seems to better capture the intent of the `foreachTensorLoopId` method, and improves decoupling (since it removes a place that leaks the implementation detail that we use `BitVector`).

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D146082

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 3c5d2d37e3e03..59c5b78fda7b8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -437,11 +437,11 @@ class Merger {
   /// for each `TensorLoopId` and passing it the corresponding tensor
   /// identifier, level, and level-type.
   void
-  foreachTensorLoopId(const BitVector &bits,
+  foreachTensorLoopId(LatPointId p,
                       function_ref<void(TensorLoopId, TensorId,
                                         std::optional<Level>, DimLevelType)>
                           callback) const {
-    for (const TensorLoopId b : bits.set_bits())
+    for (const TensorLoopId b : latPoints[p].bits.set_bits())
       callback(b, tensor(b), getLvl(b), getDimLevelType(b));
   }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 9fedd5a78658d..2779e2d4be786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1273,18 +1273,18 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
 
   SmallVector<TensorId> tids;
   SmallVector<Level> lvls;
-  env.merger().foreachTensorLoopId(
-      env.lat(l0).bits, [&](TensorLoopId b, TensorId tid,
-                            std::optional<Level> lvl, DimLevelType dlt) {
-        assert(env.merger().loop(b) == idx);
-        if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
-          needsUniv = true;
-        } else {
-          // sparse/singleton levels.
-          tids.push_back(tid);
-          lvls.push_back(*lvl);
-        }
-      });
+  env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
+                                           std::optional<Level> lvl,
+                                           DimLevelType dlt) {
+    assert(env.merger().loop(b) == idx);
+    if (isDenseDLT(dlt) || isUndefDLT(dlt)) {
+      needsUniv = true;
+    } else {
+      // sparse/singleton levels.
+      tids.push_back(tid);
+      lvls.push_back(*lvl);
+    }
+  });
 
   env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls);
 
@@ -1342,7 +1342,6 @@ static bool translateBitsToTidLvlPairs(
     CodegenEnv &env, LatPointId li, LoopId ldx, SmallVectorImpl<TensorId> &tids,
     SmallVectorImpl<Level> &lvls, SmallVectorImpl<TensorId> &affineTids,
     SmallVectorImpl<Level> &affineLvls, SmallVectorImpl<AffineExpr> &exps) {
-  const BitVector &all = env.lat(li).bits;
   const BitVector &simple = env.lat(li).simple;
   const TensorId outTid = env.merger().getOutTensorID();
   const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
@@ -1350,8 +1349,8 @@ static bool translateBitsToTidLvlPairs(
   unsigned numloopCond = 0;
   bool hasNonUnique = false;
   env.merger().foreachTensorLoopId(
-      all, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
-                    DimLevelType dlt) {
+      li, [&, ldx](TensorLoopId b, TensorId tid, std::optional<Level> lvl,
+                   DimLevelType dlt) {
         if (simple.test(b)) {
           if (isUndefDLT(dlt)) {
             // An undefined dlt in the lattices, we probably mean to


        


More information about the Mlir-commits mailing list