[Mlir-commits] [mlir] 27df715 - [mlir] Fix dumping invalid ops

Sergei Grechanik llvmlistbot at llvm.org
Mon Mar 7 08:44:42 PST 2022


Author: Sergei Grechanik
Date: 2022-03-07T08:32:31-08:00
New Revision: 27df7158feb2f073c3716b2920ebdee2f5a0eeb6

URL: https://github.com/llvm/llvm-project/commit/27df7158feb2f073c3716b2920ebdee2f5a0eeb6
DIFF: https://github.com/llvm/llvm-project/commit/27df7158feb2f073c3716b2920ebdee2f5a0eeb6.diff

LOG: [mlir] Fix dumping invalid ops

This patch fixes the crash when printing some ops (like affine.for and
scf.for) when they are dumped in invalid state, e.g. during pattern
application. Now the AsmState constructor verifies the operation
first and switches to generic operation printing when the verification
fails. Also operations are now printed in generic form when emitting
diagnostics and the severity level is Error.

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D117834

Added: 
    mlir/test/IR/print-ir-invalid.mlir
    mlir/test/lib/IR/TestPrintInvalid.cpp

Modified: 
    mlir/docs/Diagnostics.md
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OperationSupport.h
    mlir/include/mlir/IR/Value.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Diagnostics.cpp
    mlir/lib/IR/Operation.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md
index 221f11a8ece62..a41b98ebc2d30 100644
--- a/mlir/docs/Diagnostics.md
+++ b/mlir/docs/Diagnostics.md
@@ -107,6 +107,18 @@ op->emitError() << "Compose an interesting error: " << fooAttr << ", " << fooTyp
 "Compose an interesting error: @foo, i32, (0, 1, 2)"
 ```
 
+Operations attached to a diagnostic will be printed in generic form if the
+severity level is `Error`, otherwise custom operation printers will be used.
+```c++
+// `anotherOp` will be printed in generic form,
+// e.g. %3 = "arith.addf"(%arg4, %2) : (f32, f32) -> f32
+op->emitError() << anotherOp;
+
+// `anotherOp` will be printed using the custom printer,
+// e.g. %3 = arith.addf %arg4, %2 : f32
+op->emitRemark() << anotherOp;
+```
+
 ### Attaching notes
 
 Unlike many other compiler frameworks, notes in MLIR cannot be emitted directly.

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 87c33b5574937..dc25ef0de5f7c 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -601,6 +601,15 @@ Note that the second phase will be run after the operations in the region are
 verified. Verifiers further down the order can rely on certain invariants being
 verified by a previous verifier and do not need to re-verify them.
 
+#### Emitting diagnostics in custom verifiers
+
+Custom verifiers should avoid printing operations using custom operation
+printers, because they require the printed operation (and sometimes its parent
+operation) to be verified first. In particular, when emitting diagnostics,
+custom verifiers should use the `Error` severity level, which prints operations
+in generic form by default, and avoid using lower severity levels (`Note`,
+`Remark`, `Warning`).
+
 ### Declarative Assembly Format
 
 The custom assembly form of the operation may be specified in a declarative

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index f72c24480d71c..3707747b5bff0 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -726,6 +726,9 @@ class OpPrintingFlags {
   /// Always print operations in the generic form.
   OpPrintingFlags &printGenericOpForm();
 
+  /// Do not verify the operation when using custom operation printers.
+  OpPrintingFlags &assumeVerified();
+
   /// Use local scope when printing the operation. This allows for using the
   /// printer in a more localized and thread-safe setting, but may not
   /// necessarily be identical to what the IR will look like when dumping
@@ -747,6 +750,9 @@ class OpPrintingFlags {
   /// Return if operations should be printed in the generic form.
   bool shouldPrintGenericOpForm() const;
 
+  /// Return if operation verification should be skipped.
+  bool shouldAssumeVerified() const;
+
   /// Return if the printer should use local scope when dumping the IR.
   bool shouldUseLocalScope() const;
 
@@ -762,6 +768,9 @@ class OpPrintingFlags {
   /// Print operations in the generic form.
   bool printGenericOpFormFlag : 1;
 
+  /// Skip operation verification.
+  bool assumeVerifiedFlag : 1;
+
   /// Print operations with numberings local to the current operation.
   bool printLocalScope : 1;
 };

diff  --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 65f5b8cb1eab9..2f0524c4da7df 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -24,6 +24,7 @@ class Block;
 class BlockArgument;
 class Operation;
 class OpOperand;
+class OpPrintingFlags;
 class OpResult;
 class Region;
 class Value;
@@ -215,6 +216,7 @@ class Value {
   // Utilities
 
   void print(raw_ostream &os);
+  void print(raw_ostream &os, const OpPrintingFlags &flags);
   void print(raw_ostream &os, AsmState &state);
   void dump();
 

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3fc07f37c4ac8..980c886744ad5 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -25,6 +25,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SubElementInterfaces.h"
+#include "mlir/IR/Verifier.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
@@ -40,6 +41,7 @@
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SaveAndRestore.h"
+#include "llvm/Support/Threading.h"
 
 #include <tuple>
 
@@ -141,6 +143,11 @@ struct AsmPrinterOptions {
       "mlir-print-op-generic", llvm::cl::init(false),
       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
 
+  llvm::cl::opt<bool> assumeVerifiedOpt{
+      "mlir-print-assume-verified", llvm::cl::init(false),
+      llvm::cl::desc("Skip op verification when using custom printers"),
+      llvm::cl::Hidden};
+
   llvm::cl::opt<bool> printLocalScopeOpt{
       "mlir-print-local-scope", llvm::cl::init(false),
       llvm::cl::desc("Print with local scope and inline information (eliding "
@@ -160,7 +167,8 @@ void mlir::registerAsmPrinterCLOptions() {
 /// Initialize the printing flags with default supplied by the cl::opts above.
 OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
-      printGenericOpFormFlag(false), printLocalScope(false) {
+      printGenericOpFormFlag(false), assumeVerifiedFlag(false),
+      printLocalScope(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -169,6 +177,7 @@ OpPrintingFlags::OpPrintingFlags()
   printDebugInfoFlag = clOptions->printDebugInfoOpt;
   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
+  assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
   printLocalScope = clOptions->printLocalScopeOpt;
 }
 
@@ -196,6 +205,12 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
   return *this;
 }
 
+/// Do not verify the operation when using custom operation printers.
+OpPrintingFlags &OpPrintingFlags::assumeVerified() {
+  assumeVerifiedFlag = true;
+  return *this;
+}
+
 /// Use local scope when printing the operation. This allows for using the
 /// printer in a more localized and thread-safe setting, but may not necessarily
 /// be identical of what the IR will look like when dumping the full module.
@@ -231,6 +246,11 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
   return printGenericOpFormFlag;
 }
 
+/// Return if operation verification should be skipped.
+bool OpPrintingFlags::shouldAssumeVerified() const {
+  return assumeVerifiedFlag;
+}
+
 /// Return if the printer should use local scope when dumping the IR.
 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
 
@@ -1245,9 +1265,31 @@ class AsmStateImpl {
 } // namespace detail
 } // namespace mlir
 
+/// Verifies the operation and switches to generic op printing if verification
+/// fails. We need to do this because custom print functions may fail for
+/// invalid ops.
+static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
+                                              OpPrintingFlags printerFlags) {
+  if (printerFlags.shouldPrintGenericOpForm() ||
+      printerFlags.shouldAssumeVerified())
+    return printerFlags;
+
+  // Ignore errors emitted by the verifier. We check the thread id to avoid
+  // consuming other threads' errors.
+  auto parentThreadId = llvm::get_threadid();
+  ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) {
+    return success(parentThreadId == llvm::get_threadid());
+  });
+  if (failed(verify(op)))
+    printerFlags.printGenericOpForm();
+
+  return printerFlags;
+}
+
 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
                    LocationMap *locationMap)
-    : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
+    : impl(std::make_unique<AsmStateImpl>(
+          op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
 AsmState::~AsmState() = default;
 
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -2853,14 +2895,15 @@ void IntegerSet::print(raw_ostream &os) const {
   AsmPrinter::Impl(os).printIntegerSet(*this);
 }
 
-void Value::print(raw_ostream &os) {
+void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
+void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
   if (!impl) {
     os << "<<NULL VALUE>>";
     return;
   }
 
   if (auto *op = getDefiningOp())
-    return op->print(os);
+    return op->print(os, flags);
   // TODO: Improve BlockArgument print'ing.
   BlockArgument arg = this->cast<BlockArgument>();
   os << "<block argument> of type '" << arg.getType()

diff  --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index ea0ff5aa74955..975f6943fc673 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -121,6 +121,17 @@ Diagnostic &Diagnostic::operator<<(OperationName val) {
   return *this;
 }
 
+/// Adjusts operation printing flags used in diagnostics for the given severity
+/// level.
+static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
+                                           DiagnosticSeverity severity) {
+  flags.useLocalScope();
+  flags.elideLargeElementsAttrs();
+  if (severity == DiagnosticSeverity::Error)
+    flags.printGenericOpForm();
+  return flags;
+}
+
 /// Stream in an Operation.
 Diagnostic &Diagnostic::operator<<(Operation &val) {
   return appendOp(val, OpPrintingFlags());
@@ -128,8 +139,7 @@ Diagnostic &Diagnostic::operator<<(Operation &val) {
 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
   std::string str;
   llvm::raw_string_ostream os(str);
-  val.print(os,
-            OpPrintingFlags(flags).useLocalScope().elideLargeElementsAttrs());
+  val.print(os, adjustPrintingFlags(flags, severity));
   return *this << os.str();
 }
 
@@ -137,7 +147,7 @@ Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
 Diagnostic &Diagnostic::operator<<(Value val) {
   std::string str;
   llvm::raw_string_ostream os(str);
-  val.print(os);
+  val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
   return *this << os.str();
 }
 

diff  --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 89a6b27c5feaa..ea68f11e126a4 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -1097,6 +1097,8 @@ LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) {
           // Check that any value that is used by an operation is defined in the
           // same region as either an operation result.
           auto *operandRegion = operand.getParentRegion();
+          if (!operandRegion)
+            return op.emitError("operation's operand is unlinked");
           if (!region.isAncestor(operandRegion)) {
             return op.emitOpError("using value defined outside the region")
                        .attachNote(isolatedOp->getLoc())

diff  --git a/mlir/test/IR/print-ir-invalid.mlir b/mlir/test/IR/print-ir-invalid.mlir
new file mode 100644
index 0000000000000..e83353433c7e4
--- /dev/null
+++ b/mlir/test/IR/print-ir-invalid.mlir
@@ -0,0 +1,33 @@
+// # RUN: mlir-opt -test-print-invalid %s | FileCheck %s
+// # RUN: mlir-opt -test-print-invalid %s --mlir-print-assume-verified  | FileCheck %s --check-prefix=ASSUME-VERIFIED
+
+// The pass creates some ops and prints them to stdout, the input is just an
+// empty module.
+module {}
+
+// The operation is invalid because the body does not have a terminator, print
+// the generic form.
+// CHECK:      Invalid operation:
+// CHECK-NEXT: "builtin.func"() ({
+// CHECK-NEXT: ^bb0:
+// CHECK-NEXT: })
+// CHECK-SAME: sym_name = "test"
+
+// The operation is valid because the body has a terminator, print the custom
+// form.
+// CHECK:      Valid operation:
+// CHECK-NEXT: func @test() {
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// With --mlir-print-assume-verified the custom form is printed in both cases.
+// This works in this particular case, but may crash in general.
+
+// ASSUME-VERIFIED:      Invalid operation:
+// ASSUME-VERIFIED-NEXT: func @test() {
+// ASSUME-VERIFIED-NEXT: }
+
+// ASSUME-VERIFIED:      Valid operation:
+// ASSUME-VERIFIED-NEXT: func @test() {
+// ASSUME-VERIFIED-NEXT:   return
+// ASSUME-VERIFIED-NEXT: }

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index f656a4e6934ef..a195817f4fe9f 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
   TestOpaqueLoc.cpp
   TestOperationEquals.cpp
   TestPrintDefUse.cpp
+  TestPrintInvalid.cpp
   TestPrintNesting.cpp
   TestSideEffects.cpp
   TestSlicing.cpp

diff  --git a/mlir/test/lib/IR/TestPrintInvalid.cpp b/mlir/test/lib/IR/TestPrintInvalid.cpp
new file mode 100644
index 0000000000000..537af8be069b1
--- /dev/null
+++ b/mlir/test/lib/IR/TestPrintInvalid.cpp
@@ -0,0 +1,52 @@
+//===- TestPrintInvalid.cpp - Test printing invalid ops -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass creates and prints to the standard output an invalid operation and
+// a valid operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+struct TestPrintInvalidPass
+    : public PassWrapper<TestPrintInvalidPass, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-print-invalid"; }
+  StringRef getDescription() const final {
+    return "Test printing invalid ops.";
+  }
+  void getDependentDialects(DialectRegistry &registry) const {
+    registry.insert<func::FuncDialect>();
+  }
+
+  void runOnOperation() override {
+    Location loc = getOperation().getLoc();
+    OpBuilder builder(getOperation().body());
+    auto funcOp = builder.create<FuncOp>(
+        loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
+    funcOp.addEntryBlock();
+    // The created function is invalid because there is no return op.
+    llvm::outs() << "Invalid operation:\n" << funcOp << "\n";
+    builder.setInsertionPointToEnd(&funcOp.getBody().front());
+    builder.create<func::ReturnOp>(loc);
+    // Now this function is valid.
+    llvm::outs() << "Valid operation:\n" << funcOp << "\n";
+    funcOp.erase();
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestPrintInvalidPass() {
+  PassRegistration<TestPrintInvalidPass>{};
+}
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 7336980583bb2..9c6431779354f 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -45,6 +45,7 @@ void registerTestLoopPermutationPass();
 void registerTestMatchers();
 void registerTestOperationEqualPass();
 void registerTestPrintDefUsePass();
+void registerTestPrintInvalidPass();
 void registerTestPrintNestingPass();
 void registerTestReducer();
 void registerTestSpirvEntryPointABIPass();
@@ -132,6 +133,7 @@ void registerTestPasses() {
   registerTestMatchers();
   registerTestOperationEqualPass();
   registerTestPrintDefUsePass();
+  registerTestPrintInvalidPass();
   registerTestPrintNestingPass();
   registerTestReducer();
   registerTestSpirvEntryPointABIPass();


        


More information about the Mlir-commits mailing list