[Mlir-commits] [mlir] 5df63ad - [mlir][sparse] Adding `Ranks::operator==` and `VarSet::getRank` etc
wren romano
llvmlistbot at llvm.org
Mon Jul 31 17:09:07 PDT 2023
Author: wren romano
Date: 2023-07-31T17:09:00-07:00
New Revision: 5df63ad826930ea56ffd2f20c5649ff67ff3e0b7
URL: https://github.com/llvm/llvm-project/commit/5df63ad826930ea56ffd2f20c5649ff67ff3e0b7
DIFF: https://github.com/llvm/llvm-project/commit/5df63ad826930ea56ffd2f20c5649ff67ff3e0b7.diff
LOG: [mlir][sparse] Adding `Ranks::operator==` and `VarSet::getRank` etc
Depends On D156001
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D156010
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 7e848d325fb743..63f55bd43e8db3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -13,6 +13,14 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
+//===----------------------------------------------------------------------===//
+// `VarKind` helpers.
+//===----------------------------------------------------------------------===//
+
+/// For use in foreach loops.
+static constexpr const VarKind everyVarKind[] = {
+ VarKind::Dimension, VarKind::Symbol, VarKind::Level};
+
//===----------------------------------------------------------------------===//
// `Var` implementation.
//===----------------------------------------------------------------------===//
@@ -32,6 +40,13 @@ void Var::dump() const {
// `Ranks` implementation.
//===----------------------------------------------------------------------===//
+bool Ranks::operator==(Ranks const &other) const {
+ for (const auto vk : everyVarKind)
+ if (getRank(vk) != other.getRank(vk))
+ return false;
+ return true;
+}
+
bool Ranks::isValid(DimLvlExpr expr) const {
assert(expr);
// Compute the maximum identifiers for symbol-vars and dim/lvl-vars
@@ -49,9 +64,6 @@ bool Ranks::isValid(DimLvlExpr expr) const {
// `VarSet` implementation.
//===----------------------------------------------------------------------===//
-static constexpr const VarKind everyVarKind[] = {
- VarKind::Dimension, VarKind::Symbol, VarKind::Level};
-
VarSet::VarSet(Ranks const &ranks) {
// NOTE: We must not use `reserve` here, since that doesn't change
// the `size` of the bitvectors and therefore will result in unexpected
@@ -59,6 +71,7 @@ VarSet::VarSet(Ranks const &ranks) {
// move-ctor since it should be (marginally) more efficient.
for (const auto vk : everyVarKind)
impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
+ assert(getRanks() == ranks);
}
bool VarSet::contains(Var var) const {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index f23d554aee6c24..18c68dd5e1118e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -292,6 +292,9 @@ class Ranks final {
: Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
ranks[VarKind::Level]) {}
+ bool operator==(Ranks const &other) const;
+ bool operator!=(Ranks const &other) const { return !(*this == other); }
+
constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
@@ -324,6 +327,14 @@ class VarSet final {
public:
explicit VarSet(Ranks const &ranks);
+ unsigned getRank(VarKind vk) const { return impl[vk].size(); }
+ unsigned getSymRank() const { return getRank(VarKind::Symbol); }
+ unsigned getDimRank() const { return getRank(VarKind::Dimension); }
+ unsigned getLvlRank() const { return getRank(VarKind::Level); }
+ Ranks getRanks() const {
+ return Ranks(getSymRank(), getDimRank(), getLvlRank());
+ }
+
bool contains(Var var) const;
bool occursIn(VarSet const &vars) const;
bool occursIn(DimLvlExpr expr) const;
More information about the Mlir-commits
mailing list