[Mlir-commits] [mlir] 50f82e6 - [mlir] Fix missing verification after running an OpToOpAdaptorPass

River Riddle llvmlistbot at llvm.org
Wed Mar 16 14:54:09 PDT 2022


Author: River Riddle
Date: 2022-03-16T14:53:41-07:00
New Revision: 50f82e68470c3efbb8ceae8f8c8d289a079d7031

URL: https://github.com/llvm/llvm-project/commit/50f82e68470c3efbb8ceae8f8c8d289a079d7031
DIFF: https://github.com/llvm/llvm-project/commit/50f82e68470c3efbb8ceae8f8c8d289a079d7031.diff

LOG: [mlir] Fix missing verification after running an OpToOpAdaptorPass

The current decision of when to run the verifier is running on the
assumption that nested passes can't affect the validity of the parent
operation, which isn't true. Parent operations may attach any number
of constraints on nested operations, which may not necessarily be
captured (or shouldn't be captured) at a smaller granularity.

This commit rectifies this by properly running the verifier after an
OpToOpAdaptor pass. To avoid an explosive increase in compile time,
we only run verification on the parent operation itself. To do this, a
flag to mlir::verify is added to avoid recursive verification if it isn't
desired.

Fixes #54288

Differential Revision: https://reviews.llvm.org/D121836

Added: 
    mlir/test/Pass/invalid-parent.mlir

Modified: 
    mlir/include/mlir/IR/Verifier.h
    mlir/lib/IR/Verifier.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Pass/CMakeLists.txt
    mlir/test/lib/Pass/TestPassManager.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Verifier.h b/mlir/include/mlir/IR/Verifier.h
index 3e1a0f858f14e..1fcc99e63206d 100644
--- a/mlir/include/mlir/IR/Verifier.h
+++ b/mlir/include/mlir/IR/Verifier.h
@@ -15,8 +15,12 @@ class Operation;
 
 /// Perform (potentially expensive) checks of invariants, used to detect
 /// compiler bugs, on this operation and any nested operations. On error, this
-/// reports the error through the MLIRContext and returns failure.
-LogicalResult verify(Operation *op);
+/// reports the error through the MLIRContext and returns failure. If
+/// `verifyRecursively` is false, this assumes that nested operations have
+/// already been properly verified, and does not recursively invoke the verifier
+/// on nested operations.
+LogicalResult verify(Operation *op, bool verifyRecursively = true);
+
 } // namespace mlir
 
 #endif

diff  --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 0c8724de8cdab..62212dbaf6070 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -43,6 +43,11 @@ namespace {
 /// This class encapsulates all the state used to verify an operation region.
 class OperationVerifier {
 public:
+  /// If `verifyRecursively` is true, then this will also recursively verify
+  /// nested operations.
+  explicit OperationVerifier(bool verifyRecursively)
+      : verifyRecursively(verifyRecursively) {}
+
   /// Verify the given operation.
   LogicalResult verifyOpAndDominance(Operation &op);
 
@@ -61,6 +66,10 @@ class OperationVerifier {
   /// Operation.
   LogicalResult verifyDominanceOfContainedRegions(Operation &op,
                                                   DominanceInfo &domInfo);
+
+  /// A flag indicating if this verifier should recursively verify nested
+  /// operations.
+  bool verifyRecursively;
 };
 } // namespace
 
@@ -81,8 +90,12 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {
       return failure();
   }
 
-  // Check the dominance properties and invariants of any operations in the
-  // regions contained by the 'opsWithIsolatedRegions' operations.
+  // If we aren't verifying nested operations, then we're done.
+  if (!verifyRecursively)
+    return success();
+
+  // Otherwise, check the dominance properties and invariants of any operations
+  // in the regions contained by the 'opsWithIsolatedRegions' operations.
   return failableParallelForEach(
       op.getContext(), opsWithIsolatedRegions,
       [&](Operation *op) { return verifyOpAndDominance(*op); });
@@ -120,21 +133,25 @@ LogicalResult OperationVerifier::verifyBlock(
 
   // Check each operation, and make sure there are no branches out of the
   // middle of this block.
-  for (auto &op : block) {
+  for (Operation &op : block) {
     // Only the last instructions is allowed to have successors.
     if (op.getNumSuccessors() != 0 && &op != &block.back())
       return op.emitError(
           "operation with block successors must terminate its parent block");
 
+    // If we aren't verifying recursievly, there is nothing left to check.
+    if (!verifyRecursively)
+      continue;
+
     // If this operation has regions and is IsolatedFromAbove, we defer
     // checking.  This allows us to parallelize verification better.
     if (op.getNumRegions() != 0 &&
         op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
       opsWithIsolatedRegions.push_back(&op);
-    } else {
+
       // Otherwise, check the operation inline.
-      if (failed(verifyOperation(op, opsWithIsolatedRegions)))
-        return failure();
+    } else if (failed(verifyOperation(op, opsWithIsolatedRegions))) {
+      return failure();
     }
   }
 
@@ -185,8 +202,9 @@ LogicalResult OperationVerifier::verifyOperation(
     auto kindInterface = dyn_cast<RegionKindInterface>(op);
 
     // Verify that all child regions are ok.
+    MutableArrayRef<Region> regions = op.getRegions();
     for (unsigned i = 0; i < numRegions; ++i) {
-      Region &region = op.getRegion(i);
+      Region &region = regions[i];
       RegionKind kind =
           kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
       // Check that Graph Regions only have a single basic block. This is
@@ -210,10 +228,13 @@ LogicalResult OperationVerifier::verifyOperation(
         return emitError(op.getLoc(),
                          "entry block of region may not have predecessors");
 
-      // Verify each of the blocks within the region.
-      for (Block &block : region)
-        if (failed(verifyBlock(block, opsWithIsolatedRegions)))
-          return failure();
+      // Verify each of the blocks within the region if we are verifying
+      // recursively.
+      if (verifyRecursively) {
+        for (Block &block : region)
+          if (failed(verifyBlock(block, opsWithIsolatedRegions)))
+            return failure();
+      }
     }
   }
 
@@ -330,10 +351,10 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
           }
         }
 
-        // Recursively verify dominance within each operation in the
-        // block, even if the block itself is not reachable, or we are in
-        // a region which doesn't respect dominance.
-        if (op.getNumRegions() != 0) {
+        // Recursively verify dominance within each operation in the block, even
+        // if the block itself is not reachable, or we are in a region which
+        // doesn't respect dominance.
+        if (verifyRecursively && op.getNumRegions() != 0) {
           // If this operation is IsolatedFromAbove, then we'll handle it in the
           // outer verification loop.
           if (op.hasTrait<OpTrait::IsIsolatedFromAbove>())
@@ -352,9 +373,7 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
 // Entrypoint
 //===----------------------------------------------------------------------===//
 
-/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs.  On error, this reports the error through the MLIRContext and
-/// returns failure.
-LogicalResult mlir::verify(Operation *op) {
-  return OperationVerifier().verifyOpAndDominance(*op);
+LogicalResult mlir::verify(Operation *op, bool verifyRecursively) {
+  OperationVerifier verifier(verifyRecursively);
+  return verifier.verifyOpAndDominance(*op);
 }

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 22a9641bde0e8..7256f44b6adb4 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -408,22 +408,24 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   // failed).
   if (!passFailed && verifyPasses) {
     bool runVerifierNow = true;
+
+    // If the pass is an adaptor pass, we don't run the verifier recursively
+    // because the nested operations should have already been verified after
+    // nested passes had run.
+    bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
+
     // Reduce compile time by avoiding running the verifier if the pass didn't
     // change the IR since the last time the verifier was run:
     //
     //  1) If the pass said that it preserved all analyses then it can't have
     //     permuted the IR.
-    //  2) If we just ran an OpToOpPassAdaptor (e.g. to run function passes
-    //     within a module) then each sub-unit will have been verified on the
-    //     subunit (and those passes aren't allowed to modify the parent).
     //
     // We run these checks in EXPENSIVE_CHECKS mode out of caution.
 #ifndef EXPENSIVE_CHECKS
-    runVerifierNow = !isa<OpToOpPassAdaptor>(pass) &&
-                     !pass->passState->preservedAnalyses.isAll();
+    runVerifierNow = !pass->passState->preservedAnalyses.isAll();
 #endif
     if (runVerifierNow)
-      passFailed = failed(verify(op));
+      passFailed = failed(verify(op, runVerifierRecursively));
   }
 
   // Instrument after the pass has run.

diff  --git a/mlir/test/Pass/invalid-parent.mlir b/mlir/test/Pass/invalid-parent.mlir
new file mode 100644
index 0000000000000..2979ba9e89a4e
--- /dev/null
+++ b/mlir/test/Pass/invalid-parent.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.func(test-pass-invalid-parent)' -verify-diagnostics
+
+// Test that we properly report errors when the parent becomes invalid after running a pass
+// on a child operation.
+// expected-error at below {{'some_unknown_func' does not reference a valid function}}
+func @TestCreateInvalidCallInPass() {
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e0c5eea373b0d..c5305824cf2a3 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -358,6 +358,21 @@ void TestDialect::getCanonicalizationPatterns(
   results.add(&dialectCanonicalizationPattern);
 }
 
+//===----------------------------------------------------------------------===//
+// TestCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TestFoldToCallOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index da1aa2ffd7106..3f3f812b98ca0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -375,6 +375,14 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
 // Test Call Interfaces
 //===----------------------------------------------------------------------===//
 
+def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+}
+
 def ConversionCallOp : TEST_Op<"conversion_call_op",
     [CallOpInterface]> {
   let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);

diff  --git a/mlir/test/lib/Pass/CMakeLists.txt b/mlir/test/lib/Pass/CMakeLists.txt
index 061d40fc978f8..dd90c228cdaf5 100644
--- a/mlir/test/lib/Pass/CMakeLists.txt
+++ b/mlir/test/lib/Pass/CMakeLists.txt
@@ -11,4 +11,11 @@ add_mlir_library(MLIRTestPass
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRPass
+  MLIRTestDialect
+  )
+
+target_include_directories(MLIRTestPass
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
   )

diff  --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 5a759f39412cc..85dc7bf8701ed 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -98,6 +99,27 @@ class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
   }
 };
 
+/// A test pass that always fails to enable testing the failure recovery
+/// mechanisms of the pass manager.
+class TestInvalidParentPass
+    : public PassWrapper<TestInvalidParentPass,
+                         InterfacePass<FunctionOpInterface>> {
+  StringRef getArgument() const final { return "test-pass-invalid-parent"; }
+  StringRef getDescription() const final {
+    return "Test a pass in the pass manager that makes the parent operation "
+           "invalid";
+  }
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<test::TestDialect>();
+  }
+  void runOnOperation() final {
+    FunctionOpInterface op = getOperation();
+    OpBuilder b(getOperation().getBody());
+    b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
+                               ValueRange());
+  }
+};
+
 /// A test pass that contains a statistic.
 struct TestStatisticPass
     : public PassWrapper<TestStatisticPass, OperationPass<>> {
@@ -144,6 +166,7 @@ void registerPassManagerTestPass() {
 
   PassRegistration<TestCrashRecoveryPass>();
   PassRegistration<TestFailurePass>();
+  PassRegistration<TestInvalidParentPass>();
 
   PassRegistration<TestStatisticPass>();
 


        


More information about the Mlir-commits mailing list