[Mlir-commits] [mlir] dcadb68 - [mlir][sparse] Cleaning up OOB implementation details for `VarSet`

wren romano llvmlistbot at llvm.org
Fri Jul 7 14:36:03 PDT 2023


Author: wren romano
Date: 2023-07-07T14:35:56-07:00
New Revision: dcadb68a5c5a50f41a023b39748a506df8c07248

URL: https://github.com/llvm/llvm-project/commit/dcadb68a5c5a50f41a023b39748a506df8c07248
DIFF: https://github.com/llvm/llvm-project/commit/dcadb68a5c5a50f41a023b39748a506df8c07248.diff

LOG: [mlir][sparse] Cleaning up OOB implementation details for `VarSet`

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D154674

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 35022d7cfa1b09..7b67f8eece0da2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -56,22 +56,22 @@ static constexpr const VarKind everyVarKind[] = {
     VarKind::Dimension, VarKind::Symbol, VarKind::Level};
 
 VarSet::VarSet(Ranks const &ranks) {
-  // FIXME(wrengr): will this DWIM, or do we need to worry about
-  // `reserve` causing resizing/dangling issues?
   for (const auto vk : everyVarKind)
     impl[vk].reserve(ranks.getRank(vk));
 }
 
 bool VarSet::contains(Var var) const {
-  // FIXME(wrengr): this implementation will raise assertion failure on OOB;
-  // but perhaps we'd rather have this return false on OOB?  That's
-  // necessary for consistency with the `anyCommon` implementation of
-  // `occursIn(VarSet)`.
+  // NOTE: We make sure to return false on OOB, for consistency with
+  // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
+  // However beware that, as always with silencing OOB, this can hide
+  // bugs in client code.
   const llvm::SmallBitVector &bits = impl[var.getKind()];
-  // NOTE TO Wren: did this to avoid OOB but perhaps it is result of bug
-  if (var.getNum() >= bits.size())
-    return false;
-  return bits[var.getNum()];
+  const auto num = var.getNum();
+  // FIXME(wrengr): If we `assert(num < bits.size())` then
+  // "roundtrip_encoding.mlir" will fail.  So we need to figure out
+  // where exactly the OOB `var` is coming from, to determine whether
+  // that's a logic bug or not.
+  return num < bits.size() && bits[num];
 }
 
 bool VarSet::occursIn(VarSet const &other) const {
@@ -105,13 +105,20 @@ bool VarSet::occursIn(DimLvlExpr expr) const {
 }
 
 void VarSet::add(Var var) {
-  // FIXME(wrengr): this implementation will raise assertion failure on OOB;
-  // but perhaps we'd rather have this be a noop on OOB?  or to grow
-  // the underlying bitvectors on OOB?
+  // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
   impl[var.getKind()][var.getNum()] = true;
 }
 
-// TODO(wrengr): void VarSet::add(VarSet const& other);
+void VarSet::add(VarSet const &other) {
+  // NOTE: `SmallBitVector::operator&=` will implicitly resize
+  // the bitvector (unlike `BitVector::operator&=`), so we add an
+  // assertion against OOB for consistency with the implementation
+  // of `VarSet::add(Var)`.
+  for (const auto vk : everyVarKind) {
+    assert(impl[vk].size() >= other.impl[vk].size());
+    impl[vk] &= other.impl[vk];
+  }
+}
 
 void VarSet::add(DimLvlExpr expr) {
   if (!expr)

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index fef806b0fc1024..24dd631dba1bc4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -271,7 +271,6 @@ class DimLvlExpr;
 //===----------------------------------------------------------------------===//
 class Ranks final {
   // Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
-  // TODO(wrengr): to what extent do we actually care about constexpr here?
   unsigned impl[3];
 
   static constexpr unsigned to_index(VarKind vk) {
@@ -303,6 +302,14 @@ class Ranks final {
 static_assert(IsZeroCostAbstraction<Ranks>);
 
 //===----------------------------------------------------------------------===//
+/// Efficient representation of a set of `Var`.
+///
+/// NOTE: For the `contains`/`occursIn` methods: if variables occurring in
+/// the method parameter are OOB for the `VarSet`, then these methods will
+/// always return false.  However, for the `add` methods: OOB parameters
+/// cause undefined behavior.  Currently the `add` methods will raise an
+/// assertion error; though we may change that behavior in the future
+/// (e.g., to resize the underlying bitvectors).
 class VarSet final {
   // If we're willing to give up the possibility of resizing the
   // individual bitvectors, then we could flatten this into a single
@@ -314,14 +321,12 @@ class VarSet final {
 public:
   explicit VarSet(Ranks const &ranks);
 
-  // TODO(wrengr): can we come up with a single name that works for all three of
-  // these?
   bool contains(Var var) const;
   bool occursIn(VarSet const &vars) const;
   bool occursIn(DimLvlExpr expr) const;
 
   void add(Var var);
-  // TODO(wrengr): void add(VarSet const& vars);
+  void add(VarSet const &vars);
   void add(DimLvlExpr expr);
 };
 
@@ -397,7 +402,6 @@ class VarEnv final {
   VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
 
 public:
-  // NOTE TO Wren: initializer needed!
   VarEnv() : nextNum(0) {}
 
   /// Gets the underlying storage for the `VarInfo` identified by


        


More information about the Mlir-commits mailing list