[Mlir-commits] [mlir] 47a715d - [mlir][sparse] minor merger API simplification
Aart Bik
llvmlistbot at llvm.org
Tue Sep 13 18:07:36 PDT 2022
Author: Aart Bik
Date: 2022-09-13T18:07:24-07:00
New Revision: 47a715d43e290142c79a05d12c105ae49a43acca
URL: https://github.com/llvm/llvm-project/commit/47a715d43e290142c79a05d12c105ae49a43acca
DIFF: https://github.com/llvm/llvm-project/commit/47a715d43e290142c79a05d12c105ae49a43acca.diff
LOG: [mlir][sparse] minor merger API simplification
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D133821
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 584df8476466b..f9be4762ba640 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -261,8 +261,8 @@ class Merger {
return getDimLevelFormat(t, i).levelType == tp;
}
- /// Returns true if any set bit corresponds to given dimension level type.
- bool hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const;
+ /// Returns true if any set bit corresponds to sparse dimension level type.
+ bool hasAnySparse(const BitVector &bits) const;
/// Dimension level format getter.
DimLevelFormat getDimLevelFormat(unsigned t, unsigned i) const {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 642c66635a38d..918018a20c457 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1641,8 +1641,7 @@ static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder,
unsigned lsize = merger.set(lts).size();
for (unsigned i = 1; i < lsize; i++) {
unsigned li = merger.set(lts)[i];
- if (!merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kCompressed) &&
- !merger.hasAnyDimLevelTypeOf(merger.lat(li).simple, DimLvlType::kSingleton))
+ if (!merger.hasAnySparse(merger.lat(li).simple))
return true;
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index cd5d1347c1356..bd8c8f2a65b81 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -262,13 +262,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
}
}
// Now apply the two basic rules.
- //
- // TODO: improve for singleton and properties
- //
BitVector simple = latPoints[p0].bits;
- bool reset = isSingleton &&
- (hasAnyDimLevelTypeOf(simple, DimLvlType::kCompressed) ||
- hasAnyDimLevelTypeOf(simple, DimLvlType::kSingleton));
+ bool reset = isSingleton && hasAnySparse(simple);
for (unsigned b = 0, be = simple.size(); b < be; b++) {
if (simple[b] &&
(!isDimLevelType(b, DimLvlType::kCompressed) &&
@@ -297,8 +292,7 @@ bool Merger::latGT(unsigned i, unsigned j) const {
bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
BitVector tmp = latPoints[j].bits;
tmp ^= latPoints[i].bits;
- return !hasAnyDimLevelTypeOf(tmp, DimLvlType::kCompressed) &&
- !hasAnyDimLevelTypeOf(tmp, DimLvlType::kSingleton);
+ return !hasAnySparse(tmp);
}
bool Merger::isSingleCondition(unsigned t, unsigned e) const {
@@ -384,9 +378,10 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
llvm_unreachable("unexpected kind");
}
-bool Merger::hasAnyDimLevelTypeOf(const BitVector &bits, DimLvlType tp) const {
+bool Merger::hasAnySparse(const BitVector &bits) const {
for (unsigned b = 0, be = bits.size(); b < be; b++)
- if (bits[b] && isDimLevelType(b, tp))
+ if (bits[b] && (isDimLevelType(b, DimLvlType::kCompressed) ||
+ isDimLevelType(b, DimLvlType::kSingleton)))
return true;
return false;
}
More information about the Mlir-commits
mailing list