[Mlir-commits] [mlir] [mlir][sparse] Parser cleanup (PR #69792)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 23 10:53:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Yinying Li (yinying-lisa-li)

<details>
<summary>Changes</summary>

Removed TODOs, FIXMEs and long notes that are more suited for design doc.

---

Patch is 39.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69792.diff


8 Files Affected:

- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp (-35) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h (+2-63) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp (+4-60) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h (+9-16) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (-18) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h (+5-27) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (-37) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h (+11-95) 


``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 5f947b67c6d848e..851867926fe679e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
   return k ? std::make_optional(k.getValue()) : std::nullopt;
 }
 
-// This helper method is akin to `AffineExpr::operator==(int64_t)`
-// except it uses a different implementation, namely the implementation
-// used within `AsmPrinter::Impl::printAffineExprInternal`.
-//
-// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
-// this implementation because it avoids constructing the intermediate
-// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
-// However, if it is indeed faster, then the `AffineExpr::operator==`
-// method should be updated to do this instead.  And if it isn't any
-// faster, then we should be using `AffineExpr::operator==` instead.
 bool DimLvlExpr::hasConstantValue(int64_t val) const {
   const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
   return k && k.getValue() == val;
@@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && (!expr || ranks.isValid(expr));
 }
 
-bool DimSpec::isFunctionOf(VarSet const &vars) const {
-  return vars.occursIn(expr);
-}
-
-void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
 void DimSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
@@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
   return ranks.isValid(var) && ranks.isValid(expr);
 }
 
-bool LvlSpec::isFunctionOf(VarSet const &vars) const {
-  return vars.occursIn(expr);
-}
-
-void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
 void LvlSpec::dump() const {
   print(llvm::errs(), /*wantElision=*/false);
   llvm::errs() << "\n";
@@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   // below cannot cause OOB errors.
   assert(isWF());
 
-  // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
-  // Along the way we should set every `DimSpec::elideExpr` according
-  // to whether the given expression is inferable or not.  Notably, this
-  // needs to happen before the code for setting every `LvlSpec::elideVar`,
-  // since if the LvlVar is only used in elided DimExpr, then the
-  // LvlVar should also be elided.
-  // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
-  // to ensure that we maintain the invariant established by `isWF` above.
-
-  // Third, we set every `LvlSpec::elideVar` according to whether that
-  // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
-  // NOTE: The invariant established by `isWF` ensures that the following
-  // calls to `VarSet::add` cannot raise OOB errors.
   VarSet usedVars(getRanks());
   for (const auto &dimSpec : dimSpecs)
     if (!dimSpec.canElideExpr())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 2d02eed2cf9972e..664b49509f070f4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -5,11 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
-// be named as the storage representation of the parameter for the tblgen
-// defn of STEA.  We may well need to make the other classes public too,
-// so that the rest of the compiler can use them when necessary.
-//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
 #define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
@@ -23,16 +18,8 @@ namespace sparse_tensor {
 namespace ir_detail {
 
 //===----------------------------------------------------------------------===//
-// TODO(wrengr): Give this enum a better name, so that it fits together
-// with the name of the `DimLvlExpr` class (which may also want a better
-// name).  Perhaps make this a nested-type too.
-//
-// NOTE: In the future we will extend this enum to include "counting
-// expressions" required for supporting ITPACK/ELL.  Therefore the current
-// underlying-type and representation values should not be relied upon.
 enum class ExprKind : bool { Dimension = false, Level = true };
 
-// TODO(wrengr): still needs a better name....
 constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
   using VK = std::underlying_type_t<VarKind>;
   return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
@@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
               getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
 
 //===----------------------------------------------------------------------===//
-// TODO(wrengr): The goal of this class is to capture a proof that
-// we've verified that the given `AffineExpr` only has variables of the
-// appropriate kind(s).  So we need to actually prove/verify that in the
-// ctor or all its callsites!
 class DimLvlExpr {
 private:
-  // FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
-  // the `kind` field should be private and const.  However, beware
-  // that if we mark any field as `const` or if the fields have differing
-  // `private`/`protected` privileges then the `IsZeroCostAbstraction`
-  // assertion will fail!
-  // (Also, iirc, if we end up moving the `expr` to the subclasses
-  // instead, that'll also cause `IsZeroCostAbstraction` to fail.)
   ExprKind kind;
   AffineExpr expr;
 
@@ -100,11 +76,6 @@ class DimLvlExpr {
   //
   // Getters for handling `AffineExpr` subclasses.
   //
-  // TODO(wrengr): is there any way to make these typesafe without too much
-  // templating?
-  // TODO(wrengr): Most if not all of these don't actually need to be
-  // methods, they could be free-functions instead.
-  //
   Var castAnyVar() const;
   std::optional<Var> dyn_castAnyVar() const;
   SymVar castSymVar() const;
@@ -131,9 +102,6 @@ class DimLvlExpr {
   // Variant of `mlir::AsmPrinter::Impl::BindingStrength`
   enum class BindingStrength : bool { Weak = false, Strong = true };
 
-  // TODO(wrengr): Does our version of `printAffineExprInternal` really
-  // need to be a method, or could it be a free-function instead? (assuming
-  // `BindingStrength` goes with it).
   void printAffineExprInternal(llvm::raw_ostream &os,
                                BindingStrength enclosingTightness) const;
   void printStrong(llvm::raw_ostream &os) const {
@@ -145,12 +113,7 @@ class DimLvlExpr {
 };
 static_assert(IsZeroCostAbstraction<DimLvlExpr>);
 
-// FUTURE_CL(wrengr): It would be nice to have the subclasses override
-// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
-// the proper covariant return types.
-//
 class DimExpr final : public DimLvlExpr {
-  // FIXME(wrengr): These two are needed for the current RTTI implementation.
   friend class DimLvlExpr;
   constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
 
@@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
 static_assert(IsZeroCostAbstraction<DimExpr>);
 
 class LvlExpr final : public DimLvlExpr {
-  // FIXME(wrengr): These two are needed for the current RTTI implementation.
   friend class DimLvlExpr;
   constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
 
@@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
 };
 static_assert(IsZeroCostAbstraction<LvlExpr>);
 
-// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
 template <typename U>
 constexpr bool DimLvlExpr::isa() const {
   if constexpr (std::is_same_v<U, DimExpr>)
@@ -247,18 +208,12 @@ class DimSpec final {
   /// the result of this predicate.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
-  // TODO(wrengr): Use it or loose it.
-  bool isFunctionOf(Var var) const;
-  bool isFunctionOf(VarSet const &vars) const;
-  void getFreeVars(VarSet &vars) const;
-
   std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
 };
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
 static_assert(IsZeroCostAbstraction<DimSpec>);
 
 //===----------------------------------------------------------------------===//
@@ -270,13 +225,6 @@ class LvlSpec final {
   /// whereas the `DimLvlMap` ctor will reset this as appropriate.
   bool elideVar = false;
   /// The level-expression.
-  //
-  // NOTE: For now we use `LvlExpr` because all level-expressions must be
-  // `AffineExpr`; however, in the future we will also want to allow "counting
-  // expressions", and potentially other kinds of non-affine level-expressions.
-  // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
-  // so we may consider defining another class for pairing those two together
-  // to ensure that the pair is well-formed.
   LvlExpr expr;
   /// The level-type (== level-format + lvl-properties).
   DimLevelType type;
@@ -298,23 +246,14 @@ class LvlSpec final {
 
   /// Checks whether the variables bound/used by this spec are valid
   /// with respect to the given ranks.
-  //
-  // NOTE: Once we introduce "counting expressions" this will need
-  // a more sophisticated implementation than `DimSpec::isValid` does.
   [[nodiscard]] bool isValid(Ranks const &ranks) const;
 
-  // TODO(wrengr): Use it or loose it.
-  bool isFunctionOf(Var var) const;
-  bool isFunctionOf(VarSet const &vars) const;
-  void getFreeVars(VarSet &vars) const;
-
   std::string str(bool wantElision = true) const;
   void print(llvm::raw_ostream &os, bool wantElision = true) const;
   void print(AsmPrinter &printer, bool wantElision = true) const;
   void dump() const;
 };
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
 static_assert(IsZeroCostAbstraction<LvlSpec>);
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 680411f8008ea3d..6fb69d1397e6cfb 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
                                               VarInfo::ID &varID,
                                               bool &didCreate) {
   // Save the current location so that we can have error messages point to
-  // the right place.  Note that `Parser::emitWrongTokenError` starts off
-  // with the same location as `AsmParserImpl::getCurrentLocation` returns;
-  // however, `Parser` will then do some various munging with the location
-  // before actually using it, so `AsmParser::emitError` can't quite be used
-  // as a drop-in replacement for `Parser::emitWrongTokenError`.
+  // the right place.
   const auto loc = parser.getCurrentLocation();
-
-  // Several things to note.
-  // (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
-  //     same conditions as the `AffineParser.cpp`-static free-function
-  //     `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
-  // (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
-  //     methods do the same song and dance about using
-  //     `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
-  //     would want to do if we could use the `Parser` class directly.  It
-  //     doesn't provide the nice error handling we want, but we can work around
-  //     that.
   StringRef name;
   if (failed(parser.parseOptionalKeyword(&name))) {
-    // If not actually optional, then `emitError`.
     ERROR_IF(!isOptional, "expected bare identifier")
-    // If is actually optional, then return the null `OptionalParseResult`.
     return std::nullopt;
   }
 
-  // I don't know if we need to worry about the possibility of the caller
-  // recovering from error and then reusing the `DimLvlMapParser` for subsequent
-  // `parseVar`, but I'm erring on the side of caution by distinguishing
-  // all three possible creation policies.
   if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
     varID = res->first;
     didCreate = res->second;
     return success();
   }
-  // TODO(wrengr): these error messages make sense for our intended usage,
-  // but not in general; but it's unclear how best to factor that part out.
+
   switch (creationPolicy) {
   case Policy::MustNot:
     return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
@@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
   FAILURE_IF_FAILED(parseDimSpecList())
   FAILURE_IF_FAILED(parser.parseArrow())
   FAILURE_IF_FAILED(parseLvlSpecList())
-  // TODO(wrengr): Try to improve the error messages from
-  // `VarEnv::emitErrorIfAnyUnbound`.
   InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
   if (failed(ifd))
     return ifd;
@@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
       " in symbol binding list");
 }
 
-// FIXME: The forward-declaration of level-vars is a stop-gap workaround
-// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
-// `DimLvlMapParser::parseDimSpec`.  (In particular, note that all the
-// variables must be bound before entering `AsmParser::parseAffineExpr`,
-// since that method requires every variable to already have a fixed/known
-// `Var::Num`.)
-//
-// However, the forward-declaration list duplicates information which is
-// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
-// the names of the variables themselves, and the order in which the names
-// are bound).  This redundancy causes bad UX, and also means we must be
-// sure to verify consistency between the two sources of information.
-//
-// Therefore, it would be best to remove the forward-declaration list from
-// the syntax.  This can be achieved by implementing our own version of
-// `AffineParser::parseAffineExpr` which calls
-// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
-// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
-// some `Var::Num` immediately).  This would also enable us to use the `VarEnv`
-// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
-// side, and thus would also enable us to avoid the O(n^2) behavior of copying
-// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
-// every time `AsmParser::parseAffineExpr` is called.
 ParseResult DimLvlMapParser::parseLvlVarBindingList() {
   return parser.parseCommaSeparatedList(
       OpAsmParser::Delimiter::OptionalBraces,
@@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
   AffineExpr affine;
   if (succeeded(parser.parseOptionalEqual())) {
     // Parse the dim affine expr, with only any lvl-vars in scope.
-    // FIXME(wrengr): This still has the O(n^2) behavior of copying
-    // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
-    // field every time `parseDimSpec` is called.
     FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
   }
   DimExpr expr{affine};
@@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
   }
 }
 
-// NOTE: This is factored out as a separate method only because `Var`
-// lacks a default-ctor, which makes this conditional difficult to inline
-// at the one call-site.
 FailureOr<LvlVar>
 DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
   // Nothing to parse, just bind an unnamed variable.
@@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
 }
 
 ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
-  // Parse the optional lvl-var binding.  (Actually, `requireLvlVarBinding`
-  // specifies whether that "optional" is actually Must or MustNot.)
+  // Parse the optional lvl-var binding. `requireLvlVarBinding`
+  // specifies whether that "optional" is actually Must or MustNot.
   const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
   FAILURE_IF_FAILED(varRes)
   const LvlVar var = *varRes;
 
   // Parse the lvl affine expr, with only the dim-vars in scope.
   AffineExpr affine;
-  // FIXME(wrengr): This still has the O(n^2) behavior of copying
-  // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
-  // field every time `parseLvlSpec` is called.
   FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
   LvlExpr expr{affine};
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index 013a89ea172b0b5..11f727c40644ae1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,39 +42,32 @@ class DimLvlMapParser final {
   FailureOr<DimLvlMap> parseDimLvlMap();
 
 private:
-  /// The core code for parsing `Var`.  This method abstracts out a lot
-  /// of complex details to avoid code duplication; however, client code
-  /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
-  /// calling this method directly.
+  /// Client code should prefer using `parseVarUsage`
+  /// and `parseVarBinding` rather than calling this method directly.
   OptionalParseResult parseVar(VarKind vk, bool isOptional,
                                Policy creationPolicy, VarInfo::ID &id,
                                bool &didCreate);
 
-  /// Parse a variable occurence which is a *use* of that variable.
-  /// The `requireKnown` parameter specifies how to handle the case of
-  /// encountering a valid variable name which is currently unused: when
-  /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+  /// Parses a variable occurence which is a *use* of that variable.
+  /// When a valid variable name is currently unused, if
+  /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
   /// a new unbound variable will be created.
-  ///
-  /// NOTE: Just because a variable is *known* (i.e., the name has been
-  /// associated with an `VarInfo::ID`), does not mean that the variable
-  /// is actually *in scope*.
   FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
 
-  /// Parse a variable occurence which is a *binding* of that variable.
+  /// Parses a variable occurence which is a *binding* of that variable.
   /// The `requireKnown` parameter is for handling the binding of
   /// forward-declared variables.
   FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
 
-  /// Parse an optional variable binding.  When the next token is
+  /// Parses an optional variable binding. When the next token is
   /// not a valid variable name, this will bind a new unnamed variable.
   /// The returned `bool` indicates whether a variable name was parsed.
   FailureOr<std::pair<Var, bool>>
   parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
 
   /// Binds the given variable: both updating the `VarEnv` itself, and
-  /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
-  /// to `AsmParser::parseAffineExpr`).  This method is already called by the
+  /// the `{dims,lvls}AndSymbols` lists (which will be passed
+  /// to `AsmParser::parseAffineExpr`). This method is already called by the
   /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
   /// not need to be called elsewhere.
   Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 053e067fff64ddb..8cc7068e3113aff 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/69792


More information about the Mlir-commits mailing list