[Mlir-commits] [mlir] [mlir][Interfaces] Simplify and improve errors of `RegionBranchOpInterface` verifier (PR #174805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 7 09:10:23 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Simplify the `RegionBranchOpInterface` verifier by utilizing new API functions such as `getAllRegionBranchPoints`.

Also improve the error message by using the same terms that are used in the interface definition: `region branch point`, `region successor`, `successor operand`, `successor input`.

---
Full diff: https://github.com/llvm/llvm-project/pull/174805.diff


4 Files Affected:

- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+1-1) 
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+1-1) 
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+61-104) 
- (modified) mlir/test/Dialect/SCF/invalid.mlir (+8-4) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 566f4b8fadb5d..b76c2891fad5a 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -173,7 +173,7 @@ LogicalResult verifyRegionBranchWeights(Operation *op);
 
 namespace detail {
 /// Verify that types match along control flow edges described the given op.
-LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
+LogicalResult verifyRegionBranchOpInterface(Operation *op);
 } //  namespace detail
 
 /// A mapping from successor operands to successor inputs.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 2e654ba04ffe5..ecad424e30c75 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -323,7 +323,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
   let verify = [{
     static_assert(!ConcreteOp::template hasTrait<OpTrait::ZeroRegions>(),
                   "expected operation to have non-zero regions");
-    return detail::verifyTypesAlongControlFlowEdges($_op);
+    return detail::verifyRegionBranchOpInterface($_op);
   }];
   let verifyWithRegions = 1;
 
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index d393ddb8d8336..2574f4e73d311 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -152,115 +152,68 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
 
-static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
-                                               RegionBranchPoint sourceNo,
-                                               RegionSuccessor succRegionNo) {
-  diag << "from ";
-  if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
-    diag << "Operation " << op->getName();
-  else
-    diag << "parent operands";
-
-  diag << " to ";
-  if (Region *region = succRegionNo.getSuccessor())
-    diag << "Region #" << region->getRegionNumber();
-  else
-    diag << "parent results";
-  return diag;
-}
-
-/// Verify that types match along all region control flow edges originating from
-/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
-/// types of the inputs that flow to a successor region.
-static LogicalResult
-verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
-                         RegionBranchPoint sourcePoint,
-                         function_ref<FailureOr<TypeRange>(RegionSuccessor)>
-                             getInputsTypesForRegion) {
-  SmallVector<RegionSuccessor, 2> successors;
-  branchOp.getSuccessorRegions(sourcePoint, successors);
-
-  for (RegionSuccessor &succ : successors) {
-    FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
-    if (failed(sourceTypes))
-      return failure();
-
-    TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
-    if (sourceTypes->size() != succInputsTypes.size()) {
-      InFlightDiagnostic diag =
-          branchOp->emitOpError("region control flow edge ");
-      std::string succStr;
-      llvm::raw_string_ostream os(succStr);
-      os << succ;
-      return printRegionEdgeName(diag, sourcePoint, succ)
-             << ": source has " << sourceTypes->size()
-             << " operands, but target successor " << os.str() << " needs "
-             << succInputsTypes.size();
-    }
-
-    for (const auto &typesIdx :
-         llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
-      Type sourceType = std::get<0>(typesIdx.value());
-      Type inputType = std::get<1>(typesIdx.value());
-
-      if (!branchOp.areTypesCompatible(sourceType, inputType)) {
-        InFlightDiagnostic diag =
-            branchOp->emitOpError("along control flow edge ");
-        return printRegionEdgeName(diag, sourcePoint, succ)
-               << ": source type #" << typesIdx.index() << " " << sourceType
-               << " should match input type #" << typesIdx.index() << " "
-               << inputType;
-      }
-    }
-  }
-
-  return success();
-}
-
 /// Verify that types match along control flow edges described the given op.
-LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
+LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
-  auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
-    return regionInterface.getEntrySuccessorOperands(successor).getTypes();
-  };
-
-  // Verify types along control flow edges originating from the parent.
-  if (failed(verifyTypesAlongAllEdges(
-          regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
-    return failure();
-
-  // Verify types along control flow edges originating from each region.
-  for (Region &region : op->getRegions()) {
-    // Collect all return-like terminators in the region.
-    SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
-    for (Block &block : region)
-      if (!block.empty())
-        if (auto terminator =
-                dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
-          regionReturnOps.push_back(terminator);
-
-    // If there is no return-like terminator, the op itself should verify
-    // type consistency.
-    if (regionReturnOps.empty())
-      continue;
+  // Verify all control flow edges from region branch points to region
+  // successors.
+  SmallVector<RegionBranchPoint> regionBranchPoints =
+      regionInterface.getAllRegionBranchPoints();
+  for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
+    SmallVector<RegionSuccessor> successors;
+    regionInterface.getSuccessorRegions(branchPoint, successors);
+    for (const RegionSuccessor &successor : successors) {
+      // Helper function that print the region branch point and the region
+      // successor.
+      auto emitRegionEdgeError = [&]() {
+        InFlightDiagnostic diag =
+            regionInterface->emitOpError("along control flow edge from ");
+        if (branchPoint.isParent()) {
+          diag << "parent";
+          diag.attachNote(op->getLoc()) << "region branch point";
+        } else {
+          diag << "Operation "
+               << branchPoint.getTerminatorPredecessorOrNull()->getName();
+          diag.attachNote(
+              branchPoint.getTerminatorPredecessorOrNull()->getLoc())
+              << "region branch point";
+        }
+        diag << " to ";
+        if (Region *region = successor.getSuccessor()) {
+          diag << "Region #" << region->getRegionNumber();
+        } else {
+          diag << "parent";
+        }
+        return diag;
+      };
 
-    // Verify types along control flow edges originating from each return-like
-    // terminator.
-    for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+      // Verify number of successor operands and successor inputs.
+      OperandRange succOperands =
+          regionInterface.getSuccessorOperands(branchPoint, successor);
+      ValueRange succInputs = successor.getSuccessorInputs();
+      if (succOperands.size() != succInputs.size()) {
+        return emitRegionEdgeError()
+               << ": region branch point has " << succOperands.size()
+               << " operands, but region successor needs " << succInputs.size()
+               << " inputs";
+      }
 
-      auto inputTypesForRegion =
-          [&](RegionSuccessor successor) -> FailureOr<TypeRange> {
-        OperandRange terminatorOperands =
-            regionReturnOp.getSuccessorOperands(successor);
-        return TypeRange(terminatorOperands.getTypes());
-      };
-      if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
-                                          inputTypesForRegion)))
-        return failure();
+      // Verify that the types are compatible.
+      TypeRange succInputTypes = succInputs.getTypes();
+      TypeRange succOperandTypes = succOperands.getTypes();
+      for (const auto &typesIdx :
+           llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
+        Type succOperandType = std::get<0>(typesIdx.value());
+        Type succInputType = std::get<1>(typesIdx.value());
+        if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
+          return emitRegionEdgeError()
+                 << ": successor operand type #" << typesIdx.index() << " "
+                 << succOperandType << " should match successor input type #"
+                 << typesIdx.index() << " " << succInputType;
+      }
     }
   }
-
   return success();
 }
 
@@ -525,11 +478,15 @@ SmallVector<RegionBranchPoint>
 RegionBranchOpInterface::getAllRegionBranchPoints() {
   SmallVector<RegionBranchPoint> branchPoints;
   branchPoints.push_back(RegionBranchPoint::parent());
-  for (Region &region : getOperation()->getRegions())
-    for (Block &block : region)
+  for (Region &region : getOperation()->getRegions()) {
+    for (Block &block : region) {
+      if (block.empty())
+        continue;
       if (auto terminator =
               dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
         branchPoints.push_back(RegionBranchPoint(terminator));
+    }
+  }
   return branchPoints;
 }
 
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 6db43ffd4b81b..394b133b088f8 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -404,9 +404,10 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
 
 func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
 {
-  // expected-error at +1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}}
+  // expected-error at +1 {{along control flow edge from Operation scf.yield to parent: region branch point has 1 operands, but region successor needs 2 inputs}}
   %x, %y = scf.if %arg0 -> (f32, f32) {
     %0 = arith.addf %arg1, %arg1 : f32
+    // expected-note at +1 {{region branch point}}
     scf.yield %0 : f32
   } else {
     %0 = arith.subf %arg1, %arg1 : f32
@@ -575,8 +576,9 @@ func.func @while_invalid_terminator() {
 
 func.func @while_cross_region_type_mismatch() {
   %true = arith.constant true
-  // expected-error at +1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}}
+  // expected-error at +1 {{along control flow edge from Operation scf.condition to Region #1: region branch point has 0 operands, but region successor needs 1 inputs}}
   scf.while : () -> () {
+    // expected-note at +1 {{region branch point}}
     scf.condition(%true)
   } do {
   ^bb0(%arg0: i32):
@@ -588,8 +590,9 @@ func.func @while_cross_region_type_mismatch() {
 
 func.func @while_cross_region_type_mismatch() {
   %true = arith.constant true
-  // expected-error at +1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}}
+  // expected-error at +1 {{op along control flow edge from Operation scf.condition to Region #1: successor operand type #0 'i1' should match successor input type #0 'i32'}}
   %0 = scf.while : () -> (i1) {
+    // expected-note at +1 {{region branch point}}
     scf.condition(%true) %true : i1
   } do {
   ^bb0(%arg0: i32):
@@ -601,8 +604,9 @@ func.func @while_cross_region_type_mismatch() {
 
 func.func @while_result_type_mismatch() {
   %true = arith.constant true
-  // expected-error at +1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}}
+  // expected-error at +1 {{along control flow edge from Operation scf.condition to parent: region branch point has 1 operands, but region successor needs 0 inputs}}
   scf.while : () -> () {
+    // expected-note at +1 {{region branch point}}
     scf.condition(%true) %true : i1
   } do {
   ^bb0(%arg0: i1):

``````````

</details>


https://github.com/llvm/llvm-project/pull/174805


More information about the Mlir-commits mailing list