[Mlir-commits] [mlir] 497050c - [mlir][sparse] Fixes bug in VarSet ctor
wren romano
llvmlistbot at llvm.org
Mon Jul 31 14:28:07 PDT 2023
Author: wren romano
Date: 2023-07-31T14:27:59-07:00
New Revision: 497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec
URL: https://github.com/llvm/llvm-project/commit/497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec
DIFF: https://github.com/llvm/llvm-project/commit/497050c961bbd4c13ea7ac8d6e80b9d7dd51e7ec.diff
LOG: [mlir][sparse] Fixes bug in VarSet ctor
Previously, the commented out code in the `DimLvlMap` ctor would result in `VarSet::add` raising an OOB error; which should be impossible because the ctor asserted `DimLvlMap::isWF` which ensures that all variables occuring in the map are within bounds for the ranks.
The root cause of that bug was the `VarSet` ctor using `SmallBitVector::reserve` which does not actually change the size of the bitvectors (hence the subsequent OOB). This is corrected by using any of `SmallBitVector::resize`, the move-ctor, or the copy-ctor. Since the default-initialized bitvectors being modified/overwritten have size zero, there shouldn't be any significant performance difference between these three implementations.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D155999
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index d87258756ed87b..ca26a83dcd0cdd 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -252,6 +252,8 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
ArrayRef<LvlSpec> lvlSpecs)
: symRank(symRank), dimSpecs(dimSpecs), lvlSpecs(lvlSpecs) {
// First, check integrity of the variable-binding structure.
+ // NOTE: This establishes the invariant that calls to `VarSet::add`
+ // below cannot cause OOB errors.
assert(isWF());
// TODO: Second, we need to infer/validate the `lvlToDim` mapping.
@@ -260,14 +262,19 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
// needs to happen before the code for setting every `LvlSpec::elideVar`,
// since if the LvlVar is only used in elided DimExpr, then the
// LvlVar should also be elided.
+ // NOTE: Whenever we set a new DimExpr, we must make sure to validate it
+ // against our ranks, to restore the invariant established by `isWF` above.
+ // TODO(wrengr): We might should adjust the `DimLvlExpr` ctor to take a
+ // `Ranks` argument and perform the validation then.
// Third, we set every `LvlSpec::elideVar` according to whether that
// LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
+ // NOTE: The invariant established by `isWF` ensures that the following
+ // calls to `VarSet::add` cannot raise OOB errors.
VarSet usedVars(getRanks());
- // NOTE TO Wren: bypassed for now
- // for (const auto &dimSpec : dimSpecs)
- // if (!dimSpec.canElideExpr())
- // usedVars.add(dimSpec.getExpr());
+ for (const auto &dimSpec : dimSpecs)
+ if (!dimSpec.canElideExpr())
+ usedVars.add(dimSpec.getExpr());
for (auto &lvlSpec : this->lvlSpecs)
lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar()));
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 7250d44b53d0d6..a30058d512d44f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -56,8 +56,12 @@ static constexpr const VarKind everyVarKind[] = {
VarKind::Dimension, VarKind::Symbol, VarKind::Level};
VarSet::VarSet(Ranks const &ranks) {
+ // NOTE: We must not use `reserve` here, since that doesn't change
+ // the `size` of the bitvectors and therefore will result in unexpected
+ // OOB errors. Either `resize` or copy/move-ctor work; we opt for the
+ // move-ctor since it should be (marginally) more efficient.
for (const auto vk : everyVarKind)
- impl[vk].reserve(ranks.getRank(vk));
+ impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
}
bool VarSet::contains(Var var) const {
@@ -67,10 +71,6 @@ bool VarSet::contains(Var var) const {
// bugs in client code.
const llvm::SmallBitVector &bits = impl[var.getKind()];
const auto num = var.getNum();
- // FIXME(wrengr): If we `assert(num < bits.size())` then
- // "roundtrip_encoding.mlir" will fail. So we need to figure out
- // where exactly the OOB `var` is coming from, to determine whether
- // that's a logic bug or not.
return num < bits.size() && bits[num];
}
More information about the Mlir-commits
mailing list