[Mlir-commits] [mlir] 5df63ad - [mlir][sparse] Adding `Ranks::operator==` and `VarSet::getRank` etc

wren romano llvmlistbot at llvm.org
Mon Jul 31 17:09:07 PDT 2023


Author: wren romano
Date: 2023-07-31T17:09:00-07:00
New Revision: 5df63ad826930ea56ffd2f20c5649ff67ff3e0b7

URL: https://github.com/llvm/llvm-project/commit/5df63ad826930ea56ffd2f20c5649ff67ff3e0b7
DIFF: https://github.com/llvm/llvm-project/commit/5df63ad826930ea56ffd2f20c5649ff67ff3e0b7.diff

LOG: [mlir][sparse] Adding `Ranks::operator==` and `VarSet::getRank` etc

Depends On D156001

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D156010

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 7e848d325fb743..63f55bd43e8db3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -13,6 +13,14 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 using namespace mlir::sparse_tensor::ir_detail;
 
+//===----------------------------------------------------------------------===//
+// `VarKind` helpers.
+//===----------------------------------------------------------------------===//
+
+/// For use in foreach loops.
+static constexpr const VarKind everyVarKind[] = {
+    VarKind::Dimension, VarKind::Symbol, VarKind::Level};
+
 //===----------------------------------------------------------------------===//
 // `Var` implementation.
 //===----------------------------------------------------------------------===//
@@ -32,6 +40,13 @@ void Var::dump() const {
 // `Ranks` implementation.
 //===----------------------------------------------------------------------===//
 
+bool Ranks::operator==(Ranks const &other) const {
+  for (const auto vk : everyVarKind)
+    if (getRank(vk) != other.getRank(vk))
+      return false;
+  return true;
+}
+
 bool Ranks::isValid(DimLvlExpr expr) const {
   assert(expr);
   // Compute the maximum identifiers for symbol-vars and dim/lvl-vars
@@ -49,9 +64,6 @@ bool Ranks::isValid(DimLvlExpr expr) const {
 // `VarSet` implementation.
 //===----------------------------------------------------------------------===//
 
-static constexpr const VarKind everyVarKind[] = {
-    VarKind::Dimension, VarKind::Symbol, VarKind::Level};
-
 VarSet::VarSet(Ranks const &ranks) {
   // NOTE: We must not use `reserve` here, since that doesn't change
   // the `size` of the bitvectors and therefore will result in unexpected
@@ -59,6 +71,7 @@ VarSet::VarSet(Ranks const &ranks) {
   // move-ctor since it should be (marginally) more efficient.
   for (const auto vk : everyVarKind)
     impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
+  assert(getRanks() == ranks);
 }
 
 bool VarSet::contains(Var var) const {

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index f23d554aee6c24..18c68dd5e1118e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -292,6 +292,9 @@ class Ranks final {
       : Ranks(ranks[VarKind::Symbol], ranks[VarKind::Dimension],
               ranks[VarKind::Level]) {}
 
+  bool operator==(Ranks const &other) const;
+  bool operator!=(Ranks const &other) const { return !(*this == other); }
+
   constexpr unsigned getRank(VarKind vk) const { return impl[to_index(vk)]; }
   constexpr unsigned getSymRank() const { return getRank(VarKind::Symbol); }
   constexpr unsigned getDimRank() const { return getRank(VarKind::Dimension); }
@@ -324,6 +327,14 @@ class VarSet final {
 public:
   explicit VarSet(Ranks const &ranks);
 
+  unsigned getRank(VarKind vk) const { return impl[vk].size(); }
+  unsigned getSymRank() const { return getRank(VarKind::Symbol); }
+  unsigned getDimRank() const { return getRank(VarKind::Dimension); }
+  unsigned getLvlRank() const { return getRank(VarKind::Level); }
+  Ranks getRanks() const {
+    return Ranks(getSymRank(), getDimRank(), getLvlRank());
+  }
+
   bool contains(Var var) const;
   bool occursIn(VarSet const &vars) const;
   bool occursIn(DimLvlExpr expr) const;


        


More information about the Mlir-commits mailing list