[Mlir-commits] [mlir] 68785c1 - [mlir][sparse] Correcting RTTI implementation for the Var class

wren romano llvmlistbot at llvm.org
Thu Jul 6 18:57:09 PDT 2023


Author: wren romano
Date: 2023-07-06T18:57:02-07:00
New Revision: 68785c1c44bdf68bd9ad09cc141d708108e83479

URL: https://github.com/llvm/llvm-project/commit/68785c1c44bdf68bd9ad09cc141d708108e83479
DIFF: https://github.com/llvm/llvm-project/commit/68785c1c44bdf68bd9ad09cc141d708108e83479.diff

LOG: [mlir][sparse] Correcting RTTI implementation for the Var class

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D154662

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 0c9a2cf348c139..fef806b0fc1024 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -94,56 +94,105 @@ using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
 //===----------------------------------------------------------------------===//
 /// A concrete variable, to be used in our variant of `AffineExpr`.
 class Var {
+  // Design Note: This class makes several distinctions which may at first
+  // seem unnecessary but are in fact needed for implementation reasons.
+  // These distinctions are summarized as follows:
+  //
+  // * `Var`
+  //   Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
+  //   support for subclasses with a fixed `VarKind`.
+  // * `Var::Num`
+  //   Client-facing typedef for the type of variable numbers; defined
+  //   so that client code can use it to disambiguate/document when things
+  //   are intended to be variable numbers, as opposed to some other thing
+  //   which happens to be represented as `unsigned`.
+  // * `Var::Storage`
+  //   Private typedef for the storage of `Var::Impl`; defined only because
+  //   it's also needed for defining `kMaxNum`.  Note that this type must be
+  //   kept distinct from `Var::Num`: not only can they be 
diff erent C++ types
+  //   (even though they currently happen to be the same), but also because
+  //   they use 
diff erent bitwise representations.
+  // * `Var::Impl`
+  //   The underlying implementation of `Var`; needed by RTTI to serve as
+  //   an intermediary between `Var` and `Var::Storage`.  That is, we want
+  //   the RTTI methods to select the `U(Var::Impl)` ctor, without any
+  //   possibility of confusing that with the `U(Var::Num)` ctor nor with
+  //   the copy-ctor.  (Although the `U(Var::Impl)` ctor is effectively
+  //   identical to the copy-ctor, it doesn't have the type that C++ expects
+  //   for a copy-ctor.)
+  //
+  // TODO: See if it'd be cleaner to use "llvm/ADT/Bitfields.h" in lieu
+  // of doing our own bitbashing (though that seems to only be used by LLVM
+  // for defining machine/assembly ops, and not anywhere else in LLVM/MLIR).
 public:
-  /// Typedef to help disambiguate 
diff erent uses of `unsigned`.
+  /// Typedef for the type of variable numbers.
   using Num = unsigned;
 
 private:
-  /// The underlying storage representation of `Var`.  Note that this type
-  /// should be kept distinct from `Num`.  Not only can they be 
diff erent
-  /// C++ types (even though they currently happen to be the same), but
-  /// they also use 
diff erent bitwise representations.
-  //
-  // FUTURE_CL(wrengr): Rather than rolling our own, we should
-  // consider using "llvm/ADT/Bitfields.h"; though that seems to only
-  // be used by LLVM for the sake of defining machine/assembly ops.
-  // Or we could consider abusing `PointerIntPair`...
-  using Impl = unsigned;
-  Impl impl;
-
-  /// The largest `Var::Num` supported by `Var::Impl`.  Two low-order
-  /// bits are reserved for storing the `VarKind`, and one high-order bit
-  /// is reserved for future use (e.g., to support `DenseMapInfo<Var>` while
-  /// maintaining the usual numeric values for "empty" and "tombstone").
+  /// Typedef for the underlying storage of `Var::Impl`.
+  using Storage = unsigned;
+
+  /// The largest `Var::Num` supported by `Var`/`Var::Impl`/`Var::Storage`.
+  /// Two low-order bits are reserved for storing the `VarKind`,
+  /// and one high-order bit is reserved for future use (e.g., to support
+  /// `DenseMapInfo<Var>` while maintaining the usual numeric values for
+  /// "empty" and "tombstone").
   static constexpr Num kMaxNum =
-      static_cast<Num>(std::numeric_limits<Impl>::max() >> 3);
+      static_cast<Num>(std::numeric_limits<Storage>::max() >> 3);
 
 public:
+  /// Checks whether the number would be accepted by `Var(VarKind,Var::Num)`.
+  //
   // This must be public for `VarInfo` to use it (whereas we don't want
   // to expose the `impl` field via friendship).
   static constexpr bool isWF_Num(Num n) { return n <= kMaxNum; }
 
-  constexpr Var(VarKind vk, Num n)
-      : impl((static_cast<Impl>(n) << 2) |
-             static_cast<Impl>(to_underlying(vk))) {
-    assert(isWF(vk) && "unknown VarKind");
-    assert(isWF_Num(n) && "Var::Num is too large");
-  }
+protected:
+  /// The underlying implementation of `Var`.  Note that this must be kept
+  /// distinct from `Var` itself, since we want to ensure that the RTTI
+  /// methods will select the `U(Var::Impl)` ctor rather than selecting
+  /// the `U(Var::Num)` ctor.
+  class Impl final {
+    Storage data;
+
+  public:
+    constexpr Impl(VarKind vk, Num n)
+        : data((static_cast<Storage>(n) << 2) |
+               static_cast<Storage>(to_underlying(vk))) {
+      assert(isWF(vk) && "unknown VarKind");
+      assert(isWF_Num(n) && "Var::Num is too large");
+    }
+    constexpr bool operator==(Impl other) const { return data == other.data; }
+    constexpr bool operator!=(Impl other) const { return !(*this == other); }
+    constexpr VarKind getKind() const { return static_cast<VarKind>(data & 3); }
+    constexpr Num getNum() const { return static_cast<Num>(data >> 2); }
+  };
+  static_assert(IsZeroCostAbstraction<Impl>);
+
+private:
+  Impl impl;
+
+protected:
+  /// Protected ctor for the RTTI methods to use.
+  constexpr explicit Var(Impl impl) : impl(impl) {}
+
+public:
+  constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
   Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
   Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {}
 
   constexpr bool operator==(Var other) const { return impl == other.impl; }
   constexpr bool operator!=(Var other) const { return !(*this == other); }
 
-  constexpr VarKind getKind() const { return static_cast<VarKind>(impl & 3); }
-  constexpr Num getNum() const { return static_cast<Num>(impl >> 2); }
+  constexpr VarKind getKind() const { return impl.getKind(); }
+  constexpr Num getNum() const { return impl.getNum(); }
 
   template <typename U>
   constexpr bool isa() const;
   template <typename U>
   constexpr U cast() const;
   template <typename U>
-  constexpr U dyn_cast() const;
+  constexpr std::optional<U> dyn_cast() const;
 
   void print(llvm::raw_ostream &os) const;
   void print(AsmPrinter &printer) const;
@@ -152,6 +201,7 @@ class Var {
 static_assert(IsZeroCostAbstraction<Var>);
 
 class SymVar final : public Var {
+  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
 public:
   static constexpr VarKind Kind = VarKind::Symbol;
   static constexpr bool classof(Var const *var) {
@@ -163,6 +213,7 @@ class SymVar final : public Var {
 static_assert(IsZeroCostAbstraction<SymVar>);
 
 class DimVar final : public Var {
+  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
 public:
   static constexpr VarKind Kind = VarKind::Dimension;
   static constexpr bool classof(Var const *var) {
@@ -174,6 +225,7 @@ class DimVar final : public Var {
 static_assert(IsZeroCostAbstraction<DimVar>);
 
 class LvlVar final : public Var {
+  using Var::Var; // inherit `Var(Impl)` ctor for RTTI use.
 public:
   static constexpr VarKind Kind = VarKind::Level;
   static constexpr bool classof(Var const *var) {
@@ -202,12 +254,14 @@ constexpr bool Var::isa() const {
 template <typename U>
 constexpr U Var::cast() const {
   assert(isa<U>());
-  return U(impl >> 2); // NOTE TO Wren: confirm this fix
+  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
+  return U(impl);
 }
 
 template <typename U>
-constexpr U Var::dyn_cast() const {
-  return isa<U>() ? U(impl >> 2) : U();
+constexpr std::optional<U> Var::dyn_cast() const {
+  // NOTE: This should select the `U(Var::Impl)` ctor, *not* `U(Var::Num)`
+  return isa<U>() ? std::make_optional(U(impl)) : std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list