[Mlir-commits] [mlir] [mlir][sparse] Add more error messages and avoid crashing in new parser (PR #67034)

Yinying Li llvmlistbot at llvm.org
Thu Sep 21 14:47:10 PDT 2023


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/67034

>From 6b0cb1fc2ee4b9c68313fe91d4728fce7981eb01 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 21 Sep 2023 15:40:07 +0000
Subject: [PATCH 1/2] [mlir][sparse] Add more error messages and avoid crashing
 in the new parser

Updates:
1. Added more invalid encodings to test the robustness of the new syntax
2. Changed the asserts that caused crashing into returning booleans
3. Modified some error messages to make them clearer and handled failures in parsing quotes as keyword for level properties.
---
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp  |  15 +-
 .../Dialect/SparseTensor/IR/Detail/Var.cpp    |  50 ++----
 mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h |   4 +-
 .../SparseTensor/invalid_encoding.mlir        | 163 +++++++++++++++++-
 4 files changed, 188 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index b8483f5db130dcf..020e0640d988cfc 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -49,9 +49,10 @@ using namespace mlir::sparse_tensor::ir_detail;
 
 FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
-  uint8_t properties = 0;
   const auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
+           "expected valid keyword, such as compressed without quotes")
+  uint8_t properties = 0;
 
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -73,19 +74,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   } else if (base.compare("singleton") == 0) {
     properties |= static_cast<uint8_t>(LevelFormat::Singleton);
   } else {
-    parser.emitError(loc, "unknown level format");
+    parser.emitError(loc, "unknown level format: ") << base;
     return failure();
   }
 
   ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
-           "invalid level type");
+           "invalid level type: level format doesn't support the properties");
   return properties;
 }
 
 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
                                          uint8_t *properties) const {
   StringRef strVal;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
+  auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
+           "expected valid keyword, such as nonordered without quotes.")
   if (strVal.compare("nonunique") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
@@ -95,7 +98,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   } else if (strVal.compare("block2_4") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
   } else {
-    parser.emitError(parser.getCurrentLocation(), "unknown level property");
+    parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
   }
   return success();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 3b00e17657f1f97..44eba668021ba79 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -196,26 +196,17 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
   return pair1 <= pair2 ? sm1 : sm2;
 }
 
-LLVM_ATTRIBUTE_UNUSED static void
-assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
-#ifndef NDEBUG
+bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
   const auto &var = env.access(id);
-  assert(var.getName() == name && "found inconsistent name");
-  assert(var.getID() == id && "found inconsistent VarInfo::ID");
-#endif // NDEBUG
+  return (var.getName() == name && var.getID() == id);
 }
 
 // NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
 // (or find some other way to convert SMLoc to FileLineColLoc), then this
 // would no longer be `const VarEnv` (and couldn't be a free-function either).
-LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
-                                                         VarInfo::ID id,
-                                                         llvm::SMLoc loc,
-                                                         VarKind vk) {
-#ifndef NDEBUG
+bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
+                       VarKind vk) {
   const auto &var = env.access(id);
-  assert(var.getKind() == vk &&
-         "a variable of that name already exists with a different VarKind");
   // Since the same variable can occur at several locations,
   // it would not be appropriate to do `assert(var.getLoc() == loc)`.
   /* TODO(wrengr):
@@ -223,7 +214,7 @@ LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
   assert(minLoc && "Location mismatch/incompatibility");
   var.loc = minLoc;
   // */
-#endif // NDEBUG
+  return var.getKind() == vk;
 }
 
 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
@@ -236,24 +227,23 @@ std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
   if (iter == ids.end())
     return std::nullopt;
   const auto id = iter->second;
-#ifndef NDEBUG
-  assertInternalConsistency(*this, id, name);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
   return id;
 }
 
-std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
-                                            VarKind vk, bool verifyUsage) {
+std::optional<std::pair<VarInfo::ID, bool>>
+VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
   const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
   const auto id = iter->second;
   if (didInsert) {
     vars.emplace_back(id, name, loc, vk);
   } else {
-#ifndef NDEBUG
-    assertInternalConsistency(*this, id, name);
-    if (verifyUsage)
-      assertUsageConsistency(*this, id, loc, vk);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
+  if (verifyUsage)
+    if (!isUsageConsistent(*this, id, loc, vk))
+      return std::nullopt;
   }
   return std::make_pair(id, didInsert);
 }
@@ -265,20 +255,18 @@ VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
   case Policy::MustNot: {
     const auto oid = lookup(name);
     if (!oid)
-      return std::nullopt; // Doesn't exist, but must not create.
-#ifndef NDEBUG
-    assertUsageConsistency(*this, *oid, loc, vk);
-#endif // NDEBUG
+      return std::nullopt;  // Doesn't exist, but must not create.
+    if (!isUsageConsistent(*this, *oid, loc, vk))
+      return std::nullopt;
     return std::make_pair(*oid, false);
   }
   case Policy::May:
     return create(name, loc, vk, /*verifyUsage=*/true);
   case Policy::Must: {
     const auto res = create(name, loc, vk, /*verifyUsage=*/false);
-    // const auto id = res.first;
-    const auto didCreate = res.second;
+    const auto didCreate = res->second;
     if (!didCreate)
-      return std::nullopt; // Already exists, but must create.
+      return std::nullopt;  // Already exists, but must create.
     return res;
   }
   }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index a488b3ea2d56ba4..145586a83a2528c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -453,8 +453,8 @@ class VarEnv final {
   /// for the variable with the given name (i.e., either the newly created
   /// variable, or the pre-existing variable), and a bool indicating whether
   /// a new variable was created.
-  std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
-                                      VarKind vk, bool verifyUsage = false);
+  std::optional<std::pair<VarInfo::ID, bool>>
+  create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
 
   /// Attempts to lookup or create a variable according to the given
   /// `Policy`.  Returns nullopt in one of two circumstances:
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 42eb4e0a46182e7..883ba9cc81fd8f0 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -1,7 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-// expected-error at +1 {{expected a non-empty array for lvlTypes}}
-#a = #sparse_tensor.encoding<{lvlTypes = []}>
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = []}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '->'}}
+#a = #sparse_tensor.encoding<{map = ()}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ')' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0 -> d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = d0 -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid keyword}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid keyword}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
 func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
 // -----
@@ -18,17 +60,61 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = [1]}> // expected-error {{expected a string value in lvlTypes}}
+// expected-error at +1 {{unexpected dimToLvl mapping from 2 to 1}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
+func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected bare identifier}}
+#a = #sparse_tensor.encoding<{map = (1)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{unexpected key: nap}}
+#a = #sparse_tensor.encoding<{nap = (d0) -> (d0 : dense)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map =  -> (d0 : dense)}>
 func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = ["strange"]}> // expected-error {{unexpected level-type: strange}}
+// expected-error at +1 {{unknown level format: strange}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : strange)}>
 func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{dimToLvl = "wrong"}> // expected-error {{expected an affine map for dimToLvl}}
+// expected-error at +1 {{expected valid keyword, such as compressed without quotes}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid keyword, such as nonordered without quotes}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{expected ')' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed[high])}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{unknown level property: wrong}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed(wrong))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed, dense)}>
 func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
@@ -39,6 +125,73 @@ func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()
 
 // -----
 
+// expected-error at +1 {{unexpected character}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed; d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0: d1) -> (d0 : compressed, d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 = compressed, d1 = dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0 : compressed, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = compressed, d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = l0, d1 = l1) {l0, l1} -> (l0 = d0 : dense, l1 = d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : d0 = dense, l1 : d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : l0 = dense, d1 : l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : dense, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : dense, l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = dense, l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 = l0 : dense, d1 = l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
 #a = #sparse_tensor.encoding<{posWidth = "x"}> // expected-error {{expected an integral position bitwidth}}
 func.func private @tensor_no_int_ptr(%arg0: tensor<16x32xf32, #a>) -> ()
 

>From 00c589efca1694bc2b1f42842829c19491c351d3 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 21 Sep 2023 21:46:36 +0000
Subject: [PATCH 2/2] modify error message

---
 mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp | 4 ++--
 mlir/test/Dialect/SparseTensor/invalid_encoding.mlir      | 8 ++++----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 020e0640d988cfc..44a2c7d49619405 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -51,7 +51,7 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
   const auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
-           "expected valid keyword, such as compressed without quotes")
+           "expected valid level format (e.g. dense, compressed or singleton)")
   uint8_t properties = 0;
 
   ParseResult res = parser.parseCommaSeparatedList(
@@ -88,7 +88,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   StringRef strVal;
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
-           "expected valid keyword, such as nonordered without quotes.")
+           "expected valid level property (e.g. nonordered, nonunique or high)")
   if (strVal.compare("nonunique") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
diff --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 883ba9cc81fd8f0..8adf981d00051c5 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -36,13 +36,13 @@ func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
 // -----
 
-// expected-error at +1 {{expected valid keyword}}
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
 func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
 // -----
 
-// expected-error at +1 {{expected valid keyword}}
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
 func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
@@ -90,13 +90,13 @@ func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-// expected-error at +1 {{expected valid keyword, such as compressed without quotes}}
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
 func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-// expected-error at +1 {{expected valid keyword, such as nonordered without quotes}}
+// expected-error at +1 {{expected valid level property (e.g. nonordered, nonunique or high)}}
 #a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
 func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 



More information about the Mlir-commits mailing list