[Mlir-commits] [mlir] 9c30297 - [mlir] Defer `LoopLikeOpInterface` type checks to `RegionBranchOpInterface` (#184116)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 6 04:43:05 PST 2026
Author: Ivan Butygin
Date: 2026-03-06T15:43:00+03:00
New Revision: 9c3029774ad5d30718877286d2df776a9e28e052
URL: https://github.com/llvm/llvm-project/commit/9c3029774ad5d30718877286d2df776a9e28e052
DIFF: https://github.com/llvm/llvm-project/commit/9c3029774ad5d30718877286d2df776a9e28e052.diff
LOG: [mlir] Defer `LoopLikeOpInterface` type checks to `RegionBranchOpInterface` (#184116)
`LoopLikeOpInterface`'s verifier hardcodes type equality checks for
init/iter_arg/yield/result edges. This prevents downstream ops with
compatible-but-unequal types from using the pass verifier, since there
is no override mechanism.
`RegionBranchOpInterface` already verifies types along all control-flow
edges and provides an overridable `areTypesCompatible` hook. In
practice,
most loop ops implement both interfaces, making the
`LoopLikeOpInterface`
type checks redundant (the existing code even had a comment
acknowledging
this).
This PR removes the redundant type checks from `LoopLikeOpInterface`
when the op also implements `RegionBranchOpInterface`, letting that
interface handle type compatibility (including any custom
`areTypesCompatible` overrides). For ops that only implement
`LoopLikeOpInterface` (e.g. `affine.parallel`, `tosa.while_loop`), the
strict type equality checks are retained as a fallback. the verifier's
loop-result check to use structured bindings with `llvm::enumerate` and
remove dead variables.
Added:
mlir/test/IR/test-branch-types-compatible.mlir
Modified:
mlir/include/mlir/Interfaces/LoopLikeInterface.td
mlir/lib/Interfaces/CMakeLists.txt
mlir/lib/Interfaces/LoopLikeInterface.cpp
mlir/test/Dialect/SCF/invalid.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/lib/Dialect/Test/TestOpDefs.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 89526e92a4c92..5fb897339ffde 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -35,8 +35,10 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
serve as the region iter_arg of the next iteration.
If one of the respective interface methods is implemented, so must the other
- two. The interface verifier ensures that the number of types of the region
- iter_args, init values and yielded values match.
+ two. The interface verifier ensures that the number of region iter_args,
+ init values and yielded values match. Note: the types are not required to
+ strictly match; if the op also implements `RegionBranchOpInterface`, type
+ compatibility is deferred to that interface's `areTypesCompatible` hook.
Optionally, "loop results" can be exposed through this interface. These are
the values that are returned from the loop op when there are no more
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..41e890cb408ba 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -95,6 +95,7 @@ add_mlir_library(MLIRLoopLikeInterface
LINK_LIBS PUBLIC
MLIRIR
+ MLIRControlFlowInterfaces
MLIRFunctionInterfaces
)
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index ae2ac13187dff..201efdba8e037 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -8,6 +8,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
using namespace mlir;
@@ -53,8 +54,6 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
}
LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
- // Note: These invariants are also verified by the RegionBranchOpInterface,
- // but the LoopLikeOpInterface provides better error messages.
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
// Verify number of inits/iter_args/yielded values/loop results.
@@ -77,37 +76,37 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
<< " != " << loopLikeOp.getRegionIterArgs().size();
// Verify types of inits/iter_args/yielded values/loop results.
- int64_t i = 0;
- auto yieldedValues = loopLikeOp.getYieldedValues();
- for (const auto [index, init, regionIterArg] :
- llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
- if (init.getType() != regionIterArg.getType())
- return op->emitOpError(std::to_string(index))
- << "-th init and " << index
- << "-th region iter_arg have
diff erent type: " << init.getType()
- << " != " << regionIterArg.getType();
- if (!yieldedValues.empty()) {
- if (regionIterArg.getType() != yieldedValues[index].getType())
+ // If the op also implements RegionBranchOpInterface, type compatibility is
+ // already verified by that interface's verifier (which also provides an
+ // overridable areTypesCompatible hook), so skip the check here.
+ if (!isa<RegionBranchOpInterface>(op)) {
+ auto yieldedValues = loopLikeOp.getYieldedValues();
+ for (const auto [index, init, regionIterArg] : llvm::enumerate(
+ loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
+ if (init.getType() != regionIterArg.getType())
return op->emitOpError(std::to_string(index))
- << "-th region iter_arg and " << index
- << "-th yielded value have
diff erent type: "
- << regionIterArg.getType()
- << " != " << yieldedValues[index].getType();
+ << "-th init and " << index
+ << "-th region iter_arg have
diff erent type: " << init.getType()
+ << " != " << regionIterArg.getType();
+ if (!yieldedValues.empty()) {
+ if (regionIterArg.getType() != yieldedValues[index].getType())
+ return op->emitOpError(std::to_string(index))
+ << "-th region iter_arg and " << index
+ << "-th yielded value have
diff erent type: "
+ << regionIterArg.getType()
+ << " != " << yieldedValues[index].getType();
+ }
}
- ++i;
- }
- i = 0;
- if (loopLikeOp.getLoopResults()) {
- for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
- *loopLikeOp.getLoopResults())) {
- if (std::get<0>(it).getType() != std::get<1>(it).getType())
- return op->emitOpError(std::to_string(i))
- << "-th region iter_arg and " << i
- << "-th loop result have
diff erent type: "
- << std::get<0>(it).getType()
- << " != " << std::get<1>(it).getType();
+ if (loopLikeOp.getLoopResults()) {
+ for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
+ loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
+ if (regionIterArg.getType() != loopResult.getType())
+ return op->emitOpError(std::to_string(index))
+ << "-th region iter_arg and " << index
+ << "-th loop result have
diff erent type: "
+ << regionIterArg.getType() << " != " << loopResult.getType();
+ }
}
- ++i;
}
// Verify that all induction variables have valid types.
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index e7031b29adf69..33a8921eeb993 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -98,7 +98,8 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
// -----
func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
- // expected-error @below{{0-th region iter_arg and 0-th loop result have
diff erent type: 'f32' != 'f64'}}
+ // expected-error @below{{along control flow edge from parent to parent: successor operand type #0 'f32' should match successor input type #0 'f64'}}
+ // expected-note @below{{region branch point}}
"scf.for"(%arg0, %arg0, %arg0, %init) (
{
^bb0(%i0 : index, %iter: f32):
@@ -494,11 +495,12 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
- // expected-error @below {{1-th region iter_arg and 1-th yielded value have
diff erent type: 'f32' != 'i32'}}
+ // expected-error @below {{along control flow edge from Operation scf.yield to Region #0: successor operand type #1 'i32' should match successor input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32
%ic = arith.constant 1 : i32
+ // expected-note @below {{region branch point}}
scf.yield %sn, %ic : f32, i32
}
return
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 5884aad9f233c..ae706b9b148a6 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -1198,7 +1198,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
- // expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have
diff erent type: 'index' != 'f32'}}
+ // expected-error @+1 {{'sparse_tensor.iterate' op types mismatch between 0th yield value and defined value}}
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
%y = arith.constant 1.0 : f32
sparse_tensor.yield %y : f32
diff --git a/mlir/test/IR/test-branch-types-compatible.mlir b/mlir/test/IR/test-branch-types-compatible.mlir
new file mode 100644
index 0000000000000..b6840da663f54
--- /dev/null
+++ b/mlir/test/IR/test-branch-types-compatible.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// RegionBranchOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @region_branch_compat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %0 = "test.region_types_compat"(%c0) ({
+ ^bb0(%arg0: i64):
+ %c1 = arith.constant 1 : i64
+ "test.types_compat_yield"(%c1) : (i64) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// RegionBranchOpInterface: incompatible types (i32 <-> f32) should fail.
+func.func @region_branch_incompat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+ // expected-note @+1 {{region branch point}}
+ %0 = "test.region_types_compat"(%c0) ({
+ ^bb0(%arg0: f32):
+ %c1 = arith.constant 1.0 : f32
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @loop_compat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: i64):
+ %c1 = arith.constant 1 : i64
+ "test.types_compat_yield"(%c1) : (i64) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface + RegionBranchOpInterface: incompatible init vs iter_arg
+// (i32 <-> f32) should fail via RegionBranchOpInterface.
+func.func @loop_incompat_init() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+ // expected-note @+1 {{region branch point}}
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: f32):
+ %c1 = arith.constant 1.0 : f32
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface + RegionBranchOpInterface: incompatible iter_arg vs yield
+// (i32 <-> f32) should fail via RegionBranchOpInterface.
+func.func @loop_incompat_yield() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+1 {{along control flow edge from Operation test.types_compat_yield to Region #0: successor operand type #0 'f32' should match successor input type #0 'i32'}}
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: i32):
+ %c1 = arith.constant 1.0 : f32
+ // expected-note @+1 {{region branch point}}
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 7cf728f933395..b0fc5a6acc647 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1419,6 +1419,82 @@ TestStoreWithALoopRegion::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange(getBody().front().getArguments());
}
+//===----------------------------------------------------------------------===//
+// TestRegionTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestRegionTypesCompatOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody());
+ else
+ regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange
+TestRegionTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+ return getEntries();
+}
+
+ValueRange
+TestRegionTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+ if (successor.isParent())
+ return getResults();
+ return getBody().getArguments();
+}
+
+bool TestRegionTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+ return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// TestLoopTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestLoopTypesCompatOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ regions.emplace_back(&getBody());
+ if (!point.isParent())
+ regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange TestLoopTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+ return getInitArgs();
+}
+
+ValueRange
+TestLoopTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+ if (successor.isParent())
+ return getResults();
+ return getBody().getArguments();
+}
+
+MutableArrayRef<OpOperand> TestLoopTypesCompatOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType TestLoopTypesCompatOp::getRegionIterArgs() {
+ return getBody().getArguments();
+}
+
+std::optional<MutableArrayRef<OpOperand>>
+TestLoopTypesCompatOp::getYieldedValuesMutable() {
+ return cast<TestTypesCompatYieldOp>(getBody().front().getTerminator())
+ .getArgsMutable();
+}
+
+std::optional<ResultRange> TestLoopTypesCompatOp::getLoopResults() {
+ return getResults();
+}
+
+SmallVector<Region *> TestLoopTypesCompatOp::getLoopRegions() {
+ return {&getBody()};
+}
+
+bool TestLoopTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+ return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
//===----------------------------------------------------------------------===//
// TestVersionedOpA
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bd0b6e25efa53..a25b9d270de16 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2728,6 +2728,41 @@ def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
}];
}
+//===----------------------------------------------------------------------===//
+// Test areTypesCompatible for RegionBranchOpInterface / LoopLikeOpInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypesCompatYieldOp : TEST_Op<"types_compat_yield",
+ [Pure, ReturnLike, Terminator]> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let assemblyFormat = "($args^ `:` type($args))? attr-dict";
+}
+
+def TestRegionTypesCompatOp : TEST_Op<"region_types_compat",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorInputs",
+ "areTypesCompatible"]>,
+ SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+ RecursiveMemoryEffects]> {
+ let arguments = (ins Variadic<AnyType>:$entries);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+}
+
+def TestLoopTypesCompatOp : TEST_Op<"loop_types_compat",
+ [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getRegionIterArgs", "getYieldedValuesMutable",
+ "getLoopResults"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorInputs",
+ "areTypesCompatible"]>,
+ SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+ RecursiveMemoryEffects]> {
+ let arguments = (ins Variadic<AnyType>:$init_args);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+}
+
//===----------------------------------------------------------------------===//
// Test TableGen generated build() methods
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list