[Mlir-commits] [mlir] [mlir] Add `areLoopIterArgTypesCompatible` to `LoopLikeOpInterface` (PR #184116)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 05:21:30 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Ivan Butygin (Hardcode84)
<details>
<summary>Changes</summary>
LoopLikeOpInterface's verifier hardcodes type equality checks for init/iter_arg/yield/result edges, with no override mechanism. This forces downstream ops with compatible-but-unequal types to disable the pass verifier entirely.
RegionBranchOpInterface already has the way to override this behavior so replicate it for the LoopLikeOpInterface.
Add an `areLoopIterArgTypesCompatible(Type, Type)` hook (default: `lhs == rhs`) that the verifier uses instead of `!=`. Named distinctly from `RegionBranchOpInterface::areTypesCompatible` to avoid ambiguity for ops implementing both interfaces.
Add a dedicated tests for RegionBranchOpInterface and LoopLikeOpInterface.
Also refactor the verifier's loop-result check to use structured bindings with `llvm::enumerate` and remove dead variables.
---
Full diff: https://github.com/llvm/llvm-project/pull/184116.diff
5 Files Affected:
- (modified) mlir/include/mlir/Interfaces/LoopLikeInterface.td (+9)
- (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (+11-13)
- (added) mlir/test/IR/test-branch-types-compatible.mlir (+68)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+80)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+35)
``````````diff
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 89526e92a4c92..d478fff431228 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -291,6 +291,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*defaultImplementation=*/[{
return ::mlir::failure();
}]
+ >,
+ InterfaceMethod<[{
+ This method is called to compare types along control-flow edges. By
+ default, the types are checked as equal.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"areLoopIterArgTypesCompatible",
+ /*args=*/(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
+ /*defaultImplementation=*/[{ return lhs == rhs; }]
>
];
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index ae2ac13187dff..173b386dd9383 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -77,37 +77,35 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
<< " != " << loopLikeOp.getRegionIterArgs().size();
// Verify types of inits/iter_args/yielded values/loop results.
- int64_t i = 0;
auto yieldedValues = loopLikeOp.getYieldedValues();
for (const auto [index, init, regionIterArg] :
llvm::enumerate(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs())) {
- if (init.getType() != regionIterArg.getType())
+ if (!loopLikeOp.areLoopIterArgTypesCompatible(init.getType(),
+ regionIterArg.getType()))
return op->emitOpError(std::to_string(index))
<< "-th init and " << index
<< "-th region iter_arg have different type: " << init.getType()
<< " != " << regionIterArg.getType();
if (!yieldedValues.empty()) {
- if (regionIterArg.getType() != yieldedValues[index].getType())
+ if (!loopLikeOp.areLoopIterArgTypesCompatible(
+ regionIterArg.getType(), yieldedValues[index].getType()))
return op->emitOpError(std::to_string(index))
<< "-th region iter_arg and " << index
<< "-th yielded value have different type: "
<< regionIterArg.getType()
<< " != " << yieldedValues[index].getType();
}
- ++i;
}
- i = 0;
if (loopLikeOp.getLoopResults()) {
- for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
- *loopLikeOp.getLoopResults())) {
- if (std::get<0>(it).getType() != std::get<1>(it).getType())
- return op->emitOpError(std::to_string(i))
- << "-th region iter_arg and " << i
+ for (const auto [index, regionIterArg, loopResult] : llvm::enumerate(
+ loopLikeOp.getRegionIterArgs(), *loopLikeOp.getLoopResults())) {
+ if (!loopLikeOp.areLoopIterArgTypesCompatible(regionIterArg.getType(),
+ loopResult.getType()))
+ return op->emitOpError(std::to_string(index))
+ << "-th region iter_arg and " << index
<< "-th loop result have different type: "
- << std::get<0>(it).getType()
- << " != " << std::get<1>(it).getType();
+ << regionIterArg.getType() << " != " << loopResult.getType();
}
- ++i;
}
// Verify that all induction variables have valid types.
diff --git a/mlir/test/IR/test-branch-types-compatible.mlir b/mlir/test/IR/test-branch-types-compatible.mlir
new file mode 100644
index 0000000000000..5f38d73c3b161
--- /dev/null
+++ b/mlir/test/IR/test-branch-types-compatible.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+// RegionBranchOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @region_branch_compat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %0 = "test.region_types_compat"(%c0) ({
+ ^bb0(%arg0: i64):
+ %c1 = arith.constant 1 : i64
+ "test.types_compat_yield"(%c1) : (i64) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// RegionBranchOpInterface: incompatible types (i32 <-> f32) should fail.
+func.func @region_branch_incompat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+2 {{along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+ // expected-note @+1 {{region branch point}}
+ %0 = "test.region_types_compat"(%c0) ({
+ ^bb0(%arg0: f32):
+ %c1 = arith.constant 1.0 : f32
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: compatible integer types (i32 <-> i64) should pass.
+func.func @loop_compat() -> i32 {
+ %c0 = arith.constant 0 : i32
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: i64):
+ %c1 = arith.constant 1 : i64
+ "test.types_compat_yield"(%c1) : (i64) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible init vs iter_arg (i32 <-> f32) should fail.
+func.func @loop_incompat_init() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+1 {{0-th init and 0-th region iter_arg have different type: 'i32' != 'f32'}}
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: f32):
+ %c1 = arith.constant 1.0 : f32
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
+
+// -----
+
+// LoopLikeOpInterface: incompatible iter_arg vs yield (i32 <-> f32) should fail.
+func.func @loop_incompat_yield() -> i32 {
+ %c0 = arith.constant 0 : i32
+ // expected-error @+1 {{0-th region iter_arg and 0-th yielded value have different type: 'i32' != 'f32'}}
+ %0 = "test.loop_types_compat"(%c0) ({
+ ^bb0(%arg0: i32):
+ %c1 = arith.constant 1.0 : f32
+ "test.types_compat_yield"(%c1) : (f32) -> ()
+ }) : (i32) -> i32
+ return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index c243bd79a44a8..1a916afc983e3 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1401,6 +1401,86 @@ TestStoreWithALoopRegion::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange(getBody().front().getArguments());
}
+//===----------------------------------------------------------------------===//
+// TestRegionTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestRegionTypesCompatOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent())
+ regions.emplace_back(&getBody());
+ else
+ regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange
+TestRegionTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+ return getEntries();
+}
+
+ValueRange
+TestRegionTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+ if (successor.isParent())
+ return getResults();
+ return getBody().getArguments();
+}
+
+bool TestRegionTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+ return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// TestLoopTypesCompatOp
+//===----------------------------------------------------------------------===//
+
+void TestLoopTypesCompatOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ regions.emplace_back(&getBody());
+ if (!point.isParent())
+ regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange TestLoopTypesCompatOp::getEntrySuccessorOperands(RegionSuccessor) {
+ return getInitArgs();
+}
+
+ValueRange
+TestLoopTypesCompatOp::getSuccessorInputs(RegionSuccessor successor) {
+ if (successor.isParent())
+ return getResults();
+ return getBody().getArguments();
+}
+
+MutableArrayRef<OpOperand> TestLoopTypesCompatOp::getInitsMutable() {
+ return getInitArgsMutable();
+}
+
+Block::BlockArgListType TestLoopTypesCompatOp::getRegionIterArgs() {
+ return getBody().getArguments();
+}
+
+std::optional<MutableArrayRef<OpOperand>>
+TestLoopTypesCompatOp::getYieldedValuesMutable() {
+ return cast<TestTypesCompatYieldOp>(getBody().front().getTerminator())
+ .getArgsMutable();
+}
+
+std::optional<ResultRange> TestLoopTypesCompatOp::getLoopResults() {
+ return getResults();
+}
+
+SmallVector<Region *> TestLoopTypesCompatOp::getLoopRegions() {
+ return {&getBody()};
+}
+
+bool TestLoopTypesCompatOp::areTypesCompatible(Type lhs, Type rhs) {
+ return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
+bool TestLoopTypesCompatOp::areLoopIterArgTypesCompatible(Type lhs, Type rhs) {
+ return lhs == rhs || (isa<IntegerType>(lhs) && isa<IntegerType>(rhs));
+}
+
//===----------------------------------------------------------------------===//
// TestVersionedOpA
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index fe02536a1df5b..78d812175b2d6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2717,6 +2717,41 @@ def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
}];
}
+//===----------------------------------------------------------------------===//
+// Test areTypesCompatible for RegionBranchOpInterface / LoopLikeOpInterface
+//===----------------------------------------------------------------------===//
+
+def TestTypesCompatYieldOp : TEST_Op<"types_compat_yield",
+ [Pure, ReturnLike, Terminator]> {
+ let arguments = (ins Variadic<AnyType>:$args);
+ let assemblyFormat = "($args^ `:` type($args))? attr-dict";
+}
+
+def TestRegionTypesCompatOp : TEST_Op<"region_types_compat",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorInputs",
+ "areTypesCompatible"]>,
+ SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+ RecursiveMemoryEffects]> {
+ let arguments = (ins Variadic<AnyType>:$entries);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+}
+
+def TestLoopTypesCompatOp : TEST_Op<"loop_types_compat",
+ [DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getRegionIterArgs", "getYieldedValuesMutable",
+ "getLoopResults", "getLoopRegions", "areLoopIterArgTypesCompatible"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorInputs",
+ "areTypesCompatible"]>,
+ SingleBlockImplicitTerminator<"TestTypesCompatYieldOp">,
+ RecursiveMemoryEffects]> {
+ let arguments = (ins Variadic<AnyType>:$init_args);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$body);
+}
+
//===----------------------------------------------------------------------===//
// Test TableGen generated build() methods
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/184116
More information about the Mlir-commits
mailing list