[Mlir-commits] [mlir] 889f4bf - [mlir][sparse] Improve `DimLvlMapParser`'s handling of variable bindings
wren romano
llvmlistbot at llvm.org
Thu Jul 20 15:56:11 PDT 2023
Author: wren romano
Date: 2023-07-20T15:56:03-07:00
New Revision: 889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7
URL: https://github.com/llvm/llvm-project/commit/889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7
DIFF: https://github.com/llvm/llvm-project/commit/889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7.diff
LOG: [mlir][sparse] Improve `DimLvlMapParser`'s handling of variable bindings
This commit comprises a number of related changes:
(1) Reintroduces the semantic distinction between `parseVarUsage` vs `parseVarBinding`, adds documentation explaining the distinction, and adds commentary to the one place that violates the desired/intended semantics.
(2) Improves documentation/commentary about the forward-declaration of level-vars, and about the meaning of the `bool` parameter to `parseLvlSpec`.
(2) Removes the `VarEnv::addVars` method, and instead has `DimLvlMapParser` handle the conversion issues directly. In particular, the parser now stores and maintains the `{dims,lvls}AndSymbols` arrays, thereby avoiding the O(n^2) behavior of scanning through the entire `VarEnv` for each `parse{Dim,Lvl}Spec` call. Unfortunately there still remains another source of O(n^2) behavior, namely: the `AsmParser::parseAffineExpr` method will copy the `DimLvlMapParser::{dims,lvls}AndSymbols` arrays into `AffineParser::dimsAndSymbols` on each `parse{Dim,Lvl}Spec` call; but fixing that would require extensive changes to `AffineParser` itself.
Depends On D155532
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D155533
Added:
Modified:
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 5d3e70e2fb2eb3..3b6cedd6596297 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -12,8 +12,19 @@ using namespace mlir;
using namespace mlir::sparse_tensor;
using namespace mlir::sparse_tensor::ir_detail;
-#define FAILURE_IF_FAILED(STMT) \
- if (failed(STMT)) { \
+#define FAILURE_IF_FAILED(RES) \
+ if (failed(RES)) { \
+ return failure(); \
+ }
+
+/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
+/// its `RES` parameter.
+static inline bool didntSucceed(OptionalParseResult res) {
+ return !res.has_value() || failed(*res);
+}
+
+#define FAILURE_IF_NULLOPT_OR_FAILED(RES) \
+ if (didntSucceed(RES)) { \
return failure(); \
}
@@ -80,37 +91,70 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
llvm_unreachable("unknown Policy");
}
-FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk) {
- VarInfo::ID varID;
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
+ bool requireKnown) {
+ VarInfo::ID id;
bool didCreate;
- const auto res =
- parseVar(vk, /*isOptional=*/false, Policy::MustNot, varID, didCreate);
- if (!res.has_value() || failed(*res))
- return failure();
- return varID;
+ const bool isOptional = false;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+ FAILURE_IF_NULLOPT_OR_FAILED(res)
+ assert(requireKnown ? !didCreate : true);
+ return id;
+}
+
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
+ bool requireKnown) {
+ const auto loc = parser.getCurrentLocation();
+ VarInfo::ID id;
+ bool didCreate;
+ const bool isOptional = false;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+ FAILURE_IF_NULLOPT_OR_FAILED(res)
+ assert(requireKnown ? !didCreate : didCreate);
+ bindVar(loc, id);
+ return id;
}
FailureOr<std::pair<Var, bool>>
-DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
+DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
+ const auto loc = parser.getCurrentLocation();
VarInfo::ID id;
bool didCreate;
- const auto res = parseVar(vk, isOptional, Policy::Must, id, didCreate);
+ const bool isOptional = true;
+ const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+ const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
if (res.has_value()) {
FAILURE_IF_FAILED(*res)
- return std::make_pair(env.bindVar(id), true);
+ assert(didCreate);
+ return std::make_pair(bindVar(loc, id), true);
}
+ assert(!didCreate);
return std::make_pair(env.bindUnusedVar(vk), false);
}
-FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
- // Nothing to parse, create a new lvl var right away.
- if (directAffine)
- return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
- // Parse a lvl var, always pulling from the existing pool.
- const auto use = parseVarUsage(VarKind::Level);
- FAILURE_IF_FAILED(use)
- FAILURE_IF_FAILED(parser.parseEqual())
- return env.toVar(*use);
+Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
+ MLIRContext *context = parser.getContext();
+ const auto var = env.bindVar(id);
+ const auto &info = std::as_const(env).access(id);
+ const auto name = info.getName();
+ const auto num = *info.getNum();
+ switch (info.getKind()) {
+ case VarKind::Symbol: {
+ const auto affine = getAffineSymbolExpr(num, context);
+ dimsAndSymbols.emplace_back(name, affine);
+ lvlsAndSymbols.emplace_back(name, affine);
+ return var;
+ }
+ case VarKind::Dimension:
+ dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+ return var;
+ case VarKind::Level:
+ lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+ return var;
+ }
+ llvm_unreachable("unknown VarKind");
}
//===----------------------------------------------------------------------===//
@@ -118,10 +162,8 @@ FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
//===----------------------------------------------------------------------===//
FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
- FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol,
- OpAsmParser::Delimiter::OptionalSquare))
- FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level,
- OpAsmParser::Delimiter::OptionalBraces))
+ FAILURE_IF_FAILED(parseSymbolBindingList())
+ FAILURE_IF_FAILED(parseLvlVarBindingList())
FAILURE_IF_FAILED(parseDimSpecList())
FAILURE_IF_FAILED(parser.parseArrow())
FAILURE_IF_FAILED(parseLvlSpecList())
@@ -133,14 +175,41 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
}
-ParseResult
-DimLvlMapParser::parseOptionalIdList(VarKind vk,
- OpAsmParser::Delimiter delimiter) {
- const auto parseIdBinding = [&]() -> ParseResult {
- return ParseResult(parseVarBinding(vk, /*isOptional=*/false));
- };
- return parser.parseCommaSeparatedList(delimiter, parseIdBinding,
- " in id list");
+ParseResult DimLvlMapParser::parseSymbolBindingList() {
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::OptionalSquare,
+ [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
+ " in symbol binding list");
+}
+
+// FIXME: The forward-declaration of level-vars is a stop-gap workaround
+// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
+// `DimLvlMapParser::parseDimSpec`. (In particular, note that all the
+// variables must be bound before entering `AsmParser::parseAffineExpr`,
+// since that method requires every variable to already have a fixed/known
+// `Var::Num`.)
+//
+// However, the forward-declaration list duplicates information which is
+// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
+// the names of the variables themselves, and the order in which the names
+// are bound). This redundancy causes bad UX, and also means we must be
+// sure to verify consistency between the two sources of information.
+//
+// Therefore, it would be best to remove the forward-declaration list from
+// the syntax. This can be achieved by implementing our own version of
+// `AffineParser::parseAffineExpr` which calls
+// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
+// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
+// some `Var::Num` immediately). This would also enable us to use the `VarEnv`
+// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
+// side, and thus would also enable us to avoid the O(n^2) behavior of copying
+// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
+// every time `AsmParser::parseAffineExpr` is called.
+ParseResult DimLvlMapParser::parseLvlVarBindingList() {
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::OptionalBraces,
+ [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
+ " in level declaration list");
}
//===----------------------------------------------------------------------===//
@@ -150,22 +219,24 @@ DimLvlMapParser::parseOptionalIdList(VarKind vk,
ParseResult DimLvlMapParser::parseDimSpecList() {
return parser.parseCommaSeparatedList(
OpAsmParser::Delimiter::Paren,
- [&]() -> ParseResult { return parseDimSpec(); },
+ [this]() -> ParseResult { return parseDimSpec(); },
" in dimension-specifier list");
}
ParseResult DimLvlMapParser::parseDimSpec() {
- const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false);
- FAILURE_IF_FAILED(res)
- const DimVar var = res->first.cast<DimVar>();
+ // Parse the requisite dim-var binding.
+ const auto varID = parseVarBinding(VarKind::Dimension);
+ FAILURE_IF_FAILED(varID)
+ const DimVar var = env.getVar(*varID).cast<DimVar>();
// Parse an optional dimension expression.
AffineExpr affine;
if (succeeded(parser.parseOptionalEqual())) {
// Parse the dim affine expr, with only any lvl-vars in scope.
- SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
- env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext());
- FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
+ // FIXME(wrengr): This still has the O(n^2) behavior of copying
+ // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
+ // field every time `parseDimSpec` is called.
+ FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
}
DimExpr expr{affine};
@@ -188,32 +259,98 @@ ParseResult DimLvlMapParser::parseDimSpec() {
//===----------------------------------------------------------------------===//
ParseResult DimLvlMapParser::parseLvlSpecList() {
- // If no level variable is declared at this point, the following level
- // specification consists of direct affine expressions only, as in:
- // (d0, d1) -> (d0 : dense, d1 : compressed)
- // Otherwise, we are looking for a leading lvl-var, as in:
- // {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed)
- const bool directAffine = env.getRanks().getLvlRank() == 0;
- return parser.parseCommaSeparatedList(
+ // This method currently only supports two syntaxes:
+ //
+ // (1) There are no forward-declarations, and no lvl-var bindings:
+ // (d0, d1) -> (d0 : dense, d1 : compressed)
+ // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
+ // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
+ // the level-rank is correct at the end of parsing.
+ //
+ // (2) There are forward-declarations, and every lvl-spec must have
+ // a lvl-var binding:
+ // {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+ // However, this introduces duplicate information since the order of
+ // the lvl-vars in `parseLvlVarBindingList` must agree with their order
+ // in the list of lvl-specs. Therefore, `parseLvlSpec` will not call
+ // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
+ // and must also validate the consistency between the two lvl-var orders.
+ const auto declaredLvlRank = env.getRanks().getLvlRank();
+ const bool requireLvlVarBinding = declaredLvlRank != 0;
+ // Have `ERROR_IF` point to the start of the list.
+ const auto loc = parser.getCurrentLocation();
+ const auto res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::Paren,
- [&]() -> ParseResult { return parseLvlSpec(directAffine); },
+ [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
" in level-specifier list");
+ FAILURE_IF_FAILED(res)
+ const auto specLvlRank = lvlSpecs.size();
+ ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
+ "Level-rank mismatch between forward-declarations and specifiers. "
+ "Declared " +
+ Twine(declaredLvlRank) + " level-variables; but got " +
+ Twine(specLvlRank) + " level-specifiers.")
+ return success();
+}
+
+static inline Twine nth(Var::Num n) {
+ switch (n) {
+ case 1:
+ return "1st";
+ case 2:
+ return "2nd";
+ default:
+ return Twine(n) + "th";
+ }
}
-ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) {
- auto res = parseLvlVarBinding(directAffine);
- FAILURE_IF_FAILED(res);
- LvlVar var = res->cast<LvlVar>();
+// NOTE: This is factored out as a separate method only because `Var`
+// lacks a default-ctor, which makes this conditional
diff icult to inline
+// at the one call-site.
+FailureOr<LvlVar>
+DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
+ // Nothing to parse, just bind an unnamed variable.
+ if (!requireLvlVarBinding)
+ return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
+
+ const auto loc = parser.getCurrentLocation();
+ // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
+ // since the thing we're parsing is supposed to be a variable *binding*
+ // rather than a variable *use*. However, the call to `VarEnv::bindVar`
+ // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
+ // already occured in `parseLvlVarBindingList`, and therefore we must
+ // use `parseVarUsage` here in order to operationally do the right thing.
+ const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
+ FAILURE_IF_FAILED(varID)
+ const auto &info = std::as_const(env).access(*varID);
+ const auto var = info.getVar().cast<LvlVar>();
+ const auto forwardNum = var.getNum();
+ const auto specNum = lvlSpecs.size();
+ ERROR_IF(forwardNum != specNum,
+ "Level-variable ordering mismatch. The variable '" + info.getName() +
+ "' was forward-declared as the " + nth(forwardNum) +
+ " level; but is bound by the " + nth(specNum) +
+ " specification.")
+ FAILURE_IF_FAILED(parser.parseEqual())
+ return var;
+}
+
+ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
+ // Parse the optional lvl-var binding. (Actually, `requireLvlVarBinding`
+ // specifies whether that "optional" is actually Must or MustNot.)
+ const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
+ FAILURE_IF_FAILED(varRes)
+ const LvlVar var = *varRes;
// Parse the lvl affine expr, with only the dim-vars in scope.
AffineExpr affine;
- SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
- env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext());
+ // FIXME(wrengr): This still has the O(n^2) behavior of copying
+ // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
+ // field every time `parseLvlSpec` is called.
FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
LvlExpr expr{affine};
FAILURE_IF_FAILED(parser.parseColon())
-
const auto type = lvlTypeParser.parseLvlType(parser);
FAILURE_IF_FAILED(type)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index b14ef370270d6b..013a89ea172b0b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,22 +42,59 @@ class DimLvlMapParser final {
FailureOr<DimLvlMap> parseDimLvlMap();
private:
+ /// The core code for parsing `Var`. This method abstracts out a lot
+ /// of complex details to avoid code duplication; however, client code
+ /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
+ /// calling this method directly.
OptionalParseResult parseVar(VarKind vk, bool isOptional,
Policy creationPolicy, VarInfo::ID &id,
bool &didCreate);
- FailureOr<VarInfo::ID> parseVarUsage(VarKind vk);
- FailureOr<std::pair<Var, bool>> parseVarBinding(VarKind vk, bool isOptional);
- FailureOr<Var> parseLvlVarBinding(bool directAffine);
- ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter);
+ /// Parse a variable occurence which is a *use* of that variable.
+ /// The `requireKnown` parameter specifies how to handle the case of
+ /// encountering a valid variable name which is currently unused: when
+ /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+ /// a new unbound variable will be created.
+ ///
+ /// NOTE: Just because a variable is *known* (i.e., the name has been
+ /// associated with an `VarInfo::ID`), does not mean that the variable
+ /// is actually *in scope*.
+ FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
+
+ /// Parse a variable occurence which is a *binding* of that variable.
+ /// The `requireKnown` parameter is for handling the binding of
+ /// forward-declared variables.
+ FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
+
+ /// Parse an optional variable binding. When the next token is
+ /// not a valid variable name, this will bind a new unnamed variable.
+ /// The returned `bool` indicates whether a variable name was parsed.
+ FailureOr<std::pair<Var, bool>>
+ parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
+
+ /// Binds the given variable: both updating the `VarEnv` itself, and
+ /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
+ /// to `AsmParser::parseAffineExpr`). This method is already called by the
+ /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
+ /// not need to be called elsewhere.
+ Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
+
+ ParseResult parseSymbolBindingList();
+ ParseResult parseLvlVarBindingList();
ParseResult parseDimSpec();
ParseResult parseDimSpecList();
- ParseResult parseLvlSpec(bool directAffine);
+ FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
+ ParseResult parseLvlSpec(bool requireLvlVarBinding);
ParseResult parseLvlSpecList();
AsmParser &parser;
LvlTypeParser lvlTypeParser;
VarEnv env;
+ // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
+ // the O(n^2) cost of repeatedly constructing them inside of the
+ // `parse{Dim,Lvl}Spec` methods.
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
SmallVector<DimSpec> dimSpecs;
SmallVector<LvlSpec> lvlSpecs;
};
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index e126dab02b6a02..7250d44b53d0d6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -296,16 +296,4 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
return {};
}
-void VarEnv::addVars(
- SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
- VarKind vk, MLIRContext *context) const {
- for (const auto &var : vars) {
- if (var.getKind() == vk) {
- assert(var.hasNum());
- dimsAndSymbols.push_back(std::make_pair(
- var.getName(), getAffineDimExpr(*var.getNum(), context)));
- }
- }
-}
-
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 8365ff2ae54318..313972b3ca79b6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -424,8 +424,6 @@ class VarEnv final {
return oid ? &access(*oid) : nullptr;
}
- Var toVar(VarInfo::ID id) const { return vars[to_underlying(id)].getVar(); }
-
private:
VarInfo &access(VarInfo::ID id) {
return const_cast<VarInfo &>(std::as_const(*this).access(id));
@@ -472,12 +470,20 @@ class VarEnv final {
InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
+ /// Returns the current ranks of bound variables. This method should
+ /// only be used after the environment is "finished", since binding new
+ /// variables will (semantically) invalidate any previously returned `Ranks`.
Ranks getRanks() const { return Ranks(nextNum); }
- /// Adds all variables of given kind to the vector.
- void
- addVars(SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
- VarKind vk, MLIRContext *context) const;
+ /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
+ /// failure if the variable is not bound.
+ Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
+
+ /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
+ /// if the variable is not bound.
+ std::optional<Var> tryGetVar(VarInfo::ID id) const {
+ return access(id).tryGetVar();
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index e76df6551c2e1c..2500a9d244cd77 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -69,3 +69,48 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
dimSlices = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}}
}>
func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+///////////////////////////////////////////////////////////////////////////////
+// Migration plan for new STEA surface syntax,
+// use the NEW_SYNTAX on selected examples
+// and then TODO: remove when fully migrated
+///////////////////////////////////////////////////////////////////////////////
+
+// -----
+
+// expected-error at +3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}}
+#TooManyLvlDecl = #sparse_tensor.encoding<{
+ NEW_SYNTAX =
+ {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @too_many_lvl_decl(%arg0: tensor<?x?xf64, #TooManyLvlDecl>) {
+ return
+}
+
+// -----
+
+// NOTE: We don't get the "level-rank mismatch" error here, because this
+// "undeclared identifier" error occurs first. The error message is a bit
+// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather
+// than `parseVarBinding` (and the error message generated by `parseVar`
+// is assuming that `parseVarUsage` is only called for *uses* of variables).
+// expected-error at +3 {{use of undeclared identifier 'l1'}}
+#TooFewLvlDecl = #sparse_tensor.encoding<{
+ NEW_SYNTAX =
+ {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @too_few_lvl_decl(%arg0: tensor<?x?xf64, #TooFewLvlDecl>) {
+ return
+}
+
+// -----
+
+// expected-error at +3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}}
+#WrongOrderLvlDecl = #sparse_tensor.encoding<{
+ NEW_SYNTAX =
+ {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @wrong_order_lvl_decl(%arg0: tensor<?x?xf64, #WrongOrderLvlDecl>) {
+ return
+}
+
More information about the Mlir-commits
mailing list