[Mlir-commits] [mlir] 497050c - [mlir][sparse] Fixes bug in VarSet ctor

wren romano llvmlistbot at llvm.org
Mon Jul 31 14:28:07 PDT 2023


Author: wren romano
Date: 2023-07-31T14:27:59-07:00
New Revision: 497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec

URL: https://github.com/llvm/llvm-project/commit/497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec
DIFF: https://github.com/llvm/llvm-project/commit/497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec.diff

LOG: [mlir][sparse] Fixes bug in VarSet ctor

Previously, the commented out code in the `DimLvlMap` ctor would result in `VarSet::add` raising an OOB error; which should be impossible because the ctor asserted `DimLvlMap::isWF` which ensures that all variables occuring in the map are within bounds for the ranks.

The root cause of that bug was the `VarSet` ctor using `SmallBitVector::reserve` which does not actually change the size of the bitvectors (hence the subsequent OOB).  This is corrected by using any of `SmallBitVector::resize`, the move-ctor, or the copy-ctor.  Since the default-initialized bitvectors being modified/overwritten have size zero, there shouldn't be any significant performance difference between these three implementations.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D155999

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index d87258756ed87b..ca26a83dcd0cdd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -252,6 +252,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
                      ArrayRef<LvlSpec> lvlSpecs)
     : symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) {
   // First, check integrity of the variable-binding structure.
+  // NOTE: This establishes the invariant that calls to `VarSet::add`
+  // below cannot cause OOB errors.
   assert(isWF());
 
   // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
@@ -260,14 +262,19 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   // needs to happen before the code for setting every `LvlSpec::elideVar`,
   // since if the LvlVar is only used in elided DimExpr, then the
   // LvlVar should also be elided.
+  // NOTE: Whenever we set a new DimExpr, we must make sure to validate it
+  // against our ranks, to restore the invariant established by `isWF` above.
+  // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a
+  // `Ranks` argument and perform the validation then.
 
   // Third, we set every `LvlSpec::elideVar` according to whether that
   // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
+  // NOTE: The invariant established by `isWF` ensures that the following
+  // calls to `VarSet::add` cannot raise OOB errors.
   VarSet usedVars(getRanks());
-  // NOTE TO Wren: bypassed for now
-  // for (const auto &dimSpec : dimSpecs)
-  //  if (!dimSpec.canElideExpr())
-  //    usedVars.add(dimSpec.getExpr());
+  for (const auto &dimSpec : dimSpecs)
+    if (!dimSpec.canElideExpr())
+      usedVars.add(dimSpec.getExpr());
   for (auto &lvlSpec : this->lvlSpecs)
     lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar()));
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 7250d44b53d0d6..a30058d512d44f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -56,8 +56,12 @@ static constexpr const VarKind everyVarKind[] = {
     VarKind::Dimension, VarKind::Symbol, VarKind::Level};
 
 VarSet::VarSet(Ranks const &ranks) {
+  // NOTE: We must not use `reserve` here, since that doesn't change
+  // the `size` of the bitvectors and therefore will result in unexpected
+  // OOB errors.  Either `resize` or copy/move-ctor work; we opt for the
+  // move-ctor since it should be (marginally) more efficient.
   for (const auto vk : everyVarKind)
-    impl[vk].reserve(ranks.getRank(vk));
+    impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
 }
 
 bool VarSet::contains(Var var) const {
@@ -67,10 +71,6 @@ bool VarSet::contains(Var var) const {
   // bugs in client code.
   const llvm::SmallBitVector &bits = impl[var.getKind()];
   const auto num = var.getNum();
-  // FIXME(wrengr): If we `assert(num < bits.size())` then
-  // "roundtrip_encoding.mlir" will fail.  So we need to figure out
-  // where exactly the OOB `var` is coming from, to determine whether
-  // that's a logic bug or not.
   return num < bits.size() && bits[num];
 }
 


        


More information about the Mlir-commits mailing list