[Mlir-commits] [mlir] dcadb68 - [mlir][sparse] Cleaning up OOB implementation details for `VarSet`
wren romano
llvmlistbot at llvm.org
Fri Jul 7 14:36:03 PDT 2023
Author: wren romano
Date: 2023-07-07T14:35:56-07:00
New Revision: dcadb68a5c5a50f41a023b39748a506df8c07248
URL: https://github.com/llvm/llvm-project/commit/dcadb68a5c5a50f41a023b39748a506df8c07248
DIFF: https://github.com/llvm/llvm-project/commit/dcadb68a5c5a50f41a023b39748a506df8c07248.diff
LOG: [mlir][sparse] Cleaning up OOB implementation details for `VarSet`
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D154674
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 35022d7cfa1b09..7b67f8eece0da2 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -56,22 +56,22 @@ static constexpr const VarKind everyVarKind[] = {
VarKind::Dimension, VarKind::Symbol, VarKind::Level};
VarSet::VarSet(Ranks const &ranks) {
- // FIXME(wrengr): will this DWIM, or do we need to worry about
- // `reserve` causing resizing/dangling issues?
for (const auto vk : everyVarKind)
impl[vk].reserve(ranks.getRank(vk));
}
bool VarSet::contains(Var var) const {
- // FIXME(wrengr): this implementation will raise assertion failure on OOB;
- // but perhaps we'd rather have this return false on OOB? That's
- // necessary for consistency with the `anyCommon` implementation of
- // `occursIn(VarSet)`.
+ // NOTE: We make sure to return false on OOB, for consistency with
+ // the `anyCommon` implementation of `VarSet::occursIn(VarSet)`.
+ // However beware that, as always with silencing OOB, this can hide
+ // bugs in client code.
const llvm::SmallBitVector &bits = impl[var.getKind()];
- // NOTE TO Wren: did this to avoid OOB but perhaps it is result of bug
- if (var.getNum() >= bits.size())
- return false;
- return bits[var.getNum()];
+ const auto num = var.getNum();
+ // FIXME(wrengr): If we `assert(num < bits.size())` then
+ // "roundtrip_encoding.mlir" will fail. So we need to figure out
+ // where exactly the OOB `var` is coming from, to determine whether
+ // that's a logic bug or not.
+ return num < bits.size() && bits[num];
}
bool VarSet::occursIn(VarSet const &other) const {
@@ -105,13 +105,20 @@ bool VarSet::occursIn(DimLvlExpr expr) const {
}
void VarSet::add(Var var) {
- // FIXME(wrengr): this implementation will raise assertion failure on OOB;
- // but perhaps we'd rather have this be a noop on OOB? or to grow
- // the underlying bitvectors on OOB?
+ // NOTE: `SmallBitVactor::operator[]` will raise assertion errors for OOB.
impl[var.getKind()][var.getNum()] = true;
}
-// TODO(wrengr): void VarSet::add(VarSet const& other);
+void VarSet::add(VarSet const &other) {
+ // NOTE: `SmallBitVector::operator&=` will implicitly resize
+ // the bitvector (unlike `BitVector::operator&=`), so we add an
+ // assertion against OOB for consistency with the implementation
+ // of `VarSet::add(Var)`.
+ for (const auto vk : everyVarKind) {
+ assert(impl[vk].size() >= other.impl[vk].size());
+ impl[vk] &= other.impl[vk];
+ }
+}
void VarSet::add(DimLvlExpr expr) {
if (!expr)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index fef806b0fc1024..24dd631dba1bc4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -271,7 +271,6 @@ class DimLvlExpr;
//===----------------------------------------------------------------------===//
class Ranks final {
// Not using `VarKindArray` since `EnumeratedArray` doesn't support constexpr.
- // TODO(wrengr): to what extent do we actually care about constexpr here?
unsigned impl[3];
static constexpr unsigned to_index(VarKind vk) {
@@ -303,6 +302,14 @@ class Ranks final {
static_assert(IsZeroCostAbstraction<Ranks>);
//===----------------------------------------------------------------------===//
+/// Efficient representation of a set of `Var`.
+///
+/// NOTE: For the `contains`/`occursIn` methods: if variables occurring in
+/// the method parameter are OOB for the `VarSet`, then these methods will
+/// always return false. However, for the `add` methods: OOB parameters
+/// cause undefined behavior. Currently the `add` methods will raise an
+/// assertion error; though we may change that behavior in the future
+/// (e.g., to resize the underlying bitvectors).
class VarSet final {
// If we're willing to give up the possibility of resizing the
// individual bitvectors, then we could flatten this into a single
@@ -314,14 +321,12 @@ class VarSet final {
public:
explicit VarSet(Ranks const &ranks);
- // TODO(wrengr): can we come up with a single name that works for all three of
- // these?
bool contains(Var var) const;
bool occursIn(VarSet const &vars) const;
bool occursIn(DimLvlExpr expr) const;
void add(Var var);
- // TODO(wrengr): void add(VarSet const& vars);
+ void add(VarSet const &vars);
void add(DimLvlExpr expr);
};
@@ -397,7 +402,6 @@ class VarEnv final {
VarInfo::ID nextID() const { return static_cast<VarInfo::ID>(vars.size()); }
public:
- // NOTE TO Wren: initializer needed!
VarEnv() : nextNum(0) {}
/// Gets the underlying storage for the `VarInfo` identified by
More information about the Mlir-commits
mailing list