[Mlir-commits] [mlir] 9bdf1ab - [mlir-tblgen] Slightly improve the diagnostic message in pattern match

Chia-hung Duan llvmlistbot at llvm.org
Sun Jul 18 18:22:04 PDT 2021


Author: Chia-hung Duan
Date: 2021-07-19T09:19:51+08:00
New Revision: 9bdf1ab70be76fa0d6ac27077b7f2284d30929ac

URL: https://github.com/llvm/llvm-project/commit/9bdf1ab70be76fa0d6ac27077b7f2284d30929ac
DIFF: https://github.com/llvm/llvm-project/commit/9bdf1ab70be76fa0d6ac27077b7f2284d30929ac.diff

LOG: [mlir-tblgen] Slightly improve the diagnostic message in pattern match

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D105883

Added: 
    

Modified: 
    mlir/tools/mlir-tblgen/RewriterGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index e0112af6b5b03..611bc5c1c05ea 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -255,7 +255,6 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
 
   raw_indented_ostream::DelimitedScope scope(os);
 
-  os << "if(!" << opName << ") return ::mlir::failure();\n";
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     std::string argName = formatv("arg{0}_{1}", depth, i);
     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
@@ -277,15 +276,15 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
 
   auto fmt = tree.getNativeCodeTemplate();
-  if (fmt.count("$_self") != 1) {
+  if (fmt.count("$_self") != 1)
     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
                          "passing the defining Operation");
-  }
 
   auto nativeCodeCall = std::string(tgfmt(
       fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));
 
-  os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n";
+  emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
+                 formatv("\"{0} return failure\"", nativeCodeCall));
 
   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
     auto name = tree.getArgName(i);
@@ -338,20 +337,21 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
                           << '\n');
 
   std::string castedName = formatv("castedOp{0}", depth);
-  os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); "
+  os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
                 "(void){0};\n",
                 castedName, opName, op.getQualCppClassName());
+
   // Skip the operand matching at depth 0 as the pattern rewriter already does.
-  if (depth != 0) {
-    // Skip if there is no defining operation (e.g., arguments to function).
-    os << formatv("if (!{0}) return ::mlir::failure();\n", castedName);
-  }
-  if (tree.getNumArgs() != op.getNumArgs()) {
+  if (depth != 0)
+    emitMatchCheck(opName, /*matchStr=*/castedName,
+                   formatv("\"{0} is not {1} type\"", castedName,
+                           op.getQualCppClassName()));
+
+  if (tree.getNumArgs() != op.getNumArgs())
     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
                                  "pattern vs. {2} in definition",
                                  op.getOperationName(), tree.getNumArgs(),
                                  op.getNumArgs()));
-  }
 
   // If the operand's name is set, set to that variable.
   auto name = tree.getSymbol();
@@ -379,7 +379,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
       os.indent() << formatv(
           "auto *{0} = "
           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
-          argName, castedName, nextOperand++);
+          argName, castedName, nextOperand);
+      // Null check of operand's definingOp
+      emitMatchCheck(castedName, /*matchStr=*/argName,
+                     formatv("\"Operand {0} of {1} has null definingOp\"",
+                             nextOperand++, castedName));
       emitMatch(argTree, argName, depth + 1);
       os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
       os.unindent() << "}\n";


        


More information about the Mlir-commits mailing list