[Mlir-commits] [mlir] 47a715d - [mlir][sparse] minor merger API simplification

Aart Bik llvmlistbot at llvm.org
Tue Sep 13 18:07:36 PDT 2022


Author: Aart Bik
Date: 2022-09-13T18:07:24-07:00
New Revision: 47a715d43e290142c79a05d12c105ae49a43acca

URL: https://github.com/llvm/llvm-project/commit/47a715d43e290142c79a05d12c105ae49a43acca
DIFF: https://github.com/llvm/llvm-project/commit/47a715d43e290142c79a05d12c105ae49a43acca.diff

LOG: [mlir][sparse] minor merger API simplification

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D133821

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 584df8476466b..f9be4762ba640 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -261,8 +261,8 @@ class Merger {
     return getDimLevelFormat(t, i).levelType == tp;
   }
 
-  /// Returns true if any set bit corresponds to given dimension level type.
-  bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
+  /// Returns true if any set bit corresponds to sparse dimension level type.
+  bool hasAnySparse(const BitVector &bits) const;
 
   /// Dimension level format getter.
   DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 642c66635a38d..918018a20c457 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1641,8 +1641,7 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     unsigned lsize = merger.set(lts).size();
     for (unsigned i = 1; i < lsize; i++) {
       unsigned li = merger.set(lts)[i];
-      if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
-          !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
+      if (!merger.hasAnySparse(merger.lat(li).simple))
         return true;
     }
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index cd5d1347c1356..bd8c8f2a65b81 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -262,13 +262,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
     }
   }
   // Now apply the two basic rules.
-  //
-  // TODO: improve for singleton and properties
-  //
   BitVector simple = latPoints[p0].bits;
-  bool reset = isSingleton &&
-      (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
-       hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
+  bool reset = isSingleton && hasAnySparse(simple);
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
     if (simple[b] &&
         (!isDimLevelType(b, DimLvlType::kCompressed) &&
@@ -297,8 +292,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
   BitVector tmp = latPoints[j].bits;
   tmp ^= latPoints[i].bits;
-  return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
-         !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
+  return !hasAnySparse(tmp);
 }
 
 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
@@ -384,9 +378,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
   llvm_unreachable("unexpected kind");
 }
 
-bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
+bool Merger::hasAnySparse(const BitVector &bits) const {
   for (unsigned b = 0, be = bits.size(); b < be; b++)
-    if (bits[b] && isDimLevelType(b, tp))
+    if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) ||
+                    isDimLevelType(b, DimLvlType::kSingleton)))
       return true;
   return false;
 }


        


More information about the Mlir-commits mailing list