[Mlir-commits] [mlir] [mlir] Defer `LoopLikeOpInterface` type checks to `RegionBranchOpInterface` (PR #184116)

Ivan Butygin llvmlistbot at llvm.org
Thu Mar 5 07:18:16 PST 2026


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/184116

>From 99089109e120834bbb3a1e7b3bf79acac47e533f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 2 Mar 2026 13:56:55 +0100
Subject: [PATCH 1/3] [mlir] Add areLoopIterArgTypesCompatible to
 LoopLikeOpInterface

LoopLikeOpInterface's verifier hardcodes type equality checks for
init/iter_arg/yield/result edges, with no override mechanism. This
forces downstream ops with compatible-but-unequal types to disable
the pass verifier entirely.

RegionBranchOpInterface already has the way to override this behavior
so replicate it for the LoopLikeOpInterface.

Add an `areLoopIterArgTypesCompatible(Type, Type)` hook (default:
`lhs == rhs`) that the verifier uses instead of `!=`. Named distinctly
from `RegionBranchOpInterface::areTypesCompatible` to avoid ambiguity
for ops implementing both interfaces.

Add a dedicated tests for RegionBranchOpInterface and
LoopLikeOpInterface.

Also refactor the verifier's loop-result check to use structured
bindings with `llvm::enumerate` and remove dead variables.
---
 .../mlir/Interfaces/LoopLikeInterface.td      |  9 +++
 mlir/lib/Interfaces/LoopLikeInterface.cpp     | 24 +++---
 .../test/IR/test-branch-types-compatible.mlir | 68 ++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 80 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 35 ++++++++
 5 files changed, 203 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/IR/test-branch-types-compatible.mlir

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 89526e92a4c92..d478fff431228 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -291,6 +291,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        This method is called to compare types along control-flow edges. By
+        default, the types are checked as equal.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"areLoopIterArgTypesCompatible",
+      /*args=*/(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
+      /*defaultImplementation=*/[{ return lhs == rhs; }]
     >
   ];
 
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index ae2ac13187dff..173b386dd9383 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -77,37 +77,35 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
            << " != " << loopLikeOp.getRegionIterArgs().size();
 
   // Verify types of inits/iter_args/yielded values/loop results.
-  int64_t i = 0;
   auto yieldedValues = loopLikeOp.getYieldedValues();
   for (const auto [index, init, regionIterArg] :
        llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
-    if (init.getType() != regionIterArg.getType())
+    if (!loopLikeOp.areLoopIterArgTypesCompatible(init.getType(),
+                                                  regionIterArg.getType()))
       return op->emitOpError(std::to_string(index))
              << "-th init and " << index
              << "-th region iter_arg have different type: " << init.getType()
              << " != " << regionIterArg.getType();
     if (!yieldedValues.empty()) {
-      if (regionIterArg.getType() != yieldedValues[index].getType())
+      if (!loopLikeOp.areLoopIterArgTypesCompatible(
+              regionIterArg.getType(), yieldedValues[index].getType()))
         return op->emitOpError(std::to_string(index))
                << "-th region iter_arg and " << index
                << "-th yielded value have different type: "
                << regionIterArg.getType()
                << " != " << yieldedValues[index].getType();
     }
-    ++i;
   }
-  i = 0;
   if (loopLikeOp.getLoopResults()) {
-    for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
-                                         *loopLikeOp.getLoopResults())) {
-      if (std::get<0>(it).getType() != std::get<1>(it).getType())
-        return op->emitOpError(std::to_string(i))
-               << "-th region iter_arg and " << i
+    for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
+             loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
+      if (!loopLikeOp.areLoopIterArgTypesCompatible(regionIterArg.getType(),
+                                                    loopResult.getType()))
+        return op->emitOpError(std::to_string(index))
+               << "-th region iter_arg and " << index
                << "-th loop result have different type: "
-               << std::get<0>(it).getType()
-               << " != " << std::get<1>(it).getType();
+               << regionIterArg.getType() << " != " << loopResult.getType();
     }
-    ++i;
   }
 
   // Verify that all induction variables have valid types.
diff --git a/mlir/test/IR/test-branch-types-compatible.mlir b/mlir/test/IR/test-branch-types-compatible.mlir
new file mode 100644
index 0000000000000..5f38d73c3b161
--- /dev/null
+++ b/mlir/test/IR/test-branch-types-compatible.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// RegionBranchOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @region_branch_compat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %0 = "test.region_types_compat"(%c0) ({
+  ^bb0(%arg0: i64):
+    %c1 = arith.constant 1 : i64
+    "test.types_compat_yield"(%c1) : (i64) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// RegionBranchOpInterface: incompatible types (i32 <-> f32) should fail.
+func.func @region_branch_incompat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+  // expected-note @+1 {{region branch point}}
+  %0 = "test.region_types_compat"(%c0) ({
+  ^bb0(%arg0: f32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @loop_compat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: i64):
+    %c1 = arith.constant 1 : i64
+    "test.types_compat_yield"(%c1) : (i64) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible init vs iter_arg (i32 <-> f32) should fail.
+func.func @loop_incompat_init() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+1 {{0-th init and 0-th region iter_arg have different type: 'i32' != 'f32'}}
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: f32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible iter_arg vs yield (i32 <-> f32) should fail.
+func.func @loop_incompat_yield() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+1 {{0-th region iter_arg and 0-th yielded value have different type: 'i32' != 'f32'}}
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: i32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c243bd79a44a8..1a916afc983e3 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1401,6 +1401,86 @@ TestStoreWithALoopRegion::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange(getBody().front().getArguments());
 }
 
+//===----------------------------------------------------------------------===//
+// TestRegionTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestRegionTypesCompatOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
+    regions.emplace_back(&getBody());
+  else
+    regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange
+TestRegionTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+  return getEntries();
+}
+
+ValueRange
+TestRegionTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+  if (successor.isParent())
+    return getResults();
+  return getBody().getArguments();
+}
+
+bool TestRegionTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// TestLoopTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestLoopTypesCompatOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  regions.emplace_back(&getBody());
+  if (!point.isParent())
+    regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange TestLoopTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+  return getInitArgs();
+}
+
+ValueRange
+TestLoopTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+  if (successor.isParent())
+    return getResults();
+  return getBody().getArguments();
+}
+
+MutableArrayRef<OpOperand> TestLoopTypesCompatOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType TestLoopTypesCompatOp::getRegionIterArgs() {
+  return getBody().getArguments();
+}
+
+std::optional<MutableArrayRef<OpOperand>>
+TestLoopTypesCompatOp::getYieldedValuesMutable() {
+  return cast<TestTypesCompatYieldOp>(getBody().front().getTerminator())
+      .getArgsMutable();
+}
+
+std::optional<ResultRange> TestLoopTypesCompatOp::getLoopResults() {
+  return getResults();
+}
+
+SmallVector<Region *> TestLoopTypesCompatOp::getLoopRegions() {
+  return {&getBody()};
+}
+
+bool TestLoopTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+bool TestLoopTypesCompatOp::areLoopIterArgTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
 //===----------------------------------------------------------------------===//
 // TestVersionedOpA
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fe02536a1df5b..78d812175b2d6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2717,6 +2717,41 @@ def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test areTypesCompatible for RegionBranchOpInterface / LoopLikeOpInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypesCompatYieldOp : TEST_Op<"types_compat_yield",
+    [Pure, ReturnLike, Terminator]> {
+  let arguments = (ins Variadic<AnyType>:$args);
+  let assemblyFormat = "($args^ `:` type($args))? attr-dict";
+}
+
+def TestRegionTypesCompatOp : TEST_Op<"region_types_compat",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+       ["getEntrySuccessorOperands", "getSuccessorInputs",
+        "areTypesCompatible"]>,
+     SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+     RecursiveMemoryEffects]> {
+  let arguments = (ins Variadic<AnyType>:$entries);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+}
+
+def TestLoopTypesCompatOp : TEST_Op<"loop_types_compat",
+    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+       ["getInitsMutable", "getRegionIterArgs", "getYieldedValuesMutable",
+        "getLoopResults", "getLoopRegions", "areLoopIterArgTypesCompatible"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+       ["getEntrySuccessorOperands", "getSuccessorInputs",
+        "areTypesCompatible"]>,
+     SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+     RecursiveMemoryEffects]> {
+  let arguments = (ins Variadic<AnyType>:$init_args);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+}
+
 //===----------------------------------------------------------------------===//
 // Test TableGen generated build() methods
 //===----------------------------------------------------------------------===//

>From 0f5b4e86281609cd855170403f6f0a876aa6eab2 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 5 Mar 2026 14:16:19 +0100
Subject: [PATCH 2/3] [mlir] Defer LoopLikeOpInterface type checks to
 RegionBranchOpInterface

Instead of adding a new areLoopIterArgTypesCompatible hook to
LoopLikeOpInterface, skip the type compatibility checks entirely when
the op also implements RegionBranchOpInterface (which already verifies
types along all control-flow edges and provides an overridable
areTypesCompatible hook).

For ops that only implement LoopLikeOpInterface (e.g. affine.parallel,
tosa.while_loop), the existing strict type equality checks are retained
as a fallback.
---
 .../mlir/Interfaces/LoopLikeInterface.td      |  9 ---
 mlir/lib/Interfaces/CMakeLists.txt            |  1 +
 mlir/lib/Interfaces/LoopLikeInterface.cpp     | 57 ++++++++++---------
 mlir/test/Dialect/SCF/invalid.mlir            |  6 +-
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  2 +-
 .../test/IR/test-branch-types-compatible.mlir | 12 ++--
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  4 --
 mlir/test/lib/Dialect/Test/TestOps.td         |  2 +-
 8 files changed, 44 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index d478fff431228..89526e92a4c92 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -291,15 +291,6 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
-    >,
-    InterfaceMethod<[{
-        This method is called to compare types along control-flow edges. By
-        default, the types are checked as equal.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"areLoopIterArgTypesCompatible",
-      /*args=*/(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
-      /*defaultImplementation=*/[{ return lhs == rhs; }]
     >
   ];
 
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..41e890cb408ba 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -95,6 +95,7 @@ add_mlir_library(MLIRLoopLikeInterface
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRControlFlowInterfaces
   MLIRFunctionInterfaces
 )
 
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 173b386dd9383..201efdba8e037 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Interfaces/LoopLikeInterface.h"
 
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
 
 using namespace mlir;
@@ -53,8 +54,6 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
 }
 
 LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
-  // Note: These invariants are also verified by the RegionBranchOpInterface,
-  // but the LoopLikeOpInterface provides better error messages.
   auto loopLikeOp = cast<LoopLikeOpInterface>(op);
 
   // Verify number of inits/iter_args/yielded values/loop results.
@@ -77,34 +76,36 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
            << " != " << loopLikeOp.getRegionIterArgs().size();
 
   // Verify types of inits/iter_args/yielded values/loop results.
-  auto yieldedValues = loopLikeOp.getYieldedValues();
-  for (const auto [index, init, regionIterArg] :
-       llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
-    if (!loopLikeOp.areLoopIterArgTypesCompatible(init.getType(),
-                                                  regionIterArg.getType()))
-      return op->emitOpError(std::to_string(index))
-             << "-th init and " << index
-             << "-th region iter_arg have different type: " << init.getType()
-             << " != " << regionIterArg.getType();
-    if (!yieldedValues.empty()) {
-      if (!loopLikeOp.areLoopIterArgTypesCompatible(
-              regionIterArg.getType(), yieldedValues[index].getType()))
+  // If the op also implements RegionBranchOpInterface, type compatibility is
+  // already verified by that interface's verifier (which also provides an
+  // overridable areTypesCompatible hook), so skip the check here.
+  if (!isa<RegionBranchOpInterface>(op)) {
+    auto yieldedValues = loopLikeOp.getYieldedValues();
+    for (const auto [index, init, regionIterArg] : llvm::enumerate(
+             loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
+      if (init.getType() != regionIterArg.getType())
         return op->emitOpError(std::to_string(index))
-               << "-th region iter_arg and " << index
-               << "-th yielded value have different type: "
-               << regionIterArg.getType()
-               << " != " << yieldedValues[index].getType();
+               << "-th init and " << index
+               << "-th region iter_arg have different type: " << init.getType()
+               << " != " << regionIterArg.getType();
+      if (!yieldedValues.empty()) {
+        if (regionIterArg.getType() != yieldedValues[index].getType())
+          return op->emitOpError(std::to_string(index))
+                 << "-th region iter_arg and " << index
+                 << "-th yielded value have different type: "
+                 << regionIterArg.getType()
+                 << " != " << yieldedValues[index].getType();
+      }
     }
-  }
-  if (loopLikeOp.getLoopResults()) {
-    for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
-             loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
-      if (!loopLikeOp.areLoopIterArgTypesCompatible(regionIterArg.getType(),
-                                                    loopResult.getType()))
-        return op->emitOpError(std::to_string(index))
-               << "-th region iter_arg and " << index
-               << "-th loop result have different type: "
-               << regionIterArg.getType() << " != " << loopResult.getType();
+    if (loopLikeOp.getLoopResults()) {
+      for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
+               loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
+        if (regionIterArg.getType() != loopResult.getType())
+          return op->emitOpError(std::to_string(index))
+                 << "-th region iter_arg and " << index
+                 << "-th loop result have different type: "
+                 << regionIterArg.getType() << " != " << loopResult.getType();
+      }
     }
   }
 
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index e7031b29adf69..33a8921eeb993 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -98,7 +98,8 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
 // -----
 
 func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
-  // expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
+  // expected-error @below{{along control flow edge from parent to parent: successor operand type #0 'f32' should match successor input type #0 'f64'}}
+  // expected-note @below{{region branch point}}
   "scf.for"(%arg0, %arg0, %arg0, %init) (
     {
     ^bb0(%i0 : index, %iter: f32):
@@ -494,11 +495,12 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
 func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
   %s0 = arith.constant 0.0 : f32
   %t0 = arith.constant 1.0 : f32
-  // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
+  // expected-error @below {{along control flow edge from Operation scf.yield to Region #0: successor operand type #1 'i32' should match successor input type #1 'f32'}}
   %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
                     iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
     %sn = arith.addf %si, %si : f32
     %ic = arith.constant 1 : i32
+    // expected-note @below {{region branch point}}
     scf.yield %sn, %ic : f32, i32
   }
   return
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 30c74bbcdbf90..b1d9d29991584 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1198,7 +1198,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
 
 func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
   %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
-  // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
+  // expected-error @+1 {{'sparse_tensor.iterate' op types mismatch between 0th yield value and defined value}}
   %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
     %y = arith.constant 1.0 :  f32
     sparse_tensor.yield %y : f32
diff --git a/mlir/test/IR/test-branch-types-compatible.mlir b/mlir/test/IR/test-branch-types-compatible.mlir
index 5f38d73c3b161..b6840da663f54 100644
--- a/mlir/test/IR/test-branch-types-compatible.mlir
+++ b/mlir/test/IR/test-branch-types-compatible.mlir
@@ -41,10 +41,12 @@ func.func @loop_compat() -> i32 {
 
 // -----
 
-// LoopLikeOpInterface: incompatible init vs iter_arg (i32 <-> f32) should fail.
+// LoopLikeOpInterface + RegionBranchOpInterface: incompatible init vs iter_arg
+// (i32 <-> f32) should fail via RegionBranchOpInterface.
 func.func @loop_incompat_init() -> i32 {
   %c0 = arith.constant 0 : i32
-  // expected-error @+1 {{0-th init and 0-th region iter_arg have different type: 'i32' != 'f32'}}
+  // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+  // expected-note @+1 {{region branch point}}
   %0 = "test.loop_types_compat"(%c0) ({
   ^bb0(%arg0: f32):
     %c1 = arith.constant 1.0 : f32
@@ -55,13 +57,15 @@ func.func @loop_incompat_init() -> i32 {
 
 // -----
 
-// LoopLikeOpInterface: incompatible iter_arg vs yield (i32 <-> f32) should fail.
+// LoopLikeOpInterface + RegionBranchOpInterface: incompatible iter_arg vs yield
+// (i32 <-> f32) should fail via RegionBranchOpInterface.
 func.func @loop_incompat_yield() -> i32 {
   %c0 = arith.constant 0 : i32
-  // expected-error @+1 {{0-th region iter_arg and 0-th yielded value have different type: 'i32' != 'f32'}}
+  // expected-error @+1 {{along control flow edge from Operation test.types_compat_yield to Region #0: successor operand type #0 'f32' should match successor input type #0 'i32'}}
   %0 = "test.loop_types_compat"(%c0) ({
   ^bb0(%arg0: i32):
     %c1 = arith.constant 1.0 : f32
+    // expected-note @+1 {{region branch point}}
     "test.types_compat_yield"(%c1) : (f32) -> ()
   }) : (i32) -> i32
   return %0 : i32
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 1a916afc983e3..a06a11e169d74 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1477,10 +1477,6 @@ bool TestLoopTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
   return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
 }
 
-bool TestLoopTypesCompatOp::areLoopIterArgTypesCompatible(Type lhs, Type rhs) {
-  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
-}
-
 //===----------------------------------------------------------------------===//
 // TestVersionedOpA
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 78d812175b2d6..b9438362cf5a3 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2741,7 +2741,7 @@ def TestRegionTypesCompatOp : TEST_Op<"region_types_compat",
 def TestLoopTypesCompatOp : TEST_Op<"loop_types_compat",
     [DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInitsMutable", "getRegionIterArgs", "getYieldedValuesMutable",
-        "getLoopResults", "getLoopRegions", "areLoopIterArgTypesCompatible"]>,
+        "getLoopResults"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
        ["getEntrySuccessorOperands", "getSuccessorInputs",
         "areTypesCompatible"]>,

>From 471ce75881e7c2d06d41ca03b805215d3673b34e Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Thu, 5 Mar 2026 16:10:38 +0100
Subject: [PATCH 3/3] update doc

---
 mlir/include/mlir/Interfaces/LoopLikeInterface.td | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 89526e92a4c92..5fb897339ffde 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -35,8 +35,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       serve as the region iter_arg of the next iteration.
 
     If one of the respective interface methods is implemented, so must the other
-    two. The interface verifier ensures that the number of types of the region
-    iter_args, init values and yielded values match.
+    two. The interface verifier ensures that the number of region iter_args,
+    init values and yielded values match. Note: the types are not required to
+    strictly match; if the op also implements `RegionBranchOpInterface`, type
+    compatibility is deferred to that interface's `areTypesCompatible` hook.
 
     Optionally, "loop results" can be exposed through this interface. These are
     the values that are returned from the loop op when there are no more



More information about the Mlir-commits mailing list