[Mlir-commits] [mlir] [mlir][SCF] Add `scf.loop` op and terminators (PR #199535)
Matthias Springer
llvmlistbot at llvm.org
Tue May 26 09:09:16 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/199535
>From 5e615c87fed96a222e306ef59fbb52e6e3c03e64 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 May 2026 14:40:25 +0000
Subject: [PATCH] [mlir][SCF] Add `scf.loop` op and terminators
---
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 150 +++++++++++++++++++
mlir/lib/Dialect/SCF/IR/SCF.cpp | 162 +++++++++++++++++++++
mlir/test/Dialect/SCF/invalid.mlir | 101 +++++++++++++
mlir/test/Dialect/SCF/ops.mlir | 73 ++++++++++
4 files changed, 486 insertions(+)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b33ecb48b7f2..57c07fa0a50fc 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -147,6 +147,156 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+def LoopOp : SCF_Op<"loop", [
+ AutomaticAllocationScope,
+ RecursiveMemoryEffects,
+ SingleBlock
+ ]> {
+ let summary = "Loop until a break operation";
+ let description = [{
+ The `scf.loop` operation represents an infinite loop that executes until an
+ `scf.break` is reached.
+
+ The loop consists of (1) a set of loop-carried values which are initialized
+ by `initValues` and updated by each iteration of the loop, and (2) a region
+ which represents the loop body.
+
+ The loop body must end with an explicit terminator, which must be one of:
+
+ - `scf.continue`: re-enters the loop, supplying the next iteration's value
+ for each loop-carried variable. Terminator operand types and loop operand
+ types must match. If the loop has op results, its values are undefined.
+ - `scf.break`: terminates the loop, supplying the final values for the
+ `scf.loop` results. Terminator operand types and loop op result types
+ must match.
+
+ Note: This operation will be extended in the future to support breaking and
+ continuing from nested regions. For now, `scf.break` and `scf.continue`
+ must be terminators of the loop body. In practice this means that an
+ `scf.loop` either runs forever (terminator is `scf.continue`) or executes
+ exactly one iteration (terminator is `scf.break`).
+
+ Examples:
+
+ ```mlir
+ // Loop with iteration-carried values updated by `scf.continue`.
+ scf.loop iter_args(%i = %init) : i32 {
+ %v = "some.compute"(%i) : (i32) -> (i32)
+ scf.continue %v : i32
+ }
+ ```
+
+ ```mlir
+ // Loop with both an iteration-carried value and a result. The iter_arg
+ // and result types may differ.
+ %r = scf.loop iter_args(%i = %init) : i32 -> i64 {
+ %v = "some.compute"(%i) : (i32) -> (i64)
+ scf.break %v : i64
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$initValues);
+ let results = (outs Variadic<AnyType>:$resultValues);
+ let regions = (region SizedRegion<1>:$region);
+
+ let builders = [
+ OpBuilder<(ins
+ CArg<"::mlir::TypeRange", "{}">:$resultTypes,
+ CArg<"::mlir::ValueRange", "{}">:$initValues,
+ CArg<"::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location, "
+ "::mlir::ValueRange)>", "nullptr">:$bodyBuilder)>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return the iteration values of the loop region.
+ Block::BlockArgListType getRegionIterValues() {
+ return getRegion().getArguments();
+ }
+
+ /// Return the `index`-th region iteration value.
+ BlockArgument getRegionIterValue(unsigned index) {
+ return getRegionIterValues()[index];
+ }
+
+ /// Returns the number of region arguments for loop-carried values.
+ unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); }
+
+ /// Returns the loop body block.
+ Block *getBody() { return &getRegion().front(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasRegionVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// BreakOp
+//===----------------------------------------------------------------------===//
+
+def BreakOp : SCF_Op<"break", [
+ Pure, ReturnLike, Terminator, ParentOneOf<["LoopOp"]>
+ ]> {
+ let summary = "Break from an `scf.loop`";
+ let description = [{
+ The `scf.break` operation terminates the immediately enclosing `scf.loop`.
+ Its operands become the loop's result values; their types must match the
+ result types of the enclosing `scf.loop` (verified by the loop).
+
+ Example:
+
+ ```mlir
+ %r = scf.loop -> i32 {
+ ...
+ scf.break %v : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+ let assemblyFormat = [{
+ attr-dict ($operands^ `:` type($operands))?
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// ContinueOp
+//===----------------------------------------------------------------------===//
+
+def ContinueOp : SCF_Op<"continue", [
+ Pure, Terminator, ParentOneOf<["LoopOp"]>
+ ]> {
+ let summary = "Continue to the next iteration of an `scf.loop`";
+ let description = [{
+ The `scf.continue` operation re-enters the immediately enclosing `scf.loop`
+ for its next iteration. Its operands become the loop-carried values
+ (`iter_args`) for the next iteration; their types must match the loop's
+ iter_arg types (verified by the loop).
+
+ Example:
+
+ ```mlir
+ scf.loop iter_args(%i = %init) : i32 {
+ %next = arith.addi %i, %one : i32
+ scf.continue %next : i32
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+ let assemblyFormat = [{
+ attr-dict ($operands^ `:` type($operands))?
+ }];
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..60e5975f4ec48 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -282,6 +282,168 @@ ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange();
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+void LoopOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+ ValueRange initValues,
+ function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
+ result.addOperands(initValues);
+ result.addTypes(resultTypes);
+
+ // Build the body region with a single entry block, one argument per init
+ // value. The caller-supplied `bodyBuilder` is responsible for terminating
+ // the block with either `scf.continue` or `scf.break`.
+ Region *bodyRegion = result.addRegion();
+ Block *bodyBlock = builder.createBlock(bodyRegion);
+ SmallVector<Type> argTypes(initValues.getTypes());
+ SmallVector<Location> argLocs(initValues.size(), result.location);
+ bodyBlock->addArguments(argTypes, argLocs);
+
+ if (bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock->getArguments());
+ }
+}
+
+LogicalResult LoopOp::verifyRegions() {
+ if (getRegion().empty())
+ return emitOpError("region cannot be empty");
+ Block &body = getRegion().front();
+ if (body.getNumArguments() != getNumOperands())
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ for (auto [index, regionArg, initOperand] :
+ llvm::enumerate(body.getArguments(), getOperands())) {
+ if (regionArg.getType() != initOperand.getType())
+ return emitOpError() << "type mismatch between " << index
+ << "th iter operand (" << initOperand.getType()
+ << ") and region argument (" << regionArg.getType()
+ << ")";
+ }
+
+ // The loop body must end with an explicit `scf.break` or `scf.continue`.
+ Operation *terminator = body.getTerminator();
+ if (auto breakOp = dyn_cast<BreakOp>(terminator)) {
+ if (breakOp.getNumOperands() != getNumResults())
+ return breakOp.emitOpError()
+ << "has " << breakOp.getNumOperands()
+ << " operands, but enclosing scf.loop returns " << getNumResults()
+ << " result(s)";
+ for (auto [index, operandType, resultType] :
+ llvm::enumerate(breakOp.getOperandTypes(), getResultTypes())) {
+ if (operandType != resultType)
+ return breakOp.emitOpError()
+ << "type mismatch between " << index << "th operand ("
+ << operandType << ") and " << index
+ << "th result of enclosing scf.loop (" << resultType << ")";
+ }
+ } else if (auto continueOp = dyn_cast<ContinueOp>(terminator)) {
+ if (continueOp.getNumOperands() != getNumRegionIterValues())
+ return continueOp.emitOpError()
+ << "has " << continueOp.getNumOperands()
+ << " operands, but enclosing scf.loop has "
+ << getNumRegionIterValues() << " iter_args";
+ for (auto [index, operandType, iterArgType] : llvm::enumerate(
+ continueOp.getOperandTypes(), body.getArgumentTypes())) {
+ if (operandType != iterArgType)
+ return continueOp.emitOpError()
+ << "type mismatch between " << index << "th operand ("
+ << operandType << ") and " << index
+ << "th iter_arg of enclosing scf.loop (" << iterArgType << ")";
+ }
+ } else {
+ return emitOpError("body must be terminated by 'scf.break' or "
+ "'scf.continue', got '")
+ << terminator->getName() << "'";
+ }
+ return success();
+}
+
+/// Print a type list in functional-return-type style: a single bare type or
+/// a parenthesized comma-separated list.
+static void printFunctionalTypeList(OpAsmPrinter &p, TypeRange types) {
+ if (types.size() == 1) {
+ p << types.front();
+ return;
+ }
+ p << "(";
+ llvm::interleaveComma(types, p);
+ p << ")";
+}
+
+void LoopOp::print(OpAsmPrinter &p) {
+ p << " ";
+ if (!getInitValues().empty()) {
+ p << "iter_args(";
+ llvm::interleaveComma(
+ llvm::zip(getRegionIterValues(), getInitValues()), p, [&](auto it) {
+ p.printRegionArgument(std::get<0>(it), /*argAttrs=*/{},
+ /*omitType=*/true);
+ p << " = " << std::get<1>(it);
+ });
+ p << ") : ";
+ printFunctionalTypeList(p, getInitValues().getTypes());
+ p << " ";
+ }
+ if (!getResultTypes().empty()) {
+ p << "-> ";
+ printFunctionalTypeList(p, getResultTypes());
+ p << " ";
+ }
+
+ p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> iterOperands;
+ SmallVector<Type, 4> iterTypes;
+
+ if (failed(parser.parseOptionalKeyword("iter_args"))) {
+ // No iter_args, but may still have a result type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ } else {
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
+ parser.parseColon())
+ return failure();
+ if (parser.parseOptionalLParen()) {
+ // Single iter_arg type, no parens.
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+ iterTypes.push_back(type);
+ } else {
+ if (parser.parseTypeList(iterTypes) || parser.parseRParen())
+ return failure();
+ }
+ if (regionArgs.size() != iterTypes.size())
+ return parser.emitError(parser.getCurrentLocation(),
+ "found different number of iter_args and types");
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes))
+ regionArg.type = type;
+ }
+
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ if (parser.resolveOperands(iterOperands, iterTypes, parser.getNameLoc(),
+ result.operands))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 33a8921eeb993..099c02631804f 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -852,3 +852,104 @@ func.func @for_missing_induction_var(%arg0: index, %arg1: index) {
}) : (index, index, index) -> ()
return
}
+
+// -----
+
+func.func @break_outside_loop(%v: i32) {
+ // expected-error at +1 {{'scf.break' op expects parent op 'scf.loop'}}
+ scf.break %v : i32
+}
+
+// -----
+
+func.func @continue_outside_loop() {
+ // expected-error at +1 {{'scf.continue' op expects parent op 'scf.loop'}}
+ scf.continue
+}
+
+// -----
+
+func.func @loop_bad_terminator() {
+ // expected-error at +1 {{'scf.loop' op body must be terminated by 'scf.break' or 'scf.continue'}}
+ "scf.loop"() ({
+ ^bb0:
+ "test.foo"() : () -> ()
+ "test.terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// -----
+
+func.func @loop_init_arg_count_mismatch(%init: i32) {
+ // expected-error at +1 {{'scf.loop' op mismatch in number of loop-carried values and defined values}}
+ "scf.loop"(%init) ({
+ ^bb0:
+ scf.continue
+ }) : (i32) -> ()
+ return
+}
+
+// -----
+
+func.func @loop_init_arg_type_mismatch(%init: i32) {
+ // expected-error at +1 {{'scf.loop' op type mismatch between 0th iter operand ('i32') and region argument ('i64')}}
+ "scf.loop"(%init) ({
+ ^bb0(%i: i64):
+ scf.continue %i : i64
+ }) : (i32) -> ()
+ return
+}
+
+// -----
+
+func.func @loop_break_count_mismatch(%v: i32) -> (i32, i32) {
+ // expected-error at +2 {{'scf.break' op has 1 operands, but enclosing scf.loop returns 2 result(s)}}
+ %r:2 = scf.loop -> (i32, i32) {
+ scf.break %v : i32
+ }
+ return %r#0, %r#1 : i32, i32
+}
+
+// -----
+
+func.func @loop_break_type_mismatch(%v: i32) -> i64 {
+ // expected-error at +2 {{'scf.break' op type mismatch between 0th operand ('i32') and 0th result of enclosing scf.loop ('i64')}}
+ %r = scf.loop -> i64 {
+ scf.break %v : i32
+ }
+ return %r : i64
+}
+
+// -----
+
+func.func @loop_continue_count_mismatch(%init: i32) {
+ // expected-error at +2 {{'scf.continue' op has 0 operands, but enclosing scf.loop has 1 iter_args}}
+ scf.loop iter_args(%i = %init) : i32 {
+ scf.continue
+ }
+ return
+}
+
+// -----
+
+func.func @loop_continue_type_mismatch(%init: i32, %v: i64) {
+ // expected-error at +2 {{'scf.continue' op type mismatch between 0th operand ('i64') and 0th iter_arg of enclosing scf.loop ('i32')}}
+ scf.loop iter_args(%i = %init) : i32 {
+ scf.continue %v : i64
+ }
+ return
+}
+
+// -----
+
+func.func @loop_more_than_one_block(%v: i32) -> i32 {
+ // expected-error at +1 {{'scf.loop' op expects region #0 to have 0 or 1 blocks}}
+ %r = "scf.loop"() ({
+ ^bb0:
+ "test.unreachable"() [^bb1] : () -> ()
+ ^bb1:
+ scf.break %v : i32
+ }) : () -> i32
+ return %r : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 5930a1df04266..e8f5294b40a4d 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,76 @@ func.func @switch(%arg0: index) -> i32 {
return %0 : i32
}
+
+// CHECK-LABEL: @loop_infinite
+func.func @loop_infinite() {
+ // CHECK: scf.loop {
+ scf.loop {
+ // CHECK-NEXT: "test.foo"
+ "test.foo"() : () -> ()
+ // CHECK-NEXT: scf.continue
+ scf.continue
+ }
+ return
+}
+
+// CHECK-LABEL: @loop_break_no_operands
+func.func @loop_break_no_operands() {
+ // CHECK: scf.loop {
+ scf.loop {
+ // CHECK-NEXT: scf.break
+ scf.break
+ }
+ return
+}
+
+// CHECK-LABEL: @loop_break_single
+func.func @loop_break_single(%v: i32) -> i32 {
+ // CHECK: %{{.*}} = scf.loop -> i32 {
+ %r = scf.loop -> i32 {
+ // CHECK-NEXT: scf.break %{{.*}} : i32
+ scf.break %v : i32
+ }
+ return %r : i32
+}
+
+// CHECK-LABEL: @loop_break_multi
+func.func @loop_break_multi(%v: i32, %w: i64) -> (i32, i64) {
+ // CHECK: %{{.*}}:2 = scf.loop -> (i32, i64) {
+ %r:2 = scf.loop -> (i32, i64) {
+ // CHECK-NEXT: scf.break %{{.*}}, %{{.*}} : i32, i64
+ scf.break %v, %w : i32, i64
+ }
+ return %r#0, %r#1 : i32, i64
+}
+
+// CHECK-LABEL: @loop_iter_single
+func.func @loop_iter_single(%init: i32) {
+ // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 {
+ scf.loop iter_args(%i = %init) : i32 {
+ // CHECK: scf.continue %{{.*}} : i32
+ scf.continue %i : i32
+ }
+ return
+}
+
+// CHECK-LABEL: @loop_iter_multi
+func.func @loop_iter_multi(%init0: i32, %init1: i64) {
+ // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, i64) {
+ scf.loop iter_args(%i = %init0, %j = %init1) : (i32, i64) {
+ // CHECK: scf.continue %{{.*}}, %{{.*}} : i32, i64
+ scf.continue %i, %j : i32, i64
+ }
+ return
+}
+
+// Loop with iter_args of one type and a single result of another type.
+// CHECK-LABEL: @loop_iter_and_result
+func.func @loop_iter_and_result(%init: i32, %v: i64) -> i64 {
+ // CHECK: %{{.*}} = scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 -> i64 {
+ %r = scf.loop iter_args(%i = %init) : i32 -> i64 {
+ // CHECK: scf.break %{{.*}} : i64
+ scf.break %v : i64
+ }
+ return %r : i64
+}
More information about the Mlir-commits
mailing list