[Mlir-commits] [mlir] [mlir][SCF] Add `scf.loop` op and terminators (PR #199535)

Matthias Springer llvmlistbot at llvm.org
Tue May 26 09:09:16 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/199535

>From 5e615c87fed96a222e306ef59fbb52e6e3c03e64 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 May 2026 14:40:25 +0000
Subject: [PATCH] [mlir][SCF] Add `scf.loop` op and terminators

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 150 +++++++++++++++++++
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 162 +++++++++++++++++++++
 mlir/test/Dialect/SCF/invalid.mlir         | 101 +++++++++++++
 mlir/test/Dialect/SCF/ops.mlir             |  73 ++++++++++
 4 files changed, 486 insertions(+)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b33ecb48b7f2..57c07fa0a50fc 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -147,6 +147,156 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+def LoopOp : SCF_Op<"loop", [
+    AutomaticAllocationScope,
+    RecursiveMemoryEffects,
+    SingleBlock
+  ]> {
+  let summary = "Loop until a break operation";
+  let description = [{
+    The `scf.loop` operation represents an infinite loop that executes until an
+    `scf.break` is reached.
+
+    The loop consists of (1) a set of loop-carried values which are initialized
+    by `initValues` and updated by each iteration of the loop, and (2) a region
+    which represents the loop body.
+
+    The loop body must end with an explicit terminator, which must be one of:
+
+    - `scf.continue`: re-enters the loop, supplying the next iteration's value
+      for each loop-carried variable. Terminator operand types and loop operand
+      types must match. If the loop has op results, its values are undefined.
+    - `scf.break`: terminates the loop, supplying the final values for the
+      `scf.loop` results. Terminator operand types and loop op result types
+      must match.
+
+    Note: This operation will be extended in the future to support breaking and
+    continuing from nested regions. For now, `scf.break` and `scf.continue`
+    must be terminators of the loop body. In practice this means that an
+    `scf.loop` either runs forever (terminator is `scf.continue`) or executes
+    exactly one iteration (terminator is `scf.break`).
+
+    Examples:
+
+    ```mlir
+    // Loop with iteration-carried values updated by `scf.continue`.
+    scf.loop iter_args(%i = %init) : i32 {
+      %v = "some.compute"(%i) : (i32) -> (i32)
+      scf.continue %v : i32
+    }
+    ```
+
+    ```mlir
+    // Loop with both an iteration-carried value and a result. The iter_arg
+    // and result types may differ.
+    %r = scf.loop iter_args(%i = %init) : i32 -> i64 {
+      %v = "some.compute"(%i) : (i32) -> (i64)
+      scf.break %v : i64
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$initValues);
+  let results = (outs Variadic<AnyType>:$resultValues);
+  let regions = (region SizedRegion<1>:$region);
+
+  let builders = [
+    OpBuilder<(ins
+      CArg<"::mlir::TypeRange", "{}">:$resultTypes,
+      CArg<"::mlir::ValueRange", "{}">:$initValues,
+      CArg<"::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location, "
+           "::mlir::ValueRange)>", "nullptr">:$bodyBuilder)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return the iteration values of the loop region.
+    Block::BlockArgListType getRegionIterValues() {
+      return getRegion().getArguments();
+    }
+
+    /// Return the `index`-th region iteration value.
+    BlockArgument getRegionIterValue(unsigned index) {
+      return getRegionIterValues()[index];
+    }
+
+    /// Returns the number of region arguments for loop-carried values.
+    unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); }
+
+    /// Returns the loop body block.
+    Block *getBody() { return &getRegion().front(); }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasRegionVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// BreakOp
+//===----------------------------------------------------------------------===//
+
+def BreakOp : SCF_Op<"break", [
+    Pure, ReturnLike, Terminator, ParentOneOf<["LoopOp"]>
+  ]> {
+  let summary = "Break from an `scf.loop`";
+  let description = [{
+    The `scf.break` operation terminates the immediately enclosing `scf.loop`.
+    Its operands become the loop's result values; their types must match the
+    result types of the enclosing `scf.loop` (verified by the loop).
+
+    Example:
+
+    ```mlir
+    %r = scf.loop -> i32 {
+      ...
+      scf.break %v : i32
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+  let assemblyFormat = [{
+    attr-dict ($operands^ `:` type($operands))?
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ContinueOp
+//===----------------------------------------------------------------------===//
+
+def ContinueOp : SCF_Op<"continue", [
+    Pure, Terminator, ParentOneOf<["LoopOp"]>
+  ]> {
+  let summary = "Continue to the next iteration of an `scf.loop`";
+  let description = [{
+    The `scf.continue` operation re-enters the immediately enclosing `scf.loop`
+    for its next iteration. Its operands become the loop-carried values
+    (`iter_args`) for the next iteration; their types must match the loop's
+    iter_arg types (verified by the loop).
+
+    Example:
+
+    ```mlir
+    scf.loop iter_args(%i = %init) : i32 {
+      %next = arith.addi %i, %one : i32
+      scf.continue %next : i32
+    }
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$operands);
+  let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+  let assemblyFormat = [{
+    attr-dict ($operands^ `:` type($operands))?
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..60e5975f4ec48 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -282,6 +282,168 @@ ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+void LoopOp::build(
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+    ValueRange initValues,
+    function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
+  result.addOperands(initValues);
+  result.addTypes(resultTypes);
+
+  // Build the body region with a single entry block, one argument per init
+  // value. The caller-supplied `bodyBuilder` is responsible for terminating
+  // the block with either `scf.continue` or `scf.break`.
+  Region *bodyRegion = result.addRegion();
+  Block *bodyBlock = builder.createBlock(bodyRegion);
+  SmallVector<Type> argTypes(initValues.getTypes());
+  SmallVector<Location> argLocs(initValues.size(), result.location);
+  bodyBlock->addArguments(argTypes, argLocs);
+
+  if (bodyBuilder) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(bodyBlock);
+    bodyBuilder(builder, result.location, bodyBlock->getArguments());
+  }
+}
+
+LogicalResult LoopOp::verifyRegions() {
+  if (getRegion().empty())
+    return emitOpError("region cannot be empty");
+  Block &body = getRegion().front();
+  if (body.getNumArguments() != getNumOperands())
+    return emitOpError(
+        "mismatch in number of loop-carried values and defined values");
+  for (auto [index, regionArg, initOperand] :
+       llvm::enumerate(body.getArguments(), getOperands())) {
+    if (regionArg.getType() != initOperand.getType())
+      return emitOpError() << "type mismatch between " << index
+                           << "th iter operand (" << initOperand.getType()
+                           << ") and region argument (" << regionArg.getType()
+                           << ")";
+  }
+
+  // The loop body must end with an explicit `scf.break` or `scf.continue`.
+  Operation *terminator = body.getTerminator();
+  if (auto breakOp = dyn_cast<BreakOp>(terminator)) {
+    if (breakOp.getNumOperands() != getNumResults())
+      return breakOp.emitOpError()
+             << "has " << breakOp.getNumOperands()
+             << " operands, but enclosing scf.loop returns " << getNumResults()
+             << " result(s)";
+    for (auto [index, operandType, resultType] :
+         llvm::enumerate(breakOp.getOperandTypes(), getResultTypes())) {
+      if (operandType != resultType)
+        return breakOp.emitOpError()
+               << "type mismatch between " << index << "th operand ("
+               << operandType << ") and " << index
+               << "th result of enclosing scf.loop (" << resultType << ")";
+    }
+  } else if (auto continueOp = dyn_cast<ContinueOp>(terminator)) {
+    if (continueOp.getNumOperands() != getNumRegionIterValues())
+      return continueOp.emitOpError()
+             << "has " << continueOp.getNumOperands()
+             << " operands, but enclosing scf.loop has "
+             << getNumRegionIterValues() << " iter_args";
+    for (auto [index, operandType, iterArgType] : llvm::enumerate(
+             continueOp.getOperandTypes(), body.getArgumentTypes())) {
+      if (operandType != iterArgType)
+        return continueOp.emitOpError()
+               << "type mismatch between " << index << "th operand ("
+               << operandType << ") and " << index
+               << "th iter_arg of enclosing scf.loop (" << iterArgType << ")";
+    }
+  } else {
+    return emitOpError("body must be terminated by 'scf.break' or "
+                       "'scf.continue', got '")
+           << terminator->getName() << "'";
+  }
+  return success();
+}
+
+/// Print a type list in functional-return-type style: a single bare type or
+/// a parenthesized comma-separated list.
+static void printFunctionalTypeList(OpAsmPrinter &p, TypeRange types) {
+  if (types.size() == 1) {
+    p << types.front();
+    return;
+  }
+  p << "(";
+  llvm::interleaveComma(types, p);
+  p << ")";
+}
+
+void LoopOp::print(OpAsmPrinter &p) {
+  p << " ";
+  if (!getInitValues().empty()) {
+    p << "iter_args(";
+    llvm::interleaveComma(
+        llvm::zip(getRegionIterValues(), getInitValues()), p, [&](auto it) {
+          p.printRegionArgument(std::get<0>(it), /*argAttrs=*/{},
+                                /*omitType=*/true);
+          p << " = " << std::get<1>(it);
+        });
+    p << ") : ";
+    printFunctionalTypeList(p, getInitValues().getTypes());
+    p << " ";
+  }
+  if (!getResultTypes().empty()) {
+    p << "-> ";
+    printFunctionalTypeList(p, getResultTypes());
+    p << " ";
+  }
+
+  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/true);
+  p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> iterOperands;
+  SmallVector<Type, 4> iterTypes;
+
+  if (failed(parser.parseOptionalKeyword("iter_args"))) {
+    // No iter_args, but may still have a result type list.
+    if (parser.parseOptionalArrowTypeList(result.types))
+      return failure();
+  } else {
+    if (parser.parseAssignmentList(regionArgs, iterOperands) ||
+        parser.parseColon())
+      return failure();
+    if (parser.parseOptionalLParen()) {
+      // Single iter_arg type, no parens.
+      Type type;
+      if (parser.parseType(type))
+        return failure();
+      iterTypes.push_back(type);
+    } else {
+      if (parser.parseTypeList(iterTypes) || parser.parseRParen())
+        return failure();
+    }
+    if (regionArgs.size() != iterTypes.size())
+      return parser.emitError(parser.getCurrentLocation(),
+                              "found different number of iter_args and types");
+    if (parser.parseOptionalArrowTypeList(result.types))
+      return failure();
+    for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes))
+      regionArg.type = type;
+  }
+
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (parser.resolveOperands(iterOperands, iterTypes, parser.getNameLoc(),
+                             result.operands))
+    return failure();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConditionOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 33a8921eeb993..099c02631804f 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -852,3 +852,104 @@ func.func @for_missing_induction_var(%arg0: index, %arg1: index) {
   }) : (index, index, index) -> ()
   return
 }
+
+// -----
+
+func.func @break_outside_loop(%v: i32) {
+  // expected-error at +1 {{'scf.break' op expects parent op 'scf.loop'}}
+  scf.break %v : i32
+}
+
+// -----
+
+func.func @continue_outside_loop() {
+  // expected-error at +1 {{'scf.continue' op expects parent op 'scf.loop'}}
+  scf.continue
+}
+
+// -----
+
+func.func @loop_bad_terminator() {
+  // expected-error at +1 {{'scf.loop' op body must be terminated by 'scf.break' or 'scf.continue'}}
+  "scf.loop"() ({
+  ^bb0:
+    "test.foo"() : () -> ()
+    "test.terminator"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @loop_init_arg_count_mismatch(%init: i32) {
+  // expected-error at +1 {{'scf.loop' op mismatch in number of loop-carried values and defined values}}
+  "scf.loop"(%init) ({
+  ^bb0:
+    scf.continue
+  }) : (i32) -> ()
+  return
+}
+
+// -----
+
+func.func @loop_init_arg_type_mismatch(%init: i32) {
+  // expected-error at +1 {{'scf.loop' op type mismatch between 0th iter operand ('i32') and region argument ('i64')}}
+  "scf.loop"(%init) ({
+  ^bb0(%i: i64):
+    scf.continue %i : i64
+  }) : (i32) -> ()
+  return
+}
+
+// -----
+
+func.func @loop_break_count_mismatch(%v: i32) -> (i32, i32) {
+  // expected-error at +2 {{'scf.break' op has 1 operands, but enclosing scf.loop returns 2 result(s)}}
+  %r:2 = scf.loop -> (i32, i32) {
+    scf.break %v : i32
+  }
+  return %r#0, %r#1 : i32, i32
+}
+
+// -----
+
+func.func @loop_break_type_mismatch(%v: i32) -> i64 {
+  // expected-error at +2 {{'scf.break' op type mismatch between 0th operand ('i32') and 0th result of enclosing scf.loop ('i64')}}
+  %r = scf.loop -> i64 {
+    scf.break %v : i32
+  }
+  return %r : i64
+}
+
+// -----
+
+func.func @loop_continue_count_mismatch(%init: i32) {
+  // expected-error at +2 {{'scf.continue' op has 0 operands, but enclosing scf.loop has 1 iter_args}}
+  scf.loop iter_args(%i = %init) : i32 {
+    scf.continue
+  }
+  return
+}
+
+// -----
+
+func.func @loop_continue_type_mismatch(%init: i32, %v: i64) {
+  // expected-error at +2 {{'scf.continue' op type mismatch between 0th operand ('i64') and 0th iter_arg of enclosing scf.loop ('i32')}}
+  scf.loop iter_args(%i = %init) : i32 {
+    scf.continue %v : i64
+  }
+  return
+}
+
+// -----
+
+func.func @loop_more_than_one_block(%v: i32) -> i32 {
+  // expected-error at +1 {{'scf.loop' op expects region #0 to have 0 or 1 blocks}}
+  %r = "scf.loop"() ({
+  ^bb0:
+    "test.unreachable"() [^bb1] : () -> ()
+  ^bb1:
+    scf.break %v : i32
+  }) : () -> i32
+  return %r : i32
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 5930a1df04266..e8f5294b40a4d 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -441,3 +441,76 @@ func.func @switch(%arg0: index) -> i32 {
 
   return %0 : i32
 }
+
+// CHECK-LABEL: @loop_infinite
+func.func @loop_infinite() {
+  // CHECK: scf.loop {
+  scf.loop {
+    // CHECK-NEXT: "test.foo"
+    "test.foo"() : () -> ()
+    // CHECK-NEXT: scf.continue
+    scf.continue
+  }
+  return
+}
+
+// CHECK-LABEL: @loop_break_no_operands
+func.func @loop_break_no_operands() {
+  // CHECK: scf.loop {
+  scf.loop {
+    // CHECK-NEXT: scf.break
+    scf.break
+  }
+  return
+}
+
+// CHECK-LABEL: @loop_break_single
+func.func @loop_break_single(%v: i32) -> i32 {
+  // CHECK: %{{.*}} = scf.loop -> i32 {
+  %r = scf.loop -> i32 {
+    // CHECK-NEXT: scf.break %{{.*}} : i32
+    scf.break %v : i32
+  }
+  return %r : i32
+}
+
+// CHECK-LABEL: @loop_break_multi
+func.func @loop_break_multi(%v: i32, %w: i64) -> (i32, i64) {
+  // CHECK: %{{.*}}:2 = scf.loop -> (i32, i64) {
+  %r:2 = scf.loop -> (i32, i64) {
+    // CHECK-NEXT: scf.break %{{.*}}, %{{.*}} : i32, i64
+    scf.break %v, %w : i32, i64
+  }
+  return %r#0, %r#1 : i32, i64
+}
+
+// CHECK-LABEL: @loop_iter_single
+func.func @loop_iter_single(%init: i32) {
+  // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 {
+  scf.loop iter_args(%i = %init) : i32 {
+    // CHECK: scf.continue %{{.*}} : i32
+    scf.continue %i : i32
+  }
+  return
+}
+
+// CHECK-LABEL: @loop_iter_multi
+func.func @loop_iter_multi(%init0: i32, %init1: i64) {
+  // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, i64) {
+  scf.loop iter_args(%i = %init0, %j = %init1) : (i32, i64) {
+    // CHECK: scf.continue %{{.*}}, %{{.*}} : i32, i64
+    scf.continue %i, %j : i32, i64
+  }
+  return
+}
+
+// Loop with iter_args of one type and a single result of another type.
+// CHECK-LABEL: @loop_iter_and_result
+func.func @loop_iter_and_result(%init: i32, %v: i64) -> i64 {
+  // CHECK: %{{.*}} = scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 -> i64 {
+  %r = scf.loop iter_args(%i = %init) : i32 -> i64 {
+    // CHECK: scf.break %{{.*}} : i64
+    scf.break %v : i64
+  }
+  return %r : i64
+}



More information about the Mlir-commits mailing list