[Mlir-commits] [mlir] 5b29f86 - [mlir] Fix verifier of `RegionBranchOpInterface`

Markus Böck llvmlistbot at llvm.org
Thu Aug 10 03:52:22 PDT 2023


Author: Markus Böck
Date: 2023-08-10T12:38:54+02:00
New Revision: 5b29f86b42bab223ed8b9f2738802138aade74e0

URL: https://github.com/llvm/llvm-project/commit/5b29f86b42bab223ed8b9f2738802138aade74e0
DIFF: https://github.com/llvm/llvm-project/commit/5b29f86b42bab223ed8b9f2738802138aade74e0.diff

LOG: [mlir] Fix verifier of `RegionBranchOpInterface`

The verifier incorrectly passed the region number of the predecessor region instead of the successor region to `getSuccessorOperands`. This went unnoticed since all upstream `RegionBranchTerminatorOpInterface` implementations did not make use of the `index` parameter.
Adding an assert to e.g. `scf.condition` to make sure the index is valid or adding a region terminator that passes different operands to different successors immediately causes the verifier to fail as it suddenly gets incorrect types.

This patch fixes the implementation to correctly pass the successor region index.

Differential Revision: https://reviews.llvm.org/D157507

Added: 
    mlir/test/IR/test-region-branch-op-verifier.mlir

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
    mlir/lib/Dialect/SCF/IR/SCF.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index b468a6bb3f9097..98184718611a54 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -445,6 +445,10 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
           llvm::find(regionSuccessor->getSuccessorInputs(), argValue)
               .getIndex();
 
+      std::optional<unsigned> successorRegionNumber;
+      if (Region *successorRegion = regionSuccessor->getSuccessor())
+        successorRegionNumber = successorRegion->getRegionNumber();
+
       // Iterate over all immediate terminator operations to introduce
       // new buffer allocations. Thereby, the appropriate terminator operand
       // will be adjusted to point to the newly allocated buffer instead.
@@ -453,7 +457,7 @@ class BufferDeallocation : public BufferPlacementTransformationBase {
                 // Get the actual mutable operands for this terminator op.
                 auto terminatorOperands =
                     terminator.getMutableSuccessorOperands(
-                        region.getRegionNumber());
+                        successorRegionNumber);
                 // Extract the source value from the current terminator.
                 // This conversion needs to exist on a separate line due to a
                 // bug in GCC conversion analysis.

diff  --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2cc9e2c895666e..750fe9c021673f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -283,6 +283,9 @@ void ExecuteRegionOp::getSuccessorRegions(
 
 MutableOperandRange
 ConditionOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
+  assert((!index || index == getParentOp().getAfter().getRegionNumber()) &&
+         "condition op can only exit the loop or branch to the after"
+         "region");
   // Pass all operands except the condition to the successor region.
   return getArgsMutable();
 }

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index b3690ab8961555..cc90da370de693 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -84,6 +84,23 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
+static InFlightDiagnostic &
+printRegionEdgeName(InFlightDiagnostic &diag, std::optional<unsigned> sourceNo,
+                    std::optional<unsigned> succRegionNo) {
+  diag << "from ";
+  if (sourceNo)
+    diag << "Region #" << sourceNo.value();
+  else
+    diag << "parent operands";
+
+  diag << " to ";
+  if (succRegionNo)
+    diag << "Region #" << succRegionNo.value();
+  else
+    diag << "parent results";
+  return diag;
+}
+
 /// Verify that types match along all region control flow edges originating from
 /// `sourceNo` (region # if source is a region, std::nullopt if source is parent
 /// op). `getInputsTypesForRegion` is a function that returns the types of the
@@ -92,7 +109,7 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 /// the match itself).
 static LogicalResult verifyTypesAlongAllEdges(
     Operation *op, std::optional<unsigned> sourceNo,
-    function_ref<std::optional<TypeRange>(std::optional<unsigned>)>
+    function_ref<FailureOr<TypeRange>(std::optional<unsigned>)>
         getInputsTypesForRegion) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
@@ -104,32 +121,17 @@ static LogicalResult verifyTypesAlongAllEdges(
     if (!succ.isParent())
       succRegionNo = succ.getSuccessor()->getRegionNumber();
 
-    auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
-      diag << "from ";
-      if (sourceNo)
-        diag << "Region #" << sourceNo.value();
-      else
-        diag << "parent operands";
-
-      diag << " to ";
-      if (succRegionNo)
-        diag << "Region #" << succRegionNo.value();
-      else
-        diag << "parent results";
-      return diag;
-    };
-
-    std::optional<TypeRange> sourceTypes =
-        getInputsTypesForRegion(succRegionNo);
-    if (!sourceTypes.has_value())
-      continue;
+    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
+    if (failed(sourceTypes))
+      return failure();
 
     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
     if (sourceTypes->size() != succInputsTypes.size()) {
       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
-      return printEdgeName(diag) << ": source has " << sourceTypes->size()
-                                 << " operands, but target successor needs "
-                                 << succInputsTypes.size();
+      return printRegionEdgeName(diag, sourceNo, succRegionNo)
+             << ": source has " << sourceTypes->size()
+             << " operands, but target successor needs "
+             << succInputsTypes.size();
     }
 
     for (const auto &typesIdx :
@@ -138,7 +140,7 @@ static LogicalResult verifyTypesAlongAllEdges(
       Type inputType = std::get<1>(typesIdx.value());
       if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
         InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
-        return printEdgeName(diag)
+        return printRegionEdgeName(diag, sourceNo, succRegionNo)
                << ": source type #" << typesIdx.index() << " " << sourceType
                << " should match input type #" << typesIdx.index() << " "
                << inputType;
@@ -177,45 +179,48 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
   for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
     Region &region = op->getRegion(regionNo);
 
-    // Since there can be multiple `ReturnLike` terminators or others
-    // implementing the `RegionBranchTerminatorOpInterface`, all should have the
-    // same operand types when passing them to the same region.
-
-    std::optional<OperandRange> regionReturnOperands;
-    for (Block &block : region) {
-      auto terminator =
-          dyn_cast<RegionBranchTerminatorOpInterface>(block.getTerminator());
-      if (!terminator)
-        continue;
-
-      OperandRange terminatorOperands =
-          terminator.getSuccessorOperands(regionNo);
-      if (!regionReturnOperands) {
-        regionReturnOperands = terminatorOperands;
-        continue;
-      }
+    // Since there can be multiple terminators implementing the
+    // `RegionBranchTerminatorOpInterface`, all should have the same operand
+    // types when passing them to the same region.
 
-      // Found more than one ReturnLike terminator. Make sure the operand types
-      // match with the first one.
-      if (!areTypesCompatible(regionReturnOperands->getTypes(),
-                              terminatorOperands.getTypes()))
-        return op->emitOpError("Region #")
-               << regionNo
-               << " operands mismatch between return-like terminators";
-    }
+    SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
+    for (Block &block : region)
+      if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
+              block.getTerminator()))
+        regionReturnOps.push_back(terminator);
 
-    auto inputTypesFromRegion =
-        [&](std::optional<unsigned> regionNo) -> std::optional<TypeRange> {
-      // If there is no return-like terminator, the op itself should verify
-      // type consistency.
-      if (!regionReturnOperands)
-        return std::nullopt;
+    // If there is no return-like terminator, the op itself should verify
+    // type consistency.
+    if (regionReturnOps.empty())
+      continue;
+
+    auto inputTypesForRegion =
+        [&](std::optional<unsigned> succRegionNo) -> FailureOr<TypeRange> {
+      std::optional<OperandRange> regionReturnOperands;
+      for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+        auto terminatorOperands =
+            regionReturnOp.getSuccessorOperands(succRegionNo);
+
+        if (!regionReturnOperands) {
+          regionReturnOperands = terminatorOperands;
+          continue;
+        }
+
+        // Found more than one ReturnLike terminator. Make sure the operand
+        // types match with the first one.
+        if (!areTypesCompatible(regionReturnOperands->getTypes(),
+                                terminatorOperands.getTypes())) {
+          InFlightDiagnostic diag = op->emitOpError(" along control flow edge");
+          return printRegionEdgeName(diag, regionNo, succRegionNo)
+                 << " operands mismatch between return-like terminators";
+        }
+      }
 
       // All successors get the same set of operand types.
       return TypeRange(regionReturnOperands->getTypes());
     };
 
-    if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
+    if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion)))
       return failure();
   }
 

diff  --git a/mlir/test/IR/test-region-branch-op-verifier.mlir b/mlir/test/IR/test-region-branch-op-verifier.mlir
new file mode 100644
index 00000000000000..f5fb7fc2b25cb9
--- /dev/null
+++ b/mlir/test/IR/test-region-branch-op-verifier.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s
+
+func.func @test_ops_verify(%arg: i32) -> f32 {
+  %0 = "test.constant"() { value = 5.3 : f32 } : () -> f32
+  %1 = test.loop_block %arg : (i32) -> f32 {
+  ^bb0(%arg1 : i32):
+    test.loop_block_term iter %arg exit %0
+  }
+  return %1 : f32
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ed97efa462be9f..6f3e33052372e8 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -980,6 +980,37 @@ void AnyCondOp::getRegionInvocationBounds(
   invocationBounds.emplace_back(1, 1);
 }
 
+//===----------------------------------------------------------------------===//
+// LoopBlockOp
+//===----------------------------------------------------------------------===//
+
+void LoopBlockOp::getSuccessorRegions(
+    std::optional<unsigned> index, SmallVectorImpl<RegionSuccessor> &regions) {
+  regions.emplace_back(&getBody(), getBody().getArguments());
+  if (!index)
+    return;
+
+  regions.emplace_back((*this)->getResults());
+}
+
+OperandRange
+LoopBlockOp::getEntrySuccessorOperands(std::optional<unsigned> index) {
+  assert(index == 0);
+  return getInitMutable();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopBlockTerminatorOp
+//===----------------------------------------------------------------------===//
+
+MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(
+    std::optional<unsigned> index) {
+  assert(!index || index == 0);
+  if (!index)
+    return getExitArgMutable();
+  return getNextIterArgMutable();
+}
+
 //===----------------------------------------------------------------------===//
 // SingleNoTerminatorCustomAsmOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index b3678868e17dec..0b121d7a185c7d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2084,6 +2084,30 @@ def AnyCondOp : TEST_Op<"any_cond",
   let regions = (region AnyRegion:$region);
 }
 
+def LoopBlockOp : TEST_Op<"loop_block",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getEntrySuccessorOperands"]>, RecursiveMemoryEffects]> {
+
+  let results = (outs F32:$floatResult);
+  let arguments = (ins I32:$init);
+  let regions = (region SizedRegion<1>:$body);
+
+  let assemblyFormat = [{
+    $init `:` functional-type($init, $floatResult) $body
+    attr-dict-with-keyword
+  }];
+}
+
+def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
+    [DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, Pure,
+     Terminator]> {
+  let arguments = (ins I32:$nextIterArg, F32:$exitArg);
+
+  let assemblyFormat = [{
+    `iter` $nextIterArg `exit` $exitArg attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test TableGen generated build() methods
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list