[Mlir-commits] [mlir] [mlir][sparse] Parser cleanup (PR #69792)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 23 10:53:19 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Yinying Li (yinying-lisa-li)
<details>
<summary>Changes</summary>
Removed TODOs, FIXMEs and long notes that are more suited for design doc.
---
Patch is 39.63 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69792.diff
8 Files Affected:
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp (-35)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h (+2-63)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp (+4-60)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h (+9-16)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (-18)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h (+5-27)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp (-37)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h (+11-95)
``````````diff
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 5f947b67c6d848e..851867926fe679e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
return k ? std::make_optional(k.getValue()) : std::nullopt;
}
-// This helper method is akin to `AffineExpr::operator==(int64_t)`
-// except it uses a different implementation, namely the implementation
-// used within `AsmPrinter::Impl::printAffineExprInternal`.
-//
-// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
-// this implementation because it avoids constructing the intermediate
-// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
-// However, if it is indeed faster, then the `AffineExpr::operator==`
-// method should be updated to do this instead. And if it isn't any
-// faster, then we should be using `AffineExpr::operator==` instead.
bool DimLvlExpr::hasConstantValue(int64_t val) const {
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
return k && k.getValue() == val;
@@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
}
-bool DimSpec::isFunctionOf(VarSet const &vars) const {
- return vars.occursIn(expr);
-}
-
-void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
void DimSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
@@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && ranks.isValid(expr);
}
-bool LvlSpec::isFunctionOf(VarSet const &vars) const {
- return vars.occursIn(expr);
-}
-
-void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
void LvlSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
@@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
// below cannot cause OOB errors.
assert(isWF());
- // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
- // Along the way we should set every `DimSpec::elideExpr` according
- // to whether the given expression is inferable or not. Notably, this
- // needs to happen before the code for setting every `LvlSpec::elideVar`,
- // since if the LvlVar is only used in elided DimExpr, then the
- // LvlVar should also be elided.
- // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
- // to ensure that we maintain the invariant established by `isWF` above.
-
- // Third, we set every `LvlSpec::elideVar` according to whether that
- // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
- // NOTE: The invariant established by `isWF` ensures that the following
- // calls to `VarSet::add` cannot raise OOB errors.
VarSet usedVars(getRanks());
for (const auto &dimSpec : dimSpecs)
if (!dimSpec.canElideExpr())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 2d02eed2cf9972e..664b49509f070f4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -5,11 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
-// be named as the storage representation of the parameter for the tblgen
-// defn of STEA. We may well need to make the other classes public too,
-// so that the rest of the compiler can use them when necessary.
-//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
@@ -23,16 +18,8 @@ namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
-// TODO(wrengr): Give this enum a better name, so that it fits together
-// with the name of the `DimLvlExpr` class (which may also want a better
-// name). Perhaps make this a nested-type too.
-//
-// NOTE: In the future we will extend this enum to include "counting
-// expressions" required for supporting ITPACK/ELL. Therefore the current
-// underlying-type and representation values should not be relied upon.
enum class ExprKind : bool { Dimension = false, Level = true };
-// TODO(wrengr): still needs a better name....
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
using VK = std::underlying_type_t<VarKind>;
return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
@@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
//===----------------------------------------------------------------------===//
-// TODO(wrengr): The goal of this class is to capture a proof that
-// we've verified that the given `AffineExpr` only has variables of the
-// appropriate kind(s). So we need to actually prove/verify that in the
-// ctor or all its callsites!
class DimLvlExpr {
private:
- // FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
- // the `kind` field should be private and const. However, beware
- // that if we mark any field as `const` or if the fields have differing
- // `private`/`protected` privileges then the `IsZeroCostAbstraction`
- // assertion will fail!
- // (Also, iirc, if we end up moving the `expr` to the subclasses
- // instead, that'll also cause `IsZeroCostAbstraction` to fail.)
ExprKind kind;
AffineExpr expr;
@@ -100,11 +76,6 @@ class DimLvlExpr {
//
// Getters for handling `AffineExpr` subclasses.
//
- // TODO(wrengr): is there any way to make these typesafe without too much
- // templating?
- // TODO(wrengr): Most if not all of these don't actually need to be
- // methods, they could be free-functions instead.
- //
Var castAnyVar() const;
std::optional<Var> dyn_castAnyVar() const;
SymVar castSymVar() const;
@@ -131,9 +102,6 @@ class DimLvlExpr {
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };
- // TODO(wrengr): Does our version of `printAffineExprInternal` really
- // need to be a method, or could it be a free-function instead? (assuming
- // `BindingStrength` goes with it).
void printAffineExprInternal(llvm::raw_ostream &os,
BindingStrength enclosingTightness) const;
void printStrong(llvm::raw_ostream &os) const {
@@ -145,12 +113,7 @@ class DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
-// FUTURE_CL(wrengr): It would be nice to have the subclasses override
-// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
-// the proper covariant return types.
-//
class DimExpr final : public DimLvlExpr {
- // FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
@@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
static_assert(IsZeroCostAbstraction<DimExpr>);
class LvlExpr final : public DimLvlExpr {
- // FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
@@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<LvlExpr>);
-// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
template <typename U>
constexpr bool DimLvlExpr::isa() const {
if constexpr (std::is_same_v<U, DimExpr>)
@@ -247,18 +208,12 @@ class DimSpec final {
/// the result of this predicate.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
- // TODO(wrengr): Use it or loose it.
- bool isFunctionOf(Var var) const;
- bool isFunctionOf(VarSet const &vars) const;
- void getFreeVars(VarSet &vars) const;
-
std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
static_assert(IsZeroCostAbstraction<DimSpec>);
//===----------------------------------------------------------------------===//
@@ -270,13 +225,6 @@ class LvlSpec final {
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
bool elideVar = false;
/// The level-expression.
- //
- // NOTE: For now we use `LvlExpr` because all level-expressions must be
- // `AffineExpr`; however, in the future we will also want to allow "counting
- // expressions", and potentially other kinds of non-affine level-expressions.
- // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
- // so we may consider defining another class for pairing those two together
- // to ensure that the pair is well-formed.
LvlExpr expr;
/// The level-type (== level-format + lvl-properties).
DimLevelType type;
@@ -298,23 +246,14 @@ class LvlSpec final {
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
- //
- // NOTE: Once we introduce "counting expressions" this will need
- // a more sophisticated implementation than `DimSpec::isValid` does.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
- // TODO(wrengr): Use it or loose it.
- bool isFunctionOf(Var var) const;
- bool isFunctionOf(VarSet const &vars) const;
- void getFreeVars(VarSet &vars) const;
-
std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
static_assert(IsZeroCostAbstraction<LvlSpec>);
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 680411f8008ea3d..6fb69d1397e6cfb 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
VarInfo::ID &varID,
bool &didCreate) {
// Save the current location so that we can have error messages point to
- // the right place. Note that `Parser::emitWrongTokenError` starts off
- // with the same location as `AsmParserImpl::getCurrentLocation` returns;
- // however, `Parser` will then do some various munging with the location
- // before actually using it, so `AsmParser::emitError` can't quite be used
- // as a drop-in replacement for `Parser::emitWrongTokenError`.
+ // the right place.
const auto loc = parser.getCurrentLocation();
-
- // Several things to note.
- // (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
- // same conditions as the `AffineParser.cpp`-static free-function
- // `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
- // (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
- // methods do the same song and dance about using
- // `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
- // would want to do if we could use the `Parser` class directly. It
- // doesn't provide the nice error handling we want, but we can work around
- // that.
StringRef name;
if (failed(parser.parseOptionalKeyword(&name))) {
- // If not actually optional, then `emitError`.
ERROR_IF(!isOptional, "expected bare identifier")
- // If is actually optional, then return the null `OptionalParseResult`.
return std::nullopt;
}
- // I don't know if we need to worry about the possibility of the caller
- // recovering from error and then reusing the `DimLvlMapParser` for subsequent
- // `parseVar`, but I'm erring on the side of caution by distinguishing
- // all three possible creation policies.
if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
varID = res->first;
didCreate = res->second;
return success();
}
- // TODO(wrengr): these error messages make sense for our intended usage,
- // but not in general; but it's unclear how best to factor that part out.
+
switch (creationPolicy) {
case Policy::MustNot:
return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
@@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
- // TODO(wrengr): Try to improve the error messages from
- // `VarEnv::emitErrorIfAnyUnbound`.
InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
if (failed(ifd))
return ifd;
@@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
" in symbol binding list");
}
-// FIXME: The forward-declaration of level-vars is a stop-gap workaround
-// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
-// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
-// variables must be bound before entering `AsmParser::parseAffineExpr`,
-// since that method requires every variable to already have a fixed/known
-// `Var::Num`.)
-//
-// However, the forward-declaration list duplicates information which is
-// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
-// the names of the variables themselves, and the order in which the names
-// are bound). This redundancy causes bad UX, and also means we must be
-// sure to verify consistency between the two sources of information.
-//
-// Therefore, it would be best to remove the forward-declaration list from
-// the syntax. This can be achieved by implementing our own version of
-// `AffineParser::parseAffineExpr` which calls
-// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
-// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
-// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
-// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
-// side, and thus would also enable us to avoid the O(n^2) behavior of copying
-// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
-// every time `AsmParser::parseAffineExpr` is called.
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::OptionalBraces,
@@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
- // FIXME(wrengr): This still has the O(n^2) behavior of copying
- // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
- // field every time `parseDimSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
@@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
}
}
-// NOTE: This is factored out as a separate method only because `Var`
-// lacks a default-ctor, which makes this conditional difficult to inline
-// at the one call-site.
FailureOr<LvlVar>
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
// Nothing to parse, just bind an unnamed variable.
@@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
}
ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
- // Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
- // specifies whether that "optional" is actually Must or MustNot.)
+ // Parse the optional lvl-var binding. `requireLvlVarBinding`
+ // specifies whether that "optional" is actually Must or MustNot.
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
FAILURE_IF_FAILED(varRes)
const LvlVar var = *varRes;
// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
- // FIXME(wrengr): This still has the O(n^2) behavior of copying
- // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
- // field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index 013a89ea172b0b5..11f727c40644ae1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,39 +42,32 @@ class DimLvlMapParser final {
FailureOr<DimLvlMap> parseDimLvlMap();
private:
- /// The core code for parsing `Var`. This method abstracts out a lot
- /// of complex details to avoid code duplication; however, client code
- /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
- /// calling this method directly.
+ /// Client code should prefer using `parseVarUsage`
+ /// and `parseVarBinding` rather than calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);
- /// Parse a variable occurence which is a *use* of that variable.
- /// The `requireKnown` parameter specifies how to handle the case of
- /// encountering a valid variable name which is currently unused: when
- /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+ /// Parses a variable occurence which is a *use* of that variable.
+ /// When a valid variable name is currently unused, if
+ /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
/// a new unbound variable will be created.
- ///
- /// NOTE: Just because a variable is *known* (i.e., the name has been
- /// associated with an `VarInfo::ID`), does not mean that the variable
- /// is actually *in scope*.
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
- /// Parse a variable occurence which is a *binding* of that variable.
+ /// Parses a variable occurence which is a *binding* of that variable.
/// The `requireKnown` parameter is for handling the binding of
/// forward-declared variables.
FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
- /// Parse an optional variable binding. When the next token is
+ /// Parses an optional variable binding. When the next token is
/// not a valid variable name, this will bind a new unnamed variable.
/// The returned `bool` indicates whether a variable name was parsed.
FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
/// Binds the given variable: both updating the `VarEnv` itself, and
- /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
- /// to `AsmParser::parseAffineExpr`). This method is already called by the
+ /// the `{dims,lvls}AndSymbols` lists (which will be passed
+ /// to `AsmParser::parseAffineExpr`). This method is already called by the
/// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
/// not need to be called elsewhere.
Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 053e067fff64ddb..8cc7068e3113aff 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/69792
More information about the Mlir-commits
mailing list