[Mlir-commits] [mlir] 8466eb7 - [mlir][sparse] Add more error messages and avoid crashing in new parser (#67034)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 22 09:03:12 PDT 2023


Author: Yinying Li
Date: 2023-09-22T12:03:08-04:00
New Revision: 8466eb7d031cb8df9585e4cfa29d2af08b9a0c01

URL: https://github.com/llvm/llvm-project/commit/8466eb7d031cb8df9585e4cfa29d2af08b9a0c01
DIFF: https://github.com/llvm/llvm-project/commit/8466eb7d031cb8df9585e4cfa29d2af08b9a0c01.diff

LOG: [mlir][sparse] Add more error messages and avoid crashing in new parser (#67034)

Updates:
1. Added more invalid encodings to test the robustness of the new syntax
2. Changed the asserts that caused crashing into returning booleans
3. Modified some error messages to make them clearer and handled
failures in parsing quotes as keyword for level formats and properties.

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
    mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index b8483f5db130dcf..44a2c7d49619405 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -49,9 +49,10 @@ using namespace mlir::sparse_tensor::ir_detail;
 
 FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
-  uint8_t properties = 0;
   const auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
+           "expected valid level format (e.g. dense, compressed or singleton)")
+  uint8_t properties = 0;
 
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -73,19 +74,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   } else if (base.compare("singleton") == 0) {
     properties |= static_cast<uint8_t>(LevelFormat::Singleton);
   } else {
-    parser.emitError(loc, "unknown level format");
+    parser.emitError(loc, "unknown level format: ") << base;
     return failure();
   }
 
   ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
-           "invalid level type");
+           "invalid level type: level format doesn't support the properties");
   return properties;
 }
 
 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
                                          uint8_t *properties) const {
   StringRef strVal;
-  FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
+  auto loc = parser.getCurrentLocation();
+  ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
+           "expected valid level property (e.g. nonordered, nonunique or high)")
   if (strVal.compare("nonunique") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
@@ -95,7 +98,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   } else if (strVal.compare("block2_4") == 0) {
     *properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
   } else {
-    parser.emitError(parser.getCurrentLocation(), "unknown level property");
+    parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
   }
   return success();

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
index 3b00e17657f1f97..44eba668021ba79 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
@@ -196,26 +196,17 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
   return pair1 <= pair2 ? sm1 : sm2;
 }
 
-LLVM_ATTRIBUTE_UNUSED static void
-assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
-#ifndef NDEBUG
+bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
   const auto &var = env.access(id);
-  assert(var.getName() == name && "found inconsistent name");
-  assert(var.getID() == id && "found inconsistent VarInfo::ID");
-#endif // NDEBUG
+  return (var.getName() == name && var.getID() == id);
 }
 
 // NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
 // (or find some other way to convert SMLoc to FileLineColLoc), then this
 // would no longer be `const VarEnv` (and couldn't be a free-function either).
-LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
-                                                         VarInfo::ID id,
-                                                         llvm::SMLoc loc,
-                                                         VarKind vk) {
-#ifndef NDEBUG
+bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
+                       VarKind vk) {
   const auto &var = env.access(id);
-  assert(var.getKind() == vk &&
-         "a variable of that name already exists with a 
diff erent VarKind");
   // Since the same variable can occur at several locations,
   // it would not be appropriate to do `assert(var.getLoc() == loc)`.
   /* TODO(wrengr):
@@ -223,7 +214,7 @@ LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
   assert(minLoc && "Location mismatch/incompatibility");
   var.loc = minLoc;
   // */
-#endif // NDEBUG
+  return var.getKind() == vk;
 }
 
 std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
@@ -236,24 +227,23 @@ std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
   if (iter == ids.end())
     return std::nullopt;
   const auto id = iter->second;
-#ifndef NDEBUG
-  assertInternalConsistency(*this, id, name);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
   return id;
 }
 
-std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
-                                            VarKind vk, bool verifyUsage) {
+std::optional<std::pair<VarInfo::ID, bool>>
+VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
   const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
   const auto id = iter->second;
   if (didInsert) {
     vars.emplace_back(id, name, loc, vk);
   } else {
-#ifndef NDEBUG
-    assertInternalConsistency(*this, id, name);
-    if (verifyUsage)
-      assertUsageConsistency(*this, id, loc, vk);
-#endif // NDEBUG
+  if (!isInternalConsistent(*this, id, name))
+    return std::nullopt;
+  if (verifyUsage)
+    if (!isUsageConsistent(*this, id, loc, vk))
+      return std::nullopt;
   }
   return std::make_pair(id, didInsert);
 }
@@ -265,20 +255,18 @@ VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
   case Policy::MustNot: {
     const auto oid = lookup(name);
     if (!oid)
-      return std::nullopt; // Doesn't exist, but must not create.
-#ifndef NDEBUG
-    assertUsageConsistency(*this, *oid, loc, vk);
-#endif // NDEBUG
+      return std::nullopt;  // Doesn't exist, but must not create.
+    if (!isUsageConsistent(*this, *oid, loc, vk))
+      return std::nullopt;
     return std::make_pair(*oid, false);
   }
   case Policy::May:
     return create(name, loc, vk, /*verifyUsage=*/true);
   case Policy::Must: {
     const auto res = create(name, loc, vk, /*verifyUsage=*/false);
-    // const auto id = res.first;
-    const auto didCreate = res.second;
+    const auto didCreate = res->second;
     if (!didCreate)
-      return std::nullopt; // Already exists, but must create.
+      return std::nullopt;  // Already exists, but must create.
     return res;
   }
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
index a488b3ea2d56ba4..145586a83a2528c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
@@ -453,8 +453,8 @@ class VarEnv final {
   /// for the variable with the given name (i.e., either the newly created
   /// variable, or the pre-existing variable), and a bool indicating whether
   /// a new variable was created.
-  std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
-                                      VarKind vk, bool verifyUsage = false);
+  std::optional<std::pair<VarInfo::ID, bool>>
+  create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);
 
   /// Attempts to lookup or create a variable according to the given
   /// `Policy`.  Returns nullopt in one of two circumstances:

diff  --git a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
index 42eb4e0a46182e7..8adf981d00051c5 100644
--- a/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
@@ -1,7 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
-// expected-error at +1 {{expected a non-empty array for lvlTypes}}
-#a = #sparse_tensor.encoding<{lvlTypes = []}>
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = []}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '->'}}
+#a = #sparse_tensor.encoding<{map = ()}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ')' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0 -> d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map = d0 -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> d0}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
+func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
 func.func private @scalar(%arg0: tensor<f64, #a>) -> ()
 
 // -----
@@ -18,17 +60,61 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = [1]}> // expected-error {{expected a string value in lvlTypes}}
+// expected-error at +1 {{unexpected dimToLvl mapping from 2 to 1}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
+func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected bare identifier}}
+#a = #sparse_tensor.encoding<{map = (1)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{unexpected key: nap}}
+#a = #sparse_tensor.encoding<{nap = (d0) -> (d0 : dense)}>
+func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '(' in dimension-specifier list}}
+#a = #sparse_tensor.encoding<{map =  -> (d0 : dense)}>
 func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{lvlTypes = ["strange"]}> // expected-error {{unexpected level-type: strange}}
+// expected-error at +1 {{unknown level format: strange}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : strange)}>
 func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
 
-#a = #sparse_tensor.encoding<{dimToLvl = "wrong"}> // expected-error {{expected an affine map for dimToLvl}}
+// expected-error at +1 {{expected valid level format (e.g. dense, compressed or singleton)}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected valid level property (e.g. nonordered, nonunique or high)}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{expected ')' in level-specifier list}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed[high])}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{unknown level property: wrong}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed(wrong))}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed, dense)}>
 func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
 
 // -----
@@ -39,6 +125,73 @@ func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()
 
 // -----
 
+// expected-error at +1 {{unexpected character}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed; d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0: d1) -> (d0 : compressed, d1 : dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected ':'}}
+#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 = compressed, d1 = dense)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected attribute value}}
+#a = #sparse_tensor.encoding<{map = (d0 : compressed, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = compressed, d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = (d0 = l0, d1 = l1) {l0, l1} -> (l0 = d0 : dense, l1 = d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
+// expected-error at +1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : d0 = dense, l1 : d1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : l0 = dense, d1 : l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : dense, d1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{expected '='}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : dense, l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = dense, l1 = compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+// expected-error at +1 {{use of undeclared identifier 'd0'}}
+#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 = l0 : dense, d1 = l1 : compressed)}>
+func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()
+
+// -----
+
 #a = #sparse_tensor.encoding<{posWidth = "x"}> // expected-error {{expected an integral position bitwidth}}
 func.func private @tensor_no_int_ptr(%arg0: tensor<16x32xf32, #a>) -> ()
 


        


More information about the Mlir-commits mailing list