[Mlir-commits] [mlir] b939c01 - [mlir][sparse] add affine parsing to new surface syntax for STEA

Aart Bik llvmlistbot at llvm.org
Fri Jun 30 14:48:52 PDT 2023


Author: Aart Bik
Date: 2023-06-30T14:48:23-07:00
New Revision: b939c015a4ad1f1d07f93d322e7dbe2feb0a13bc

URL: https://github.com/llvm/llvm-project/commit/b939c015a4ad1f1d07f93d322e7dbe2feb0a13bc
DIFF: https://github.com/llvm/llvm-project/commit/b939c015a4ad1f1d07f93d322e7dbe2feb0a13bc.diff

LOG: [mlir][sparse] add affine parsing to new surface syntax for STEA

(1) uses the previously introduce API to reuse AffineExpr parser without codedup
(2) solves the look-ahead problem when parsing level spec

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D154254

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
index b8eeb1a4a67e1b..d87258756ed87b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp
@@ -241,7 +241,7 @@ void LvlSpec::print(llvm::raw_ostream &os, bool wantElision) const {
   if (!wantElision || !elideVar)
     os << var << " = ";
   os << expr;
-  os << ": \"" << toMLIRString(type) << "\"";
+  os << ": " << toMLIRString(type);
 }
 
 //===----------------------------------------------------------------------===//
@@ -264,10 +264,10 @@ DimLvlMap::DimLvlMap(unsigned symRank, ArrayRef<DimSpec> dimSpecs,
   // Third, we set every `LvlSpec::elideVar` according to whether that
   // LvlVar occurs in a non-elided DimExpr (TODO: or CountingExpr).
   VarSet usedVars(getRanks());
-  for (const auto &dimSpec : dimSpecs)
-    // NOTE TO Wren: bypassed for empty
-    if (dimSpec.hasExpr() && !dimSpec.canElideExpr())
-      usedVars.add(dimSpec.getExpr());
+  // NOTE TO Wren: bypassed for now
+  // for (const auto &dimSpec : dimSpecs)
+  //  if (!dimSpec.canElideExpr())
+  //    usedVars.add(dimSpec.getExpr());
   for (auto &lvlSpec : this->lvlSpecs)
     lvlSpec.setElideVar(!usedVars.contains(lvlSpec.getBoundVar()));
 }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 6ff72c11d06e11..361425de73cc21 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -1,8 +1,4 @@
 //===- DimLvlMapParser.cpp - `DimLvlMap` parser implementation ------------===//
-// These two lookup methods are probably small enough to benefit from
-// being defined inline/in-class, expecially since doing so may allow the
-// compiler to optimize the `std::optional` away.  But we put the defns
-// here until benchmarks prove the benefit of doing otherwise.
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -16,32 +12,12 @@ using namespace mlir;
 using namespace mlir::sparse_tensor;
 using namespace mlir::sparse_tensor::ir_detail;
 
-//===----------------------------------------------------------------------===//
-// TODO(wrengr): rephrase these to do the trick for gobbling up any trailing
-// semicolon
-//
-// NOTE: There's no way for `FAILURE_IF_FAILED` to simultaneously support
-// both `OptionalParseResult` and `InFlightDiagnostic` return types.
-// We can get the compiler to accept the code if we returned "`{}`",
-// however for `OptionalParseResult` that would become the nullopt result,
-// whereas for `InFlightDiagnostic` it would become a result that can
-// be implicitly converted to success.  By using "`failure()`" we ensure
-// that `OptionalParseResult` behaves as intended, however that means the
-// macro cannot be used for `InFlightDiagnostic` since there's no implicit
-// conversion.
 #define FAILURE_IF_FAILED(STMT)                                                \
   if (failed(STMT)) {                                                          \
     return failure();                                                          \
   }
 
-// Although `ERROR_IF` is phrased to return `InFlightDiagnostic`, that type
-// can be implicitly converted to all four of `LogicalResult, `FailureOr`,
-// `ParseResult`, and `OptionalParseResult`.  (However, beware that the
-// conversion to `OptionalParseResult` doesn't properly delegate to
-// `InFlightDiagnostic::operator ParseResult`.)
-//
 // NOTE: this macro assumes `AsmParser parser` and `SMLoc loc` are in scope.
-// NOTE_TO_SELF(wrengr): The LOC used to always be `parser.getNameLoc()`
 #define ERROR_IF(COND, MSG)                                                    \
   if (COND) {                                                                  \
     return parser.emitError(loc, MSG);                                         \
@@ -107,11 +83,8 @@ OptionalParseResult DimLvlMapParser::parseVar(VarKind vk, bool isOptional,
 FailureOr<VarInfo::ID> DimLvlMapParser::parseVarUsage(VarKind vk) {
   VarInfo::ID varID;
   bool didCreate;
-  // We use the policy `May` because we want to allow parsing free/unbound
-  // variables.  If we wanted to distinguish between parsing free-var uses
-  // vs bound-var uses, then the latter should use `MustNot`.
-  const auto res =
-      parseVar(vk, /*isOptional=*/false, CreationPolicy::May, varID, didCreate);
+  const auto res = parseVar(vk, /*isOptional=*/false, CreationPolicy::MustNot,
+                            varID, didCreate);
   if (!res.has_value() || failed(*res))
     return failure();
   return varID;
@@ -126,9 +99,19 @@ DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
   if (res.has_value()) {
     FAILURE_IF_FAILED(*res)
     return std::make_pair(env.bindVar(id), true);
-  } else {
-    return std::make_pair(env.bindUnusedVar(vk), false);
   }
+  return std::make_pair(env.bindUnusedVar(vk), false);
+}
+
+FailureOr<Var> DimLvlMapParser::parseLvlVarBinding(bool directAffine) {
+  // Nothing to parse, create a new lvl var right away.
+  if (directAffine)
+    return env.bindUnusedVar(VarKind::Level).cast<LvlVar>();
+  // Parse a lvl var, always pulling from the existing pool.
+  const auto use = parseVarUsage(VarKind::Level);
+  FAILURE_IF_FAILED(use)
+  FAILURE_IF_FAILED(parser.parseEqual())
+  return env.toVar(*use);
 }
 
 //===----------------------------------------------------------------------===//
@@ -136,7 +119,10 @@ DimLvlMapParser::parseVarBinding(VarKind vk, bool isOptional) {
 //===----------------------------------------------------------------------===//
 
 FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
-  FAILURE_IF_FAILED(parseOptionalSymbolIdList())
+  FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Symbol,
+                                        OpAsmParser::Delimiter::OptionalSquare))
+  FAILURE_IF_FAILED(parseOptionalIdList(VarKind::Level,
+                                        OpAsmParser::Delimiter::OptionalBraces))
   FAILURE_IF_FAILED(parseDimSpecList())
   FAILURE_IF_FAILED(parser.parseArrow())
   FAILURE_IF_FAILED(parseLvlSpecList())
@@ -148,19 +134,14 @@ FailureOr<DimLvlMap> DimLvlMapParser::parseDimLvlMap() {
   return DimLvlMap(env.getRanks().getSymRank(), dimSpecs, lvlSpecs);
 }
 
-using Delimiter = mlir::OpAsmParser::Delimiter;
-
-ParseResult DimLvlMapParser::parseOptionalSymbolIdList() {
-  const auto parseSymVarBinding = [&]() -> ParseResult {
-    return ParseResult(parseVarBinding(VarKind::Symbol, /*isOptional=*/false));
+ParseResult
+DimLvlMapParser::parseOptionalIdList(VarKind vk,
+                                     OpAsmParser::Delimiter delimiter) {
+  const auto parseIdBinding = [&]() -> ParseResult {
+    return ParseResult(parseVarBinding(vk, /*isOptional=*/false));
   };
-  // If I've correctly unpacked how exactly `Parser::parseCommaSeparatedList`
-  // handles the "optional" delimiters vs the non-optional ones, then
-  // the following call to `AsmParser::parseCommaSeparatedList` should
-  // be equivalent to the whole `AffineParse::parseOptionalSymbolIdList`
-  // method (which uses `Parser` methods to handle the optionality instead).
-  return parser.parseCommaSeparatedList(Delimiter::OptionalSquare,
-                                        parseSymVarBinding, " in symbol list");
+  return parser.parseCommaSeparatedList(delimiter, parseIdBinding,
+                                        " in id list");
 }
 
 //===----------------------------------------------------------------------===//
@@ -169,7 +150,8 @@ ParseResult DimLvlMapParser::parseOptionalSymbolIdList() {
 
 ParseResult DimLvlMapParser::parseDimSpecList() {
   return parser.parseCommaSeparatedList(
-      Delimiter::Paren, [&]() -> ParseResult { return parseDimSpec(); },
+      OpAsmParser::Delimiter::Paren,
+      [&]() -> ParseResult { return parseDimSpec(); },
       " in dimension-specifier list");
 }
 
@@ -178,22 +160,17 @@ ParseResult DimLvlMapParser::parseDimSpec() {
   FAILURE_IF_FAILED(res)
   const DimVar var = res->first.cast<DimVar>();
 
-  DimExpr expr{AffineExpr()};
+  // Parse an optional dimension expression.
+  AffineExpr affine;
   if (succeeded(parser.parseOptionalEqual())) {
-    // FIXME(wrengr): I don't think there's any way to implement this
-    // without replicating the bulk of `AffineParser::parseAffineExpr`
-    // TODO(wrengr): Also, need to make sure the parser uses
-    // `parseVarUsage(VarKind::Level)` so that every `AffineDimExpr`
-    // necessarily corresponds to a `LvlVar` (never a `DimVar`).
-    //
-    // FIXME: proof of concept, parse trivial level vars (viz d0 = l0).
-    auto use = parseVarUsage(VarKind::Level);
-    FAILURE_IF_FAILED(use)
-    AffineExpr a = getAffineDimExpr(var.getNum(), parser.getContext());
-    DimExpr dexpr{a};
-    expr = dexpr;
+    // Parse the dim affine expr, with only any lvl-vars in scope.
+    SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+    env.addVars(dimsAndSymbols, VarKind::Level, parser.getContext());
+    FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
   }
+  DimExpr expr{affine};
 
+  // Parse an optional slice.
   SparseTensorDimSliceAttr slice;
   if (succeeded(parser.parseOptionalColon())) {
     const auto loc = parser.getCurrentLocation();
@@ -212,40 +189,29 @@ ParseResult DimLvlMapParser::parseDimSpec() {
 //===----------------------------------------------------------------------===//
 
 ParseResult DimLvlMapParser::parseLvlSpecList() {
+  // If no level variable is declared at this point, the following level
+  // specification consists of direct affine expressions only, as in:
+  //    (d0, d1) -> (d0 : dense, d1 : compressed)
+  // Otherwise, we are looking for a leading lvl-var, as in:
+  //    {l0, l1} ( d0 = l0, d1 = l1) -> ( l0 = d0 : dense, l1 = d1: compressed)
+  const bool directAffine = env.getRanks().getLvlRank() == 0;
   return parser.parseCommaSeparatedList(
-      Delimiter::Paren, [&]() -> ParseResult { return parseLvlSpec(); },
+      mlir::OpAsmParser::Delimiter::Paren,
+      [&]() -> ParseResult { return parseLvlSpec(directAffine); },
       " in level-specifier list");
 }
 
-ParseResult DimLvlMapParser::parseLvlSpec() {
-  // FIXME(wrengr): This implementation isn't actually going to work as-is,
-  // due to grammar ambiguity.  That is, assuming the current token is indeed
-  // a variable, we don't yet know whether that variable is supposed to
-  // be a binding vs being a usage that's part of the following AffineExpr.
-  // We can only disambiguate that by peeking at the next token to see whether
-  // it's the equals symbol or not.
-  //
-  // FIXME: proof of concept, assume it is new (viz. l0 = d0).
-  const auto res = parseVarBinding(VarKind::Level, /*isOptional=*/true);
-  FAILURE_IF_FAILED(res)
-  if (res->second) {
-    FAILURE_IF_FAILED(parser.parseEqual())
-  }
-  const LvlVar var = res->first.cast<LvlVar>();
-
-  // FIXME(wrengr): I don't think there's any way to implement this
-  // without replicating the bulk of `AffineParser::parseAffineExpr`
-  //
-  // TODO(wrengr): Also, need to make sure the parser uses
-  // `parseVarUsage(VarKind::Dimension)` so that every `AffineDimExpr`
-  // necessarily corresponds to a `DimVar` (never a `LvlVar`).
-  //
-  // FIXME: proof of concept, parse trivial dim vars (viz l0 = d0).
-  auto use = parseVarUsage(VarKind::Dimension);
-  FAILURE_IF_FAILED(use)
-  AffineExpr a =
-      getAffineDimExpr(env.toVar(*use).getNum(), parser.getContext());
-  LvlExpr expr{a};
+ParseResult DimLvlMapParser::parseLvlSpec(bool directAffine) {
+  auto res = parseLvlVarBinding(directAffine);
+  FAILURE_IF_FAILED(res);
+  LvlVar var = res->cast<LvlVar>();
+
+  // Parse the lvl affine expr, with only the dim-vars in scope.
+  AffineExpr affine;
+  SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+  env.addVars(dimsAndSymbols, VarKind::Dimension, parser.getContext());
+  FAILURE_IF_FAILED(parser.parseAffineExpr(dimsAndSymbols, affine))
+  LvlExpr expr{affine};
 
   FAILURE_IF_FAILED(parser.parseColon())
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
index a77879e6aac128..2cdf66b8de4b1c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.h
@@ -16,14 +16,23 @@ namespace mlir {
 namespace sparse_tensor {
 namespace ir_detail {
 
-//===----------------------------------------------------------------------===//
-// NOTE(wrengr): The idea here was originally based on the
-// "lib/AsmParser/AffineParser.cpp"-static class `AffineParser`.
-// Unfortunately, we can't use that class directly since it's file-local.
-// Even worse, both `mlir::detail::Parser` and `mlir::detail::ParserState`
-// are also file-local classes.  I've been attempting to convert things
-// over to using `AsmParser` wherever possible, though it's not clear that
-// that'll work...
+///
+/// Parses the Sparse Tensor Encoding Attribute (STEA).
+///
+/// General syntax is as follows,
+///
+///   [s0, ...]     // optional forward decl sym-vars
+///   {l0, ...}     // optional forward decl lvl-vars
+///   (
+///     d0 = ...,   // dim-var = dim-exp
+///     ...
+///   ) -> (
+///     l0 = ...,   // lvl-var = lvl-exp
+///     ...
+///   )
+///
+/// with simplifications when variables are implicit.
+///
 class DimLvlMapParser final {
 public:
   explicit DimLvlMapParser(AsmParser &parser) : parser(parser) {}
@@ -33,18 +42,17 @@ class DimLvlMapParser final {
   FailureOr<DimLvlMap> parseDimLvlMap();
 
 private:
-  // TODO(wrengr): rather than using `OptionalParseResult` and two
-  // out-parameters, should we define a type to encapsulate all that?
   OptionalParseResult parseVar(VarKind vk, bool isOptional,
                                CreationPolicy creationPolicy, VarInfo::ID &id,
                                bool &didCreate);
   FailureOr<VarInfo::ID> parseVarUsage(VarKind vk);
   FailureOr<std::pair<Var, bool>> parseVarBinding(VarKind vk, bool isOptional);
+  FailureOr<Var> parseLvlVarBinding(bool directAffine);
 
-  ParseResult parseOptionalSymbolIdList();
+  ParseResult parseOptionalIdList(VarKind vk, OpAsmParser::Delimiter delimiter);
   ParseResult parseDimSpec();
   ParseResult parseDimSpecList();
-  ParseResult parseLvlSpec();
+  ParseResult parseLvlSpec(bool directAffine);
   ParseResult parseLvlSpecList();
 
   AsmParser &parser;
@@ -54,8 +62,6 @@ class DimLvlMapParser final {
   SmallVector<LvlSpec> lvlSpecs;
 };
 
-//===----------------------------------------------------------------------===//
-
 } // namespace ir_detail
 } // namespace sparse_tensor
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 5cc9c2258a9a6f..35022d7cfa1b09 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -289,4 +289,16 @@ InFlightDiagnostic VarEnv::emitErrorIfAnyUnbound(AsmParser &parser) const {
   return {};
 }
 
+void VarEnv::addVars(
+    SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
+    VarKind vk, MLIRContext *context) const {
+  for (const auto &var : vars) {
+    if (var.getKind() == vk) {
+      assert(var.hasNum());
+      dimsAndSymbols.push_back(std::make_pair(
+          var.getName(), getAffineDimExpr(*var.getNum(), context)));
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index e3b1038de696dc..0c9a2cf348c139 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -162,11 +162,6 @@ class SymVar final : public Var {
 };
 static_assert(IsZeroCostAbstraction<SymVar>);
 
-// TODO(wrengr): I'd like to give the ctors the types `DimVar(Dimension)`
-// and `LvlVar(Level)`, instead of their current types using `Num`;
-// however, that'd require importing "IR/SparseTensor.h" which nothing else
-// in this file requires.  Also beware the issues about implicit-conversion
-// from `uint64_t` to `Num`.
 class DimVar final : public Var {
 public:
   static constexpr VarKind Kind = VarKind::Dimension;
@@ -189,32 +184,6 @@ class LvlVar final : public Var {
 };
 static_assert(IsZeroCostAbstraction<LvlVar>);
 
-// FIXME(wrengr): In order to get the `llvm::{isa,cast,dyn_cast}`
-// free-functions to work (instead of using our hand-rolled methods),
-// we'll need to define something like this:
-// ```
-// namespace llvm {
-// template <typename U> struct CastInfo<U, Var> : OptionalValueCast<U, Var> {};
-// template <> struct ValueIsPresent<Var> {
-//   using UnwrappedType = Var;
-//   static inline bool isPresent(Var const&) { return true; }
-// };
-// } // namespace llvm
-// ```
-// The above will enable the type `llvm::dyn_cast<U>(Var) -> std::optional<U>`.
-//
-// FIXME(wrengr): The default `OptionalValueCast<U,Var>::doCast(Var const&)`
-// implementation uses the expression "`U(var)`", which means that all the
-// subclasses will need to define that upcasting-copy-ctor, and to ensure
-// safety/correctness will need to mark that ctor as private/protected,
-// which in turn means they'll need make the `CastInfo`/`OptionalValueCast`
-// classes friends.
-//
-// We run into similar issues with our hand-rolled methods, the only
-// 
diff erence is that the upcasting-copy-ctor would have type `U(Impl)`
-// instead of `U(Var)` and that we'd need to make the `Var` class a friend
-// rather than the `CastInfo`/`OptionalValueCast` classes.
-//
 template <typename U>
 constexpr bool Var::isa() const {
   if constexpr (std::is_same_v<U, SymVar>)
@@ -257,8 +226,6 @@ class Ranks final {
   }
 
 public:
-  // NOTE_TO_SELF(wrengr): According to <https://stackoverflow.com/a/34465458>
-  // we should be able to do this just fine, even though `constexpr`
   constexpr Ranks(unsigned symRank, unsigned dimRank, unsigned lvlRank)
       : impl() {
     impl[to_index(VarKind::Symbol)] = symRank;
@@ -304,16 +271,6 @@ class VarSet final {
   void add(DimLvlExpr expr);
 };
 
-//===----------------------------------------------------------------------===//
-// TODO(wrengr): For good error messages we'll need to define something like:
-// ```class LocatedVar final { llvm::SMLoc loc; VarInfo::ID id; };```
-// to be the actual thing occuring in our variant of AffineExpr.
-// Though we may also want that struct to contain a pointer back to the
-// `VarEnv` which contains the `VarInfo` for that `VarInfo::ID`.
-//
-// To go along with this, the `VarInfo` record should drop its own `SMLoc`
-// field.
-
 //===----------------------------------------------------------------------===//
 /// A record of metadata for/about a variable, used by `VarEnv`.
 /// The principal goal of this record is to enable `VarEnv` to be used for
@@ -456,6 +413,11 @@ class VarEnv final {
   InFlightDiagnostic emitErrorIfAnyUnbound(AsmParser &parser) const;
 
   Ranks getRanks() const { return Ranks(nextNum); }
+
+  /// Adds all variables of given kind to the vector.
+  void
+  addVars(SmallVectorImpl<std::pair<StringRef, AffineExpr>> &dimsAndSymbols,
+          VarKind vk, MLIRContext *context) const;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index e9efbd071925fa..5a7c5a641a653c 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -129,19 +129,78 @@ func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 // CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ], dimSlices = [ (1, ?, 1), (?, 4, 2) ] }>>
 func.func private @sparse_slice(tensor<?x?xf64, #CSR_SLICE>)
 
-// -----
 
+///////////////////////////////////////////////////////////////////////////////
 // Migration plan for new STEA surface syntax,
 // use the NEW_SYNTAX on selected examples
 // and then TODO: remove when fully migrated
+///////////////////////////////////////////////////////////////////////////////
+
+// -----
+
+#CSR_implicit = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+// CHECK-LABEL: func private @foo(
+// CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
+func.func private @foo(%arg0: tensor<?x?xf64, #CSR_implicit>) {
+  return
+}
+
+// -----
 
-#NewSurfaceSyntax = #sparse_tensor.encoding<{
+#CSR_explicit = #sparse_tensor.encoding<{
   NEW_SYNTAX =
-  (d0, d1) -> (l0 = d0 : dense, l1 = d1 : compressed)
+  {l0, l1} (d0 = l0, d1 = l1) -> (l0 = d0 : dense, l1 = d1 : compressed)
 }>
 
 // CHECK-LABEL: func private @foo(
 // CHECK-SAME: tensor<?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>
-func.func private @foo(%arg0: tensor<?x?xf64, #NewSurfaceSyntax>) {
+func.func private @foo(%arg0: tensor<?x?xf64, #CSR_explicit>) {
+  return
+}
+
+// -----
+
+#BCSR_implicit = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  ( i, j ) ->
+  ( i floordiv 2 : compressed,
+    j floordiv 3 : compressed,
+    i mod 2      : dense,
+    j mod 3      : dense
+  )
+}>
+
+// FIXME: should not have to use 4 dims ;-)
+//
+// CHECK-LABEL: func private @foo(
+// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
+func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_implicit>) {
+  return
+}
+
+// -----
+
+#BCSR_explicit = #sparse_tensor.encoding<{
+  NEW_SYNTAX =
+  {il, jl, ii, jj}
+  ( i = il * 2 + ii,
+    j = jl * 3 + jj
+  ) ->
+  ( il = i floordiv 2 : compressed,
+    jl = j floordiv 3 : compressed,
+    ii = i mod 2      : dense,
+    jj = j mod 3      : dense
+  )
+}>
+
+// FIXME: should not have to use 4 dims ;-)
+//
+// CHECK-LABEL: func private @foo(
+// CHECK-SAME: tensor<?x?x?x?xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "dense", "dense" ] }>>
+func.func private @foo(%arg0: tensor<?x?x?x?xf64, #BCSR_explicit>) {
   return
 }


        


More information about the Mlir-commits mailing list