[Mlir-commits] [mlir] [mlir][sparse] Parser cleanup (PR #69792)
Yinying Li
llvmlistbot at llvm.org
Fri Oct 20 15:42:22 PDT 2023
https://github.com/yinying-lisa-li created https://github.com/llvm/llvm-project/pull/69792
Removed TODOs, FIXMEs and long notes that are more suited for design doc.
>From f0ce8bde075f39046eb2b365f2decbb3383da349 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Fri, 20 Oct 2023 21:45:34 +0000
Subject: [PATCH] [mlir][sparse] Parser cleanup
Removed TODOs, FIXMEs and long notes that are more suited for design doc.
---
.../SparseTensor/IR/Detail/DimLvlMap.cpp | 35 ------
.../SparseTensor/IR/Detail/DimLvlMap.h | 65 +----------
.../IR/Detail/DimLvlMapParser.cpp | 64 +----------
.../SparseTensor/IR/Detail/DimLvlMapParser.h | 25 ++---
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 18 ---
.../SparseTensor/IR/Detail/TemplateExtras.h | 32 +-----
.../Dialect/SparseTensor/IR/Detail/Var.cpp | 37 ------
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h | 106 ++----------------
8 files changed, 31 insertions(+), 351 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index 5f947b67c6d848e..851867926fe679e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -60,16 +60,6 @@ std::optional<int64_t> DimLvlExpr::dyn_castConstantValue() const {
return k ? std::make_optional(k.getValue()) : std::nullopt;
}
-// This helper method is akin to `AffineExpr::operator==(int64_t)`
-// except it uses a different implementation, namely the implementation
-// used within `AsmPrinter::Impl::printAffineExprInternal`.
-//
-// wrengr guesses that `AsmPrinter::Impl::printAffineExprInternal` uses
-// this implementation because it avoids constructing the intermediate
-// `AffineConstantExpr(val)` and thus should in theory be a bit faster.
-// However, if it is indeed faster, then the `AffineExpr::operator==`
-// method should be updated to do this instead. And if it isn't any
-// faster, then we should be using `AffineExpr::operator==` instead.
bool DimLvlExpr::hasConstantValue(int64_t val) const {
const auto k = expr.dyn_cast_or_null<AffineConstantExpr>();
return k && k.getValue() == val;
@@ -216,12 +206,6 @@ bool DimSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && (!expr || ranks.isValid(expr));
}
-bool DimSpec::isFunctionOf(VarSet const &vars) const {
- return vars.occursIn(expr);
-}
-
-void DimSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
void DimSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
@@ -262,12 +246,6 @@ bool LvlSpec::isValid(Ranks const &ranks) const {
return ranks.isValid(var) && ranks.isValid(expr);
}
-bool LvlSpec::isFunctionOf(VarSet const &vars) const {
- return vars.occursIn(expr);
-}
-
-void LvlSpec::getFreeVars(VarSet &vars) const { vars.add(expr); }
-
void LvlSpec::dump() const {
print(llvm::errs(), /*wantElision=*/false);
llvm::errs() << "\n";
@@ -301,19 +279,6 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
// below cannot cause OOB errors.
assert(isWF());
- // TODO: Second, we need to infer/validate the `lvlToDim` mapping.
- // Along the way we should set every `DimSpec::elideExpr` according
- // to whether the given expression is inferable or not. Notably, this
- // needs to happen before the code for setting every `LvlSpec::elideVar`,
- // since if the LvlVar is only used in elided DimExpr, then the
- // LvlVar should also be elided.
- // NOTE: Be sure to use `DimLvlMap::setDimExpr` for setting the new exprs,
- // to ensure that we maintain the invariant established by `isWF` above.
-
- // Third, we set every `LvlSpec::elideVar` according to whether that
- // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
- // NOTE: The invariant established by `isWF` ensures that the following
- // calls to `VarSet::add` cannot raise OOB errors.
VarSet usedVars(getRanks());
for (const auto &dimSpec : dimSpecs)
if (!dimSpec.canElideExpr())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
index 2d02eed2cf9972e..664b49509f070f4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.h
@@ -5,11 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-// FIXME(wrengr): The `DimLvlMap` class must be public so that it can
-// be named as the storage representation of the parameter for the tblgen
-// defn of STEA. We may well need to make the other classes public too,
-// so that the rest of the compiler can use them when necessary.
-//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
#define MLIR_DIALECT_SPARSETENSOR_IR_DETAIL_DIMLVLMAP_H
@@ -23,16 +18,8 @@ namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
-// TODO(wrengr): Give this enum a better name, so that it fits together
-// with the name of the `DimLvlExpr` class (which may also want a better
-// name). Perhaps make this a nested-type too.
-//
-// NOTE: In the future we will extend this enum to include "counting
-// expressions" required for supporting ITPACK/ELL. Therefore the current
-// underlying-type and representation values should not be relied upon.
enum class ExprKind : bool { Dimension = false, Level = true };
-// TODO(wrengr): still needs a better name....
constexpr VarKind getVarKindAllowedInExpr(ExprKind ek) {
using VK = std::underlying_type_t<VarKind>;
return VarKind{2 * static_cast<VK>(!to_underlying(ek))};
@@ -41,19 +28,8 @@ static_assert(getVarKindAllowedInExpr(ExprKind::Dimension) == VarKind::Level &&
getVarKindAllowedInExpr(ExprKind::Level) == VarKind::Dimension);
//===----------------------------------------------------------------------===//
-// TODO(wrengr): The goal of this class is to capture a proof that
-// we've verified that the given `AffineExpr` only has variables of the
-// appropriate kind(s). So we need to actually prove/verify that in the
-// ctor or all its callsites!
class DimLvlExpr {
private:
- // FIXME(wrengr): Per <https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html>,
- // the `kind` field should be private and const. However, beware
- // that if we mark any field as `const` or if the fields have differing
- // `private`/`protected` privileges then the `IsZeroCostAbstraction`
- // assertion will fail!
- // (Also, iirc, if we end up moving the `expr` to the subclasses
- // instead, that'll also cause `IsZeroCostAbstraction` to fail.)
ExprKind kind;
AffineExpr expr;
@@ -100,11 +76,6 @@ class DimLvlExpr {
//
// Getters for handling `AffineExpr` subclasses.
//
- // TODO(wrengr): is there any way to make these typesafe without too much
- // templating?
- // TODO(wrengr): Most if not all of these don't actually need to be
- // methods, they could be free-functions instead.
- //
Var castAnyVar() const;
std::optional<Var> dyn_castAnyVar() const;
SymVar castSymVar() const;
@@ -131,9 +102,6 @@ class DimLvlExpr {
// Variant of `mlir::AsmPrinter::Impl::BindingStrength`
enum class BindingStrength : bool { Weak = false, Strong = true };
- // TODO(wrengr): Does our version of `printAffineExprInternal` really
- // need to be a method, or could it be a free-function instead? (assuming
- // `BindingStrength` goes with it).
void printAffineExprInternal(llvm::raw_ostream &os,
BindingStrength enclosingTightness) const;
void printStrong(llvm::raw_ostream &os) const {
@@ -145,12 +113,7 @@ class DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<DimLvlExpr>);
-// FUTURE_CL(wrengr): It would be nice to have the subclasses override
-// `getRHS`, `getLHS`, `unpackBinop`, and `castDimLvlVar` to give them
-// the proper covariant return types.
-//
class DimExpr final : public DimLvlExpr {
- // FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit DimExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
@@ -170,7 +133,6 @@ class DimExpr final : public DimLvlExpr {
static_assert(IsZeroCostAbstraction<DimExpr>);
class LvlExpr final : public DimLvlExpr {
- // FIXME(wrengr): These two are needed for the current RTTI implementation.
friend class DimLvlExpr;
constexpr explicit LvlExpr(DimLvlExpr expr) : DimLvlExpr(expr) {}
@@ -189,7 +151,6 @@ class LvlExpr final : public DimLvlExpr {
};
static_assert(IsZeroCostAbstraction<LvlExpr>);
-// FIXME(wrengr): See comments elsewhere re RTTI implementation issues/questions
template <typename U>
constexpr bool DimLvlExpr::isa() const {
if constexpr (std::is_same_v<U, DimExpr>)
@@ -247,18 +208,12 @@ class DimSpec final {
/// the result of this predicate.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
- // TODO(wrengr): Use it or loose it.
- bool isFunctionOf(Var var) const;
- bool isFunctionOf(VarSet const &vars) const;
- void getFreeVars(VarSet &vars) const;
-
std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
static_assert(IsZeroCostAbstraction<DimSpec>);
//===----------------------------------------------------------------------===//
@@ -270,13 +225,6 @@ class LvlSpec final {
/// whereas the `DimLvlMap` ctor will reset this as appropriate.
bool elideVar = false;
/// The level-expression.
- //
- // NOTE: For now we use `LvlExpr` because all level-expressions must be
- // `AffineExpr`; however, in the future we will also want to allow "counting
- // expressions", and potentially other kinds of non-affine level-expressions.
- // Which kinds of `DimLvlExpr` are allowed will depend on the `DimLevelType`,
- // so we may consider defining another class for pairing those two together
- // to ensure that the pair is well-formed.
LvlExpr expr;
/// The level-type (== level-format + lvl-properties).
DimLevelType type;
@@ -298,23 +246,14 @@ class LvlSpec final {
/// Checks whether the variables bound/used by this spec are valid
/// with respect to the given ranks.
- //
- // NOTE: Once we introduce "counting expressions" this will need
- // a more sophisticated implementation than `DimSpec::isValid` does.
[[nodiscard]] bool isValid(Ranks const &ranks) const;
- // TODO(wrengr): Use it or loose it.
- bool isFunctionOf(Var var) const;
- bool isFunctionOf(VarSet const &vars) const;
- void getFreeVars(VarSet &vars) const;
-
std::string str(bool wantElision = true) const;
void print(llvm::raw_ostream &os, bool wantElision = true) const;
void print(AsmPrinter &printer, bool wantElision = true) const;
void dump() const;
};
-// Although this class is more than just a newtype/wrapper, we do want
-// to ensure that storing them into `SmallVector` is efficient.
+
static_assert(IsZeroCostAbstraction<LvlSpec>);
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 680411f8008ea3d..6fb69d1397e6cfb 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -44,42 +44,20 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
VarInfo::ID &varID,
bool &didCreate) {
// Save the current location so that we can have error messages point to
- // the right place. Note that `Parser::emitWrongTokenError` starts off
- // with the same location as `AsmParserImpl::getCurrentLocation` returns;
- // however, `Parser` will then do some various munging with the location
- // before actually using it, so `AsmParser::emitError` can't quite be used
- // as a drop-in replacement for `Parser::emitWrongTokenError`.
+ // the right place.
const auto loc = parser.getCurrentLocation();
-
- // Several things to note.
- // (1) the `Parser::isCurrentTokenAKeyword` method checks the exact
- // same conditions as the `AffineParser.cpp`-static free-function
- // `isIdentifier` which is used by `AffineParser::parseBareIdExpr`.
- // (2) the `{Parser,AsmParserImpl}::parseOptionalKeyword(StringRef*)`
- // methods do the same song and dance about using
- // `isCurrentTokenAKeyword`, `getTokenSpelling`, et `consumeToken` as we
- // would want to do if we could use the `Parser` class directly. It
- // doesn't provide the nice error handling we want, but we can work around
- // that.
StringRef name;
if (failed(parser.parseOptionalKeyword(&name))) {
- // If not actually optional, then `emitError`.
ERROR_IF(!isOptional, "expected bare identifier")
- // If is actually optional, then return the null `OptionalParseResult`.
return std::nullopt;
}
- // I don't know if we need to worry about the possibility of the caller
- // recovering from error and then reusing the `DimLvlMapParser` for subsequent
- // `parseVar`, but I'm erring on the side of caution by distinguishing
- // all three possible creation policies.
if (const auto res = env.lookupOrCreate(creationPolicy, name, loc, vk)) {
varID = res->first;
didCreate = res->second;
return success();
}
- // TODO(wrengr): these error messages make sense for our intended usage,
- // but not in general; but it's unclear how best to factor that part out.
+
switch (creationPolicy) {
case Policy::MustNot:
return parser.emitError(loc, "use of undeclared identifier '" + name + "'");
@@ -167,8 +145,6 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
- // TODO(wrengr): Try to improve the error messages from
- // `VarEnv::emitErrorIfAnyUnbound`.
InFlightDiagnostic ifd = env.emitErrorIfAnyUnbound(parser);
if (failed(ifd))
return ifd;
@@ -182,29 +158,6 @@ ParseResult DimLvlMapParser::parseSymbolBindingList() {
" in symbol binding list");
}
-// FIXME: The forward-declaration of level-vars is a stop-gap workaround
-// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
-// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
-// variables must be bound before entering `AsmParser::parseAffineExpr`,
-// since that method requires every variable to already have a fixed/known
-// `Var::Num`.)
-//
-// However, the forward-declaration list duplicates information which is
-// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
-// the names of the variables themselves, and the order in which the names
-// are bound). This redundancy causes bad UX, and also means we must be
-// sure to verify consistency between the two sources of information.
-//
-// Therefore, it would be best to remove the forward-declaration list from
-// the syntax. This can be achieved by implementing our own version of
-// `AffineParser::parseAffineExpr` which calls
-// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
-// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
-// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
-// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
-// side, and thus would also enable us to avoid the O(n^2) behavior of copying
-// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
-// every time `AsmParser::parseAffineExpr` is called.
ParseResult DimLvlMapParser::parseLvlVarBindingList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::OptionalBraces,
@@ -233,9 +186,6 @@ ParseResult DimLvlMapParser::parseDimSpec() {
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
- // FIXME(wrengr): This still has the O(n^2) behavior of copying
- // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
- // field every time `parseDimSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
@@ -304,9 +254,6 @@ static inline Twine nth(Var::Num n) {
}
}
-// NOTE: This is factored out as a separate method only because `Var`
-// lacks a default-ctor, which makes this conditional difficult to inline
-// at the one call-site.
FailureOr<LvlVar>
DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
// Nothing to parse, just bind an unnamed variable.
@@ -336,17 +283,14 @@ DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
}
ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
- // Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
- // specifies whether that "optional" is actually Must or MustNot.)
+ // Parse the optional lvl-var binding. `requireLvlVarBinding`
+ // specifies whether that "optional" is actually Must or MustNot.
const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
FAILURE_IF_FAILED(varRes)
const LvlVar var = *varRes;
// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
- // FIXME(wrengr): This still has the O(n^2) behavior of copying
- // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
- // field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index 013a89ea172b0b5..11f727c40644ae1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,39 +42,32 @@ class DimLvlMapParser final {
FailureOr<DimLvlMap> parseDimLvlMap();
private:
- /// The core code for parsing `Var`. This method abstracts out a lot
- /// of complex details to avoid code duplication; however, client code
- /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
- /// calling this method directly.
+ /// Client code should prefer using `parseVarUsage`
+ /// and `parseVarBinding` rather than calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);
- /// Parse a variable occurence which is a *use* of that variable.
- /// The `requireKnown` parameter specifies how to handle the case of
- /// encountering a valid variable name which is currently unused: when
- /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+ /// Parses a variable occurence which is a *use* of that variable.
+ /// When a valid variable name is currently unused, if
+ /// `requireKnown=true`, an error is raised; if `requireKnown=false`,
/// a new unbound variable will be created.
- ///
- /// NOTE: Just because a variable is *known* (i.e., the name has been
- /// associated with an `VarInfo::ID`), does not mean that the variable
- /// is actually *in scope*.
FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
- /// Parse a variable occurence which is a *binding* of that variable.
+ /// Parses a variable occurence which is a *binding* of that variable.
/// The `requireKnown` parameter is for handling the binding of
/// forward-declared variables.
FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
- /// Parse an optional variable binding. When the next token is
+ /// Parses an optional variable binding. When the next token is
/// not a valid variable name, this will bind a new unnamed variable.
/// The returned `bool` indicates whether a variable name was parsed.
FailureOr<std::pair<Var, bool>>
parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
/// Binds the given variable: both updating the `VarEnv` itself, and
- /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
- /// to `AsmParser::parseAffineExpr`). This method is already called by the
+ /// the `{dims,lvls}AndSymbols` lists (which will be passed
+ /// to `AsmParser::parseAffineExpr`). This method is already called by the
/// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
/// not need to be called elsewhere.
Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 053e067fff64ddb..8cc7068e3113aff 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -14,29 +14,11 @@ using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
//===----------------------------------------------------------------------===//
-// TODO(wrengr): rephrase these to do the trick for gobbling up any trailing
-// semicolon
-//
-// NOTE: There's no way for `FAILURE_IF_FAILED` to simultaneously support
-// both `OptionalParseResult` and `InFlightDiagnostic` return types.
-// We can get the compiler to accept the code if we returned "`{}`",
-// however for `OptionalParseResult` that would become the nullopt result,
-// whereas for `InFlightDiagnostic` it would become a result that can
-// be implicitly converted to success. By using "`failure()`" we ensure
-// that `OptionalParseResult` behaves as intended, however that means the
-// macro cannot be used for `InFlightDiagnostic` since there's no implicit
-// conversion.
#define FAILURE_IF_FAILED(STMT) \
if (failed(STMT)) { \
return failure(); \
}
-// Although `ERROR_IF` is phrased to return `InFlightDiagnostic`, that type
-// can be implicitly converted to all four of `LogicalResult, `FailureOr`,
-// `ParseResult`, and `OptionalParseResult`. (However, beware that the
-// conversion to `OptionalParseResult` doesn't properly delegate to
-// `InFlightDiagnostic::operator ParseResult`.)
-//
// NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
#define ERROR_IF(COND, MSG) \
if (COND) { \
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h
index 5c222dd966f4f73..7f0c1fd8c46c78b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/TemplateExtras.h
@@ -19,8 +19,6 @@ namespace sparse_tensor {
namespace ir_detail {
//===----------------------------------------------------------------------===//
-// These two templates are like `AsmPrinter::{,detect_}has_print_method`,
-// except they detect print methods taking `raw_ostream` (not `AsmPrinter`).
template <typename T>
using has_print_method =
decltype(std::declval<T>().print(std::declval<llvm::raw_ostream &>()));
@@ -31,9 +29,7 @@ using enable_if_has_print_method =
std::enable_if_t<detect_has_print_method<T>::value, R>;
/// Generic template for defining `operator<<` overloads which delegate
-/// to `T::print(raw_ostream&) const`. Note that there's already another
-/// generic template which defines `operator<<(AsmPrinterT&, T const&)`
-/// via delegating to `operator<<(raw_ostream&, T const&)`.
+/// to `T::print(raw_ostream&) const`.
template <typename T>
inline enable_if_has_print_method<T, llvm::raw_ostream &>
operator<<(llvm::raw_ostream &os, T const &t) {
@@ -42,9 +38,7 @@ operator<<(llvm::raw_ostream &os, T const &t) {
}
//===----------------------------------------------------------------------===//
-/// Convert an enum to its underlying type. This template is designed
-/// to avoid introducing implicit conversions to other integral types,
-/// and is a backport of C++23 `std::to_underlying`.
+/// Convert an enum to its underlying type.
template <typename Enum>
constexpr std::underlying_type_t<Enum> to_underlying(Enum e) noexcept {
return static_cast<std::underlying_type_t<Enum>>(e);
@@ -53,28 +47,12 @@ constexpr std::underlying_type_t<Enum> to_underlying(Enum e) noexcept {
//===----------------------------------------------------------------------===//
template <typename T>
static constexpr bool IsZeroCostAbstraction =
- // These two predicates license the compiler to make several optimizations;
- // some of which are explicitly documented by the C++ standard:
- // <https://en.cppreference.com/w/cpp/types/is_trivially_copyable#Notes>
- // <https://en.cppreference.com/w/cpp/types/is_trivially_destructible#Notes>
- // However, some key optimizations aren't mentioned by the standard; e.g.,
- // that trivially-copyable enables passing-by-value, and the conjunction
- // of trivially-copyable and trivially-destructible enables passing those
- // values in registers rather than on the stack (cf.,
- // <https://www.agner.org/optimize/calling_conventions.pdf>).
+ // These two predicates license the compiler to make optimizations.
std::is_trivially_copyable_v<T> && std::is_trivially_destructible_v<T> &&
- // This one helps ensure ABI compatibility (e.g., padding and alignment):
- // <https://en.cppreference.com/w/cpp/types/is_standard_layout#Notes>
- // <https://en.cppreference.com/w/cpp/language/classes#Standard-layout_class>
- // In particular, the standard mentions that passing/returning a `struct`
- // by value can sometimes introduce ABI overhead compared to using
- // `enum class`; so this assertion is attempting to avoid that.
- // <https://en.cppreference.com/w/cpp/language/enum#enum_relaxed_init_cpp17>
+ // This helps ensure ABI compatibility (e.g., padding and alignment).
std::is_standard_layout_v<T> &&
// These two are what SmallVector uses to determine whether it can
- // use memcpy. The commentary there mentions that it's intended to be
- // an approximation of `is_trivially_copyable`, so this may be redundant
- // with the above, but we include it just to make sure.
+ // use memcpy.
std::is_trivially_copy_constructible<T>::value &&
std::is_trivially_move_constructible<T>::value;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 44eba668021ba79..966e32401c1f9e3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -61,9 +61,6 @@ bool Ranks::isValid(DimLvlExpr expr) const {
int64_t maxSym = -1, maxVar = -1;
mlir::getMaxDimAndSymbol<ArrayRef<AffineExpr>>({{expr.getAffineExpr()}},
maxVar, maxSym);
- // TODO(wrengr): We may want to add a call to `LLVM_DEBUG` like
- // `willBeValidAffineMap` does. And/or should return `InFlightDiagnostic`
- // instead of bool.
return maxSym < getSymRank() && maxVar < getRank(expr.getAllowedVarKind());
}
@@ -72,10 +69,6 @@ bool Ranks::isValid(DimLvlExpr expr) const {
//===----------------------------------------------------------------------===//
VarSet::VarSet(Ranks const &ranks) {
- // NOTE: We must not use `reserve` here, since that doesn't change
- // the `size` of the bitvectors and therefore will result in unexpected
- // OOB errors. Either `resize` or copy/move-ctor work; we opt for the
- // move-ctor since it should be (marginally) more efficient.
for (const auto vk : everyVarKind)
impl[vk] = llvm::SmallBitVector(ranks.getRank(vk));
assert(getRanks() == ranks);
@@ -180,9 +173,6 @@ void VarInfo::setNum(Var::Num n) {
/// Helper function for `assertUsageConsistency` to better handle SMLoc
/// mismatches.
-// TODO(wrengr): If we switch to the `LocatedVar` design, then there's
-// no need for anything like `minSMLoc` since `assertUsageConsistency`
-// won't need to do anything about locations.
LLVM_ATTRIBUTE_UNUSED static llvm::SMLoc
minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
const auto loc1 = parser.getEncodedSourceLoc(sm1).dyn_cast<FileLineColLoc>();
@@ -201,28 +191,13 @@ bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
return (var.getName() == name && var.getID() == id);
}
-// NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
-// (or find some other way to convert SMLoc to FileLineColLoc), then this
-// would no longer be `const VarEnv` (and couldn't be a free-function either).
bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
VarKind vk) {
const auto &var = env.access(id);
- // Since the same variable can occur at several locations,
- // it would not be appropriate to do `assert(var.getLoc() == loc)`.
- /* TODO(wrengr):
- const auto minLoc = minSMLoc(_, var.getLoc(), loc);
- assert(minLoc && "Location mismatch/incompatibility");
- var.loc = minLoc;
- // */
return var.getKind() == vk;
}
std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
- // NOTE: `StringMap::lookup` will return a default-constructed value if
- // the key isn't found; which for enums means zero, and therefore makes
- // it impossible to distinguish between actual zero-VarInfo::ID vs not-found.
- // Whereas `StringMap::at` asserts that the key is found, which we don't
- // want either.
const auto iter = ids.find(name);
if (iter == ids.end())
return std::nullopt;
@@ -277,22 +252,10 @@ Var VarEnv::bindUnusedVar(VarKind vk) { return Var(vk, nextNum[vk]++); }
Var VarEnv::bindVar(VarInfo::ID id) {
auto &info = access(id);
const auto var = bindUnusedVar(info.getKind());
- // NOTE: `setNum` already checks wellformedness of the `Var::Num`.
info.setNum(var.getNum());
return var;
}
-// TODO(wrengr): Alternatively there's `mlir::emitError(Location, Twine const&)`
-// which is what `Operation::emitError` uses; though I'm not sure if
-// that's appropriate to use here... But if it is, then that means
-// we can have `VarInfo` store `Location` rather than `SMLoc`, which
-// means we can use `FusedLoc` to handle the combination issue in
-// `VarEnv::lookupOrCreate`.
-//
-// TODO(wrengr): is there any way to combine multiple IFDs, so that
-// we can report all unbound variables instead of just the first one
-// encountered?
-//
InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
for (const auto &var : vars)
if (!var.hasNum())
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 145586a83a2528c..2606dd399eec81c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -20,36 +20,6 @@ namespace mlir {
namespace sparse_tensor {
namespace ir_detail {
-// Throughout this namespace we use the name `isWF` (is "well-formed")
-// for predicates that detect intrinsic structural integrity criteria,
-// and hence which should always be assertively true. Whereas we reserve
-// the name `isValid` for predicates that detect extrinsic semantic
-// integrity criteria, and hence which may legitimately return false even
-// in well-formed programs. Moreover, "validity" is often a relational
-// or contextual property, and therefore the same term may be considered
-// valid in one context yet invalid in another.
-//
-// As an example of why we make this distinction, consider `Var`.
-// A variable is well-formed if its kind and identifier are both well-formed;
-// this can be checked locally, and the resulting truth-value holds globally.
-// Whereas, a variable is valid with respect to a particular `Ranks` only if
-// it is within bounds; and a variable is valid with respect to a particular
-// `DimLvlMap` only if the variable is bound and all uses of the variable
-// are within the scope of that binding.
-
-// Throughout this namespace we use `enum class` types to form "newtypes".
-// The enum-based implementation of newtypes only serves to block implicit
-// conversions; it cannot enforce any wellformedness constraints, since
-// `enum class` permits using direct-list-initialization to construct
-// arbitrary values[1]. Consequently, we use the syntax "`E{u}`" whenever
-// we intend that ctor to be a noop (i.e., `std::is_same_v<decltype(u),
-// std::underlying_type_t<E>>`), since the compiler will ensure that that's
-// the case. Whereas we only use the "`static_cast<E>(u)`" syntax when we
-// specifically intend to introduce conversions.
-//
-// [1]:
-// <https://en.cppreference.com/w/cpp/language/enum#enum_relaxed_init_cpp17>
-
//===----------------------------------------------------------------------===//
/// The three kinds of variables that `Var` can be.
///
@@ -93,37 +63,9 @@ using VarKindArray = llvm::EnumeratedArray<T, VarKind, VarKind::Level>;
//===----------------------------------------------------------------------===//
/// A concrete variable, to be used in our variant of `AffineExpr`.
+/// Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
+/// support for subclasses with a fixed `VarKind`.
class Var {
- // Design Note: This class makes several distinctions which may at first
- // seem unnecessary but are in fact needed for implementation reasons.
- // These distinctions are summarized as follows:
- //
- // * `Var`
- // Client-facing class for `VarKind` + `Var::Num` pairs, with RTTI
- // support for subclasses with a fixed `VarKind`.
- // * `Var::Num`
- // Client-facing typedef for the type of variable numbers; defined
- // so that client code can use it to disambiguate/document when things
- // are intended to be variable numbers, as opposed to some other thing
- // which happens to be represented as `unsigned`.
- // * `Var::Storage`
- // Private typedef for the storage of `Var::Impl`; defined only because
- // it's also needed for defining `kMaxNum`. Note that this type must be
- // kept distinct from `Var::Num`: not only can they be different C++ types
- // (even though they currently happen to be the same), but also because
- // they use different bitwise representations.
- // * `Var::Impl`
- // The underlying implementation of `Var`; needed by RTTI to serve as
- // an intermediary between `Var` and `Var::Storage`. That is, we want
- // the RTTI methods to select the `U(Var::Impl)` ctor, without any
- // possibility of confusing that with the `U(Var::Num)` ctor nor with
- // the copy-ctor. (Although the `U(Var::Impl)` ctor is effectively
- // identical to the copy-ctor, it doesn't have the type that C++ expects
- // for a copy-ctor.)
- //
- // TODO: See if it'd be cleaner to use "llvm/ADT/Bitfields.h" in lieu
- // of doing our own bitbashing (though that seems to only be used by LLVM
- // for defining machine/assembly ops, and not anywhere else in LLVM/MLIR).
public:
/// Typedef for the type of variable numbers.
using Num = unsigned;
@@ -179,7 +121,6 @@ class Var {
public:
constexpr Var(VarKind vk, Num n) : impl(Impl(vk, n)) {}
Var(AffineSymbolExpr sym) : Var(VarKind::Symbol, sym.getPosition()) {}
- // TODO(wrengr): Should make the first argument an `ExprKind` instead...?
Var(VarKind vk, AffineDimExpr var) : Var(vk, var.getPosition()) {
assert(vk != VarKind::Symbol);
}
@@ -248,11 +189,6 @@ constexpr bool Var::isa() const {
return getKind() == VarKind::Dimension;
if constexpr (std::is_same_v<U, LvlVar>)
return getKind() == VarKind::Level;
- // NOTE: The `AffineExpr::isa` implementation doesn't have a fallthrough
- // case returning `false`; wrengr guesses that's so things will fail
- // to compile whenever `!std::is_base_of<Var, U>`. Though it's unclear
- // why they implemented it that way rather than using SFINAE for that,
- // especially since it would give better error messages.
}
template <typename U>
@@ -310,19 +246,7 @@ static_assert(IsZeroCostAbstraction<Ranks>);
//===----------------------------------------------------------------------===//
/// Efficient representation of a set of `Var`.
-///
-/// NOTE: For the `contains`/`occursIn` methods: if variables occurring in
-/// the method parameter are OOB for the `VarSet`, then these methods will
-/// always return false. However, for the `add` methods: OOB parameters
-/// cause undefined behavior. Currently the `add` methods will raise an
-/// assertion error; though we may change that behavior in the future
-/// (e.g., to resize the underlying bitvectors).
class VarSet final {
- // If we're willing to give up the possibility of resizing the
- // individual bitvectors, then we could flatten this into a single
- // bitvector (akin to how `mlir::presburger::PresburgerSpace` does it);
- // however, doing so would greatly complicate the implementation of the
- // `occursIn(VarSet)` method.
VarKindArray<llvm::SmallBitVector> impl;
public:
@@ -335,11 +259,15 @@ class VarSet final {
Ranks getRanks() const {
return Ranks(getSymRank(), getDimRank(), getLvlRank());
}
-
+ /// For the `contains`/`occursIn` methods: if variables occurring in
+ /// the method parameter are OOB for the `VarSet`, then these methods will
+ /// always return false.
bool contains(Var var) const;
bool occursIn(VarSet const &vars) const;
bool occursIn(DimLvlExpr expr) const;
+ /// For the `add` methods: OOB parameters cause undefined behavior.
+ /// Currently the `add` methods will raise an assertion error.
void add(Var var);
void add(VarSet const &vars);
void add(DimLvlExpr expr);
@@ -394,10 +322,6 @@ class VarInfo final {
return num ? std::make_optional(Var(kind, *num)) : std::nullopt;
}
};
-// We don't actually require this, since `VarInfo` is a proper struct
-// rather than a newtype. But it passes, so for now we'll keep it around.
-// TODO: Uncomment the static assert, it fails the build with gcc7 right now.
-// static_assert(IsZeroCostAbstraction<VarInfo>);
//===----------------------------------------------------------------------===//
enum class Policy { MustNot, May, Must };
@@ -423,9 +347,6 @@ class VarEnv final {
/// object is mutated during the lifetime of the pointer. Therefore,
/// client code should not store the reference nor otherwise allow it
/// to live too long.
- //
- // FUTURE_CL(wrengr): Consider trying to define/use a nested class
- // `struct{VarEnv*; VarInfo::ID}` akin to `BitVector::reference`.
VarInfo const &access(VarInfo::ID id) const {
// `SmallVector::operator[]` already asserts the index is in-bounds.
return vars[to_underlying(id)];
@@ -443,29 +364,24 @@ class VarEnv final {
}
public:
- /// Attempts to look up the variable with the given name.
+ /// Looks up the variable with the given name.
std::optional<VarInfo::ID> lookup(StringRef name) const;
- /// Attempts to create a new currently-unbound variable. When a variable
+ /// Creates a new currently-unbound variable. When a variable
/// of that name already exists: if `verifyUsage` is true, then will assert
/// that the variable has the same kind and a consistent location; otherwise,
/// when `verifyUsage` is false, this is a noop. Returns the identifier
- /// for the variable with the given name (i.e., either the newly created
- /// variable, or the pre-existing variable), and a bool indicating whether
+ /// for the variable with the given name, and a bool indicating whether
/// a new variable was created.
std::optional<std::pair<VarInfo::ID, bool>>
create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
- /// Attempts to lookup or create a variable according to the given
+ /// Looks up or creates a variable according to the given
/// `Policy`. Returns nullopt in one of two circumstances:
/// (1) the policy says we `Must` create, yet the variable already exists;
/// (2) the policy says we `MustNot` create, yet no such variable exists.
/// Otherwise, if the variable already exists then it is validated against
/// the given kind and location to ensure consistency.
- //
- // TODO(wrengr): Define an enum of error codes, to avoid `nullopt`-blindness
- // TODO(wrengr): Prolly want to rename this to `create` and move the
- // current method of that name to being a private `createImpl`.
std::optional<std::pair<VarInfo::ID, bool>>
lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
VarKind vk);
More information about the Mlir-commits
mailing list