[Mlir-commits] [mlir] 889f4bf - [mlir][sparse] Improve `DimLvlMapParser`'s handling of variable bindings

wren romano llvmlistbot at llvm.org
Thu Jul 20 15:56:11 PDT 2023


Author: wren romano
Date: 2023-07-20T15:56:03-07:00
New Revision: 889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7

URL: https://github.com/llvm/llvm-project/commit/889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7
DIFF: https://github.com/llvm/llvm-project/commit/889f4bf26406d22e24b7e85bb4f6a8eb57d04fb7.diff

LOG: [mlir][sparse] Improve `DimLvlMapParser`'s handling of variable bindings

This commit comprises a number of related changes:

(1) Reintroduces the semantic distinction between `parseVarUsage` vs `parseVarBinding`, adds documentation explaining the distinction, and adds commentary to the one place that violates the desired/intended semantics.

(2) Improves documentation/commentary about the forward-declaration of level-vars, and about the meaning of the `bool` parameter to `parseLvlSpec`.

(2) Removes the `VarEnv::addVars` method, and instead has `DimLvlMapParser` handle the conversion issues directly.  In particular, the parser now stores and maintains the `{dims,lvls}AndSymbols` arrays, thereby avoiding the O(n^2) behavior of scanning through the entire `VarEnv` for each `parse{Dim,Lvl}Spec` call.  Unfortunately there still remains another source of O(n^2) behavior, namely: the `AsmParser::parseAffineExpr` method will copy the `DimLvlMapParser::{dims,lvls}AndSymbols` arrays into `AffineParser::dimsAndSymbols` on each `parse{Dim,Lvl}Spec` call; but fixing that would require extensive changes to `AffineParser` itself.

Depends On D155532

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D155533

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
    mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 5d3e70e2fb2eb3..3b6cedd6596297 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -12,8 +12,19 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 using namespace mlir::sparse_tensor::ir_detail;
 
-#define FAILURE_IF_FAILED(STMT)                                                \
-  if (failed(STMT)) {                                                          \
+#define FAILURE_IF_FAILED(RES)                                                 \
+  if (failed(RES)) {                                                           \
+    return failure();                                                          \
+  }
+
+/// Helper function for `FAILURE_IF_NULLOPT_OR_FAILED` to avoid duplicating
+/// its `RES` parameter.
+static inline bool didntSucceed(OptionalParseResult res) {
+  return !res.has_value() || failed(*res);
+}
+
+#define FAILURE_IF_NULLOPT_OR_FAILED(RES)                                      \
+  if (didntSucceed(RES)) {                                                     \
     return failure();                                                          \
   }
 
@@ -80,37 +91,70 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
   llvm_unreachable("unknown Policy");
 }
 
-FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk) {
-  VarInfo::ID varID;
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk,
+                                                      bool requireKnown) {
+  VarInfo::ID id;
   bool didCreate;
-  const auto res =
-      parseVar(vk, /*isOptional=*/false, Policy::MustNot, varID, didCreate);
-  if (!res.has_value() || failed(*res))
-    return failure();
-  return varID;
+  const bool isOptional = false;
+  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::May;
+  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+  FAILURE_IF_NULLOPT_OR_FAILED(res)
+  assert(requireKnown ? !didCreate : true);
+  return id;
+}
+
+FailureOr<VarInfo::ID> DimLvlMapParser::parseVarBinding(VarKind vk,
+                                                        bool requireKnown) {
+  const auto loc = parser.getCurrentLocation();
+  VarInfo::ID id;
+  bool didCreate;
+  const bool isOptional = false;
+  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
+  FAILURE_IF_NULLOPT_OR_FAILED(res)
+  assert(requireKnown ? !didCreate : didCreate);
+  bindVar(loc, id);
+  return id;
 }
 
 FailureOr<std::pair<Var, bool>>
-DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
+DimLvlMapParser::parseOptionalVarBinding(VarKind vk, bool requireKnown) {
+  const auto loc = parser.getCurrentLocation();
   VarInfo::ID id;
   bool didCreate;
-  const auto res = parseVar(vk, isOptional, Policy::Must, id, didCreate);
+  const bool isOptional = true;
+  const auto creationPolicy = requireKnown ? Policy::MustNot : Policy::Must;
+  const auto res = parseVar(vk, isOptional, creationPolicy, id, didCreate);
   if (res.has_value()) {
     FAILURE_IF_FAILED(*res)
-    return std::make_pair(env.bindVar(id), true);
+    assert(didCreate);
+    return std::make_pair(bindVar(loc, id), true);
   }
+  assert(!didCreate);
   return std::make_pair(env.bindUnusedVar(vk), false);
 }
 
-FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
-  // Nothing to parse, create a new lvl var right away.
-  if (directAffine)
-    return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
-  // Parse a lvl var, always pulling from the existing pool.
-  const auto use = parseVarUsage(VarKind::Level);
-  FAILURE_IF_FAILED(use)
-  FAILURE_IF_FAILED(parser.parseEqual())
-  return env.toVar(*use);
+Var DimLvlMapParser::bindVar(llvm::SMLoc loc, VarInfo::ID id) {
+  MLIRContext *context = parser.getContext();
+  const auto var = env.bindVar(id);
+  const auto &info = std::as_const(env).access(id);
+  const auto name = info.getName();
+  const auto num = *info.getNum();
+  switch (info.getKind()) {
+  case VarKind::Symbol: {
+    const auto affine = getAffineSymbolExpr(num, context);
+    dimsAndSymbols.emplace_back(name, affine);
+    lvlsAndSymbols.emplace_back(name, affine);
+    return var;
+  }
+  case VarKind::Dimension:
+    dimsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+    return var;
+  case VarKind::Level:
+    lvlsAndSymbols.emplace_back(name, getAffineDimExpr(num, context));
+    return var;
+  }
+  llvm_unreachable("unknown VarKind");
 }
 
 //===----------------------------------------------------------------------===//
@@ -118,10 +162,8 @@ FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
 //===----------------------------------------------------------------------===//
 
 FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
-  FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol,
-                                        OpAsmParser::Delimiter::OptionalSquare))
-  FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level,
-                                        OpAsmParser::Delimiter::OptionalBraces))
+  FAILURE_IF_FAILED(parseSymbolBindingList())
+  FAILURE_IF_FAILED(parseLvlVarBindingList())
   FAILURE_IF_FAILED(parseDimSpecList())
   FAILURE_IF_FAILED(parser.parseArrow())
   FAILURE_IF_FAILED(parseLvlSpecList())
@@ -133,14 +175,41 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
   return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
 }
 
-ParseResult
-DimLvlMapParser::parseOptionalIdList(VarKind vk,
-                                     OpAsmParser::Delimiter delimiter) {
-  const auto parseIdBinding = [&]() -> ParseResult {
-    return ParseResult(parseVarBinding(vk, /*isOptional=*/false));
-  };
-  return parser.parseCommaSeparatedList(delimiter, parseIdBinding,
-                                        " in id list");
+ParseResult DimLvlMapParser::parseSymbolBindingList() {
+  return parser.parseCommaSeparatedList(
+      OpAsmParser::Delimiter::OptionalSquare,
+      [this]() { return ParseResult(parseVarBinding(VarKind::Symbol)); },
+      " in symbol binding list");
+}
+
+// FIXME: The forward-declaration of level-vars is a stop-gap workaround
+// so that we can reuse `AsmParser::parseAffineExpr` in the definition of
+// `DimLvlMapParser::parseDimSpec`.  (In particular, note that all the
+// variables must be bound before entering `AsmParser::parseAffineExpr`,
+// since that method requires every variable to already have a fixed/known
+// `Var::Num`.)
+//
+// However, the forward-declaration list duplicates information which is
+// already encoded by the level-var bindings in `parseLvlSpecList` (namely:
+// the names of the variables themselves, and the order in which the names
+// are bound).  This redundancy causes bad UX, and also means we must be
+// sure to verify consistency between the two sources of information.
+//
+// Therefore, it would be best to remove the forward-declaration list from
+// the syntax.  This can be achieved by implementing our own version of
+// `AffineParser::parseAffineExpr` which calls
+// `parseVarUsage(_,requireKnown=false)` for variables and stores the resulting
+// `VarInfo::ID` in the expression tree (instead of demanding it be resolved to
+// some `Var::Num` immediately).  This would also enable us to use the `VarEnv`
+// directly, rather than building the `{dims,lvls}AndSymbols` lists on the
+// side, and thus would also enable us to avoid the O(n^2) behavior of copying
+// `DimLvlParser::{dims,lvls}AndSymbols` into `AffineParser::dimsAndSymbols`
+// every time `AsmParser::parseAffineExpr` is called.
+ParseResult DimLvlMapParser::parseLvlVarBindingList() {
+  return parser.parseCommaSeparatedList(
+      OpAsmParser::Delimiter::OptionalBraces,
+      [this]() { return ParseResult(parseVarBinding(VarKind::Level)); },
+      " in level declaration list");
 }
 
 //===----------------------------------------------------------------------===//
@@ -150,22 +219,24 @@ DimLvlMapParser::parseOptionalIdList(VarKind vk,
 ParseResult DimLvlMapParser::parseDimSpecList() {
   return parser.parseCommaSeparatedList(
       OpAsmParser::Delimiter::Paren,
-      [&]() -> ParseResult { return parseDimSpec(); },
+      [this]() -> ParseResult { return parseDimSpec(); },
       " in dimension-specifier list");
 }
 
 ParseResult DimLvlMapParser::parseDimSpec() {
-  const auto res = parseVarBinding(VarKind::Dimension, /*isOptional=*/false);
-  FAILURE_IF_FAILED(res)
-  const DimVar var = res->first.cast<DimVar>();
+  // Parse the requisite dim-var binding.
+  const auto varID = parseVarBinding(VarKind::Dimension);
+  FAILURE_IF_FAILED(varID)
+  const DimVar var = env.getVar(*varID).cast<DimVar>();
 
   // Parse an optional dimension expression.
   AffineExpr affine;
   if (succeeded(parser.parseOptionalEqual())) {
     // Parse the dim affine expr, with only any lvl-vars in scope.
-    SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
-    env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext());
-    FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
+    // FIXME(wrengr): This still has the O(n^2) behavior of copying
+    // our `lvlsAndSymbols` into the `AffineParser::dimsAndSymbols`
+    // field every time `parseDimSpec` is called.
+    FAILURE_IF_FAILED(parser.parseAffineExpr(lvlsAndSymbols, affine))
   }
   DimExpr expr{affine};
 
@@ -188,32 +259,98 @@ ParseResult DimLvlMapParser::parseDimSpec() {
 //===----------------------------------------------------------------------===//
 
 ParseResult DimLvlMapParser::parseLvlSpecList() {
-  // If no level variable is declared at this point, the following level
-  // specification consists of direct affine expressions only, as in:
-  //    (d0, d1) -> (d0 : dense, d1 : compressed)
-  // Otherwise, we are looking for a leading lvl-var, as in:
-  //    {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed)
-  const bool directAffine = env.getRanks().getLvlRank() == 0;
-  return parser.parseCommaSeparatedList(
+  // This method currently only supports two syntaxes:
+  //
+  // (1) There are no forward-declarations, and no lvl-var bindings:
+  //        (d0, d1) -> (d0 : dense, d1 : compressed)
+  // Therefore `parseLvlVarBindingList` didn't bind any lvl-vars, and thus
+  // `parseLvlSpec` will need to use `VarEnv::bindUnusedVar` to ensure that
+  // the level-rank is correct at the end of parsing.
+  //
+  // (2) There are forward-declarations, and every lvl-spec must have
+  // a lvl-var binding:
+  //    {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+  // However, this introduces duplicate information since the order of
+  // the lvl-vars in `parseLvlVarBindingList` must agree with their order
+  // in the list of lvl-specs.  Therefore, `parseLvlSpec` will not call
+  // `VarEnv::bindVar` (since `parseLvlVarBindingList` already did so),
+  // and must also validate the consistency between the two lvl-var orders.
+  const auto declaredLvlRank = env.getRanks().getLvlRank();
+  const bool requireLvlVarBinding = declaredLvlRank != 0;
+  // Have `ERROR_IF` point to the start of the list.
+  const auto loc = parser.getCurrentLocation();
+  const auto res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::Paren,
-      [&]() -> ParseResult { return parseLvlSpec(directAffine); },
+      [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); },
       " in level-specifier list");
+  FAILURE_IF_FAILED(res)
+  const auto specLvlRank = lvlSpecs.size();
+  ERROR_IF(requireLvlVarBinding && specLvlRank != declaredLvlRank,
+           "Level-rank mismatch between forward-declarations and specifiers. "
+           "Declared " +
+               Twine(declaredLvlRank) + " level-variables; but got " +
+               Twine(specLvlRank) + " level-specifiers.")
+  return success();
+}
+
+static inline Twine nth(Var::Num n) {
+  switch (n) {
+  case 1:
+    return "1st";
+  case 2:
+    return "2nd";
+  default:
+    return Twine(n) + "th";
+  }
 }
 
-ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) {
-  auto res = parseLvlVarBinding(directAffine);
-  FAILURE_IF_FAILED(res);
-  LvlVar var = res->cast<LvlVar>();
+// NOTE: This is factored out as a separate method only because `Var`
+// lacks a default-ctor, which makes this conditional 
diff icult to inline
+// at the one call-site.
+FailureOr<LvlVar>
+DimLvlMapParser::parseLvlVarBinding(bool requireLvlVarBinding) {
+  // Nothing to parse, just bind an unnamed variable.
+  if (!requireLvlVarBinding)
+    return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
+
+  const auto loc = parser.getCurrentLocation();
+  // NOTE: Calling `parseVarUsage` here is semantically inappropriate,
+  // since the thing we're parsing is supposed to be a variable *binding*
+  // rather than a variable *use*.  However, the call to `VarEnv::bindVar`
+  // (and its corresponding call to `DimLvlMapParser::recordVarBinding`)
+  // already occured in `parseLvlVarBindingList`, and therefore we must
+  // use `parseVarUsage` here in order to operationally do the right thing.
+  const auto varID = parseVarUsage(VarKind::Level, /*requireKnown=*/true);
+  FAILURE_IF_FAILED(varID)
+  const auto &info = std::as_const(env).access(*varID);
+  const auto var = info.getVar().cast<LvlVar>();
+  const auto forwardNum = var.getNum();
+  const auto specNum = lvlSpecs.size();
+  ERROR_IF(forwardNum != specNum,
+           "Level-variable ordering mismatch. The variable '" + info.getName() +
+               "' was forward-declared as the " + nth(forwardNum) +
+               " level; but is bound by the " + nth(specNum) +
+               " specification.")
+  FAILURE_IF_FAILED(parser.parseEqual())
+  return var;
+}
+
+ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
+  // Parse the optional lvl-var binding.  (Actually, `requireLvlVarBinding`
+  // specifies whether that "optional" is actually Must or MustNot.)
+  const auto varRes = parseLvlVarBinding(requireLvlVarBinding);
+  FAILURE_IF_FAILED(varRes)
+  const LvlVar var = *varRes;
 
   // Parse the lvl affine expr, with only the dim-vars in scope.
   AffineExpr affine;
-  SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
-  env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext());
+  // FIXME(wrengr): This still has the O(n^2) behavior of copying
+  // our `dimsAndSymbols` into the `AffineParser::dimsAndSymbols`
+  // field every time `parseLvlSpec` is called.
   FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
   LvlExpr expr{affine};
 
   FAILURE_IF_FAILED(parser.parseColon())
-
   const auto type = lvlTypeParser.parseLvlType(parser);
   FAILURE_IF_FAILED(type)
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index b14ef370270d6b..013a89ea172b0b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -42,22 +42,59 @@ class DimLvlMapParser final {
   FailureOr<DimLvlMap> parseDimLvlMap();
 
 private:
+  /// The core code for parsing `Var`.  This method abstracts out a lot
+  /// of complex details to avoid code duplication; however, client code
+  /// should prefer using `parseVarUsage` and `parseVarBinding` rather than
+  /// calling this method directly.
   OptionalParseResult parseVar(VarKind vk, bool isOptional,
                                Policy creationPolicy, VarInfo::ID &id,
                                bool &didCreate);
-  FailureOr<VarInfo::ID> parseVarUsage(VarKind vk);
-  FailureOr<std::pair<Var, bool>> parseVarBinding(VarKind vk, bool isOptional);
-  FailureOr<Var> parseLvlVarBinding(bool directAffine);
 
-  ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter);
+  /// Parse a variable occurence which is a *use* of that variable.
+  /// The `requireKnown` parameter specifies how to handle the case of
+  /// encountering a valid variable name which is currently unused: when
+  /// `requireKnown=true`, an error is raised; when `requireKnown=false`,
+  /// a new unbound variable will be created.
+  ///
+  /// NOTE: Just because a variable is *known* (i.e., the name has been
+  /// associated with an `VarInfo::ID`), does not mean that the variable
+  /// is actually *in scope*.
+  FailureOr<VarInfo::ID> parseVarUsage(VarKind vk, bool requireKnown);
+
+  /// Parse a variable occurence which is a *binding* of that variable.
+  /// The `requireKnown` parameter is for handling the binding of
+  /// forward-declared variables.
+  FailureOr<VarInfo::ID> parseVarBinding(VarKind vk, bool requireKnown = false);
+
+  /// Parse an optional variable binding.  When the next token is
+  /// not a valid variable name, this will bind a new unnamed variable.
+  /// The returned `bool` indicates whether a variable name was parsed.
+  FailureOr<std::pair<Var, bool>>
+  parseOptionalVarBinding(VarKind vk, bool requireKnown = false);
+
+  /// Binds the given variable: both updating the `VarEnv` itself, and
+  /// also updating the `{dims,lvls}AndSymbols` lists (which will be passed
+  /// to `AsmParser::parseAffineExpr`).  This method is already called by the
+  /// `parseVarBinding`/`parseOptionalVarBinding` methods, therefore should
+  /// not need to be called elsewhere.
+  Var bindVar(llvm::SMLoc loc, VarInfo::ID id);
+
+  ParseResult parseSymbolBindingList();
+  ParseResult parseLvlVarBindingList();
   ParseResult parseDimSpec();
   ParseResult parseDimSpecList();
-  ParseResult parseLvlSpec(bool directAffine);
+  FailureOr<LvlVar> parseLvlVarBinding(bool requireLvlVarBinding);
+  ParseResult parseLvlSpec(bool requireLvlVarBinding);
   ParseResult parseLvlSpecList();
 
   AsmParser &parser;
   LvlTypeParser lvlTypeParser;
   VarEnv env;
+  // The parser maintains the `{dims,lvls}AndSymbols` lists to avoid
+  // the O(n^2) cost of repeatedly constructing them inside of the
+  // `parse{Dim,Lvl}Spec` methods.
+  SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+  SmallVector<std::pair<StringRef, AffineExpr>, 4> lvlsAndSymbols;
   SmallVector<DimSpec> dimSpecs;
   SmallVector<LvlSpec> lvlSpecs;
 };

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index e126dab02b6a02..7250d44b53d0d6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -296,16 +296,4 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
   return {};
 }
 
-void VarEnv::addVars(
-    SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
-    VarKind vk, MLIRContext *context) const {
-  for (const auto &var : vars) {
-    if (var.getKind() == vk) {
-      assert(var.hasNum());
-      dimsAndSymbols.push_back(std::make_pair(
-          var.getName(), getAffineDimExpr(*var.getNum(), context)));
-    }
-  }
-}
-
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index 8365ff2ae54318..313972b3ca79b6 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -424,8 +424,6 @@ class VarEnv final {
     return oid ? &access(*oid) : nullptr;
   }
 
-  Var toVar(VarInfo::ID id) const { return vars[to_underlying(id)].getVar(); }
-
 private:
   VarInfo &access(VarInfo::ID id) {
     return const_cast<VarInfo &>(std::as_const(*this).access(id));
@@ -472,12 +470,20 @@ class VarEnv final {
 
   InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
 
+  /// Returns the current ranks of bound variables.  This method should
+  /// only be used after the environment is "finished", since binding new
+  /// variables will (semantically) invalidate any previously returned `Ranks`.
   Ranks getRanks() const { return Ranks(nextNum); }
 
-  /// Adds all variables of given kind to the vector.
-  void
-  addVars(SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
-          VarKind vk, MLIRContext *context) const;
+  /// Gets the `Var` identified by the `VarInfo::ID`, raising an assertion
+  /// failure if the variable is not bound.
+  Var getVar(VarInfo::ID id) const { return access(id).getVar(); }
+
+  /// Gets the `Var` identified by the `VarInfo::ID`, returning nullopt
+  /// if the variable is not bound.
+  std::optional<Var> tryGetVar(VarInfo::ID id) const {
+    return access(id).tryGetVar();
+  }
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index e76df6551c2e1c..2500a9d244cd77 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -69,3 +69,48 @@ func.func private @tensor_invalid_key(%arg0: tensor<16x32xf32, #a>) -> ()
   dimSlices = [ (-1, ?, 1), (?, 4, 2) ] // expected-error{{expect positive value or ? for slice offset/size/stride}}
 }>
 func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
+
+///////////////////////////////////////////////////////////////////////////////
+// Migration plan for new STEA surface syntax,
+// use the NEW_SYNTAX on selected examples
+// and then TODO: remove when fully migrated
+///////////////////////////////////////////////////////////////////////////////
+
+// -----
+
+// expected-error at +3 {{Level-rank mismatch between forward-declarations and specifiers. Declared 3 level-variables; but got 2 level-specifiers.}}
+#TooManyLvlDecl = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  {l0, l1, l2} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @too_many_lvl_decl(%arg0: tensor<?x?xf64, #TooManyLvlDecl>) {
+  return
+}
+
+// -----
+
+// NOTE: We don't get the "level-rank mismatch" error here, because this
+// "undeclared identifier" error occurs first.  The error message is a bit
+// misleading because `parseLvlVarBinding` calls `parseVarUsage` rather
+// than `parseVarBinding` (and the error message generated by `parseVar`
+// is assuming that `parseVarUsage` is only called for *uses* of variables).
+// expected-error at +3 {{use of undeclared identifier 'l1'}}
+#TooFewLvlDecl = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  {l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @too_few_lvl_decl(%arg0: tensor<?x?xf64, #TooFewLvlDecl>) {
+  return
+}
+
+// -----
+
+// expected-error at +3 {{Level-variable ordering mismatch. The variable 'l0' was forward-declared as the 1st level; but is bound by the 0th specification.}}
+#WrongOrderLvlDecl = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  {l1, l0} (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+}>
+func.func private @wrong_order_lvl_decl(%arg0: tensor<?x?xf64, #WrongOrderLvlDecl>) {
+  return
+}
+


        


More information about the Mlir-commits mailing list