[Mlir-commits] [mlir] [mlir] Add `areLoopIterArgTypesCompatible` to `LoopLikeOpInterface` (PR #184116)

Ivan Butygin llvmlistbot at llvm.org
Mon Mar 2 05:20:59 PST 2026


https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/184116

LoopLikeOpInterface's verifier hardcodes type equality checks for init/iter_arg/yield/result edges, with no override mechanism. This forces downstream ops with compatible-but-unequal types to disable the pass verifier entirely.

RegionBranchOpInterface already has the way to override this behavior so replicate it for the LoopLikeOpInterface.

Add an `areLoopIterArgTypesCompatible(Type, Type)` hook (default: `lhs == rhs`) that the verifier uses instead of `!=`. Named distinctly from `RegionBranchOpInterface::areTypesCompatible` to avoid ambiguity for ops implementing both interfaces.

Add a dedicated tests for RegionBranchOpInterface and LoopLikeOpInterface.

Also refactor the verifier's loop-result check to use structured bindings with `llvm::enumerate` and remove dead variables.

>From 99089109e120834bbb3a1e7b3bf79acac47e533f Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 2 Mar 2026 13:56:55 +0100
Subject: [PATCH] [mlir] Add areLoopIterArgTypesCompatible to
 LoopLikeOpInterface

LoopLikeOpInterface's verifier hardcodes type equality checks for
init/iter_arg/yield/result edges, with no override mechanism. This
forces downstream ops with compatible-but-unequal types to disable
the pass verifier entirely.

RegionBranchOpInterface already has the way to override this behavior
so replicate it for the LoopLikeOpInterface.

Add an `areLoopIterArgTypesCompatible(Type, Type)` hook (default:
`lhs == rhs`) that the verifier uses instead of `!=`. Named distinctly
from `RegionBranchOpInterface::areTypesCompatible` to avoid ambiguity
for ops implementing both interfaces.

Add a dedicated tests for RegionBranchOpInterface and
LoopLikeOpInterface.

Also refactor the verifier's loop-result check to use structured
bindings with `llvm::enumerate` and remove dead variables.
---
 .../mlir/Interfaces/LoopLikeInterface.td      |  9 +++
 mlir/lib/Interfaces/LoopLikeInterface.cpp     | 24 +++---
 .../test/IR/test-branch-types-compatible.mlir | 68 ++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 80 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 35 ++++++++
 5 files changed, 203 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/IR/test-branch-types-compatible.mlir

diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 89526e92a4c92..d478fff431228 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -291,6 +291,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*defaultImplementation=*/[{
         return ::mlir::failure();
       }]
+    >,
+    InterfaceMethod<[{
+        This method is called to compare types along control-flow edges. By
+        default, the types are checked as equal.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"areLoopIterArgTypesCompatible",
+      /*args=*/(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
+      /*defaultImplementation=*/[{ return lhs == rhs; }]
     >
   ];
 
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index ae2ac13187dff..173b386dd9383 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -77,37 +77,35 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
            << " != " << loopLikeOp.getRegionIterArgs().size();
 
   // Verify types of inits/iter_args/yielded values/loop results.
-  int64_t i = 0;
   auto yieldedValues = loopLikeOp.getYieldedValues();
   for (const auto [index, init, regionIterArg] :
        llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
-    if (init.getType() != regionIterArg.getType())
+    if (!loopLikeOp.areLoopIterArgTypesCompatible(init.getType(),
+                                                  regionIterArg.getType()))
       return op->emitOpError(std::to_string(index))
              << "-th init and " << index
              << "-th region iter_arg have different type: " << init.getType()
              << " != " << regionIterArg.getType();
     if (!yieldedValues.empty()) {
-      if (regionIterArg.getType() != yieldedValues[index].getType())
+      if (!loopLikeOp.areLoopIterArgTypesCompatible(
+              regionIterArg.getType(), yieldedValues[index].getType()))
         return op->emitOpError(std::to_string(index))
                << "-th region iter_arg and " << index
                << "-th yielded value have different type: "
                << regionIterArg.getType()
                << " != " << yieldedValues[index].getType();
     }
-    ++i;
   }
-  i = 0;
   if (loopLikeOp.getLoopResults()) {
-    for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
-                                         *loopLikeOp.getLoopResults())) {
-      if (std::get<0>(it).getType() != std::get<1>(it).getType())
-        return op->emitOpError(std::to_string(i))
-               << "-th region iter_arg and " << i
+    for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
+             loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
+      if (!loopLikeOp.areLoopIterArgTypesCompatible(regionIterArg.getType(),
+                                                    loopResult.getType()))
+        return op->emitOpError(std::to_string(index))
+               << "-th region iter_arg and " << index
                << "-th loop result have different type: "
-               << std::get<0>(it).getType()
-               << " != " << std::get<1>(it).getType();
+               << regionIterArg.getType() << " != " << loopResult.getType();
     }
-    ++i;
   }
 
   // Verify that all induction variables have valid types.
diff --git a/mlir/test/IR/test-branch-types-compatible.mlir b/mlir/test/IR/test-branch-types-compatible.mlir
new file mode 100644
index 0000000000000..5f38d73c3b161
--- /dev/null
+++ b/mlir/test/IR/test-branch-types-compatible.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// RegionBranchOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @region_branch_compat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %0 = "test.region_types_compat"(%c0) ({
+  ^bb0(%arg0: i64):
+    %c1 = arith.constant 1 : i64
+    "test.types_compat_yield"(%c1) : (i64) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// RegionBranchOpInterface: incompatible types (i32 <-> f32) should fail.
+func.func @region_branch_incompat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+  // expected-note @+1 {{region branch point}}
+  %0 = "test.region_types_compat"(%c0) ({
+  ^bb0(%arg0: f32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @loop_compat() -> i32 {
+  %c0 = arith.constant 0 : i32
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: i64):
+    %c1 = arith.constant 1 : i64
+    "test.types_compat_yield"(%c1) : (i64) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible init vs iter_arg (i32 <-> f32) should fail.
+func.func @loop_incompat_init() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+1 {{0-th init and 0-th region iter_arg have different type: 'i32' != 'f32'}}
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: f32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible iter_arg vs yield (i32 <-> f32) should fail.
+func.func @loop_incompat_yield() -> i32 {
+  %c0 = arith.constant 0 : i32
+  // expected-error @+1 {{0-th region iter_arg and 0-th yielded value have different type: 'i32' != 'f32'}}
+  %0 = "test.loop_types_compat"(%c0) ({
+  ^bb0(%arg0: i32):
+    %c1 = arith.constant 1.0 : f32
+    "test.types_compat_yield"(%c1) : (f32) -> ()
+  }) : (i32) -> i32
+  return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c243bd79a44a8..1a916afc983e3 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1401,6 +1401,86 @@ TestStoreWithALoopRegion::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange(getBody().front().getArguments());
 }
 
+//===----------------------------------------------------------------------===//
+// TestRegionTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestRegionTypesCompatOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  if (point.isParent())
+    regions.emplace_back(&getBody());
+  else
+    regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange
+TestRegionTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+  return getEntries();
+}
+
+ValueRange
+TestRegionTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+  if (successor.isParent())
+    return getResults();
+  return getBody().getArguments();
+}
+
+bool TestRegionTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// TestLoopTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestLoopTypesCompatOp::getSuccessorRegions(
+    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  regions.emplace_back(&getBody());
+  if (!point.isParent())
+    regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange TestLoopTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+  return getInitArgs();
+}
+
+ValueRange
+TestLoopTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+  if (successor.isParent())
+    return getResults();
+  return getBody().getArguments();
+}
+
+MutableArrayRef<OpOperand> TestLoopTypesCompatOp::getInitsMutable() {
+  return getInitArgsMutable();
+}
+
+Block::BlockArgListType TestLoopTypesCompatOp::getRegionIterArgs() {
+  return getBody().getArguments();
+}
+
+std::optional<MutableArrayRef<OpOperand>>
+TestLoopTypesCompatOp::getYieldedValuesMutable() {
+  return cast<TestTypesCompatYieldOp>(getBody().front().getTerminator())
+      .getArgsMutable();
+}
+
+std::optional<ResultRange> TestLoopTypesCompatOp::getLoopResults() {
+  return getResults();
+}
+
+SmallVector<Region *> TestLoopTypesCompatOp::getLoopRegions() {
+  return {&getBody()};
+}
+
+bool TestLoopTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+bool TestLoopTypesCompatOp::areLoopIterArgTypesCompatible(Type lhs, Type rhs) {
+  return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
 //===----------------------------------------------------------------------===//
 // TestVersionedOpA
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fe02536a1df5b..78d812175b2d6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2717,6 +2717,41 @@ def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Test areTypesCompatible for RegionBranchOpInterface / LoopLikeOpInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypesCompatYieldOp : TEST_Op<"types_compat_yield",
+    [Pure, ReturnLike, Terminator]> {
+  let arguments = (ins Variadic<AnyType>:$args);
+  let assemblyFormat = "($args^ `:` type($args))? attr-dict";
+}
+
+def TestRegionTypesCompatOp : TEST_Op<"region_types_compat",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+       ["getEntrySuccessorOperands", "getSuccessorInputs",
+        "areTypesCompatible"]>,
+     SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+     RecursiveMemoryEffects]> {
+  let arguments = (ins Variadic<AnyType>:$entries);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+}
+
+def TestLoopTypesCompatOp : TEST_Op<"loop_types_compat",
+    [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+       ["getInitsMutable", "getRegionIterArgs", "getYieldedValuesMutable",
+        "getLoopResults", "getLoopRegions", "areLoopIterArgTypesCompatible"]>,
+     DeclareOpInterfaceMethods<RegionBranchOpInterface,
+       ["getEntrySuccessorOperands", "getSuccessorInputs",
+        "areTypesCompatible"]>,
+     SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+     RecursiveMemoryEffects]> {
+  let arguments = (ins Variadic<AnyType>:$init_args);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$body);
+}
+
 //===----------------------------------------------------------------------===//
 // Test TableGen generated build() methods
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list