[Mlir-commits] [mlir] 7971655 - [mlir] Add a generic while/do-while loop to the SCF dialect

Alex Zinenko llvmlistbot at llvm.org
Wed Nov 4 00:44:32 PST 2020


Author: Alex Zinenko
Date: 2020-11-04T09:43:13+01:00
New Revision: 79716559b5acee891d5664315d7862c5b5c1d34f

URL: https://github.com/llvm/llvm-project/commit/79716559b5acee891d5664315d7862c5b5c1d34f
DIFF: https://github.com/llvm/llvm-project/commit/79716559b5acee891d5664315d7862c5b5c1d34f.diff

LOG: [mlir] Add a generic while/do-while loop to the SCF dialect

The new construct represents a generic loop with two regions: one executed
before the loop condition is verifier and another after that. This construct
can be used to express both a "while" loop and a "do-while" loop, depending on
where the main payload is located. It is intended as an intermediate
abstraction for lowering, which will be added later. This form is relatively
easy to target from higher-level abstractions and supports transformations such
as loop rotation and LICM.

Differential Revision: https://reviews.llvm.org/D90255

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/SCFOps.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/SCF/SCF.cpp
    mlir/lib/Interfaces/ControlFlowInterfaces.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/Dialect/SCF/invalid.mlir
    mlir/test/Dialect/SCF/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
index a58af941965c..bf81d7ff2177 100644
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -36,6 +36,25 @@ class SCF_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
+def ConditionOp : SCF_Op<"condition",
+                         [HasParent<"WhileOp">, NoSideEffect, Terminator]> {
+  let summary = "loop continuation condition";
+  let description = [{
+    This operation accepts the continuation (i.e., inverse of exit) condition
+    of the `scf.while` construct. If its first argument is true, the "after"
+    region of `scf.while` is executed, with the remaining arguments forwarded
+    to the entry block of the region. Otherwise, the loop terminates.
+  }];
+
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$args);
+
+  let assemblyFormat =
+      [{ `(` $condition `)` attr-dict ($args^ `:` type($args))? }];
+
+  // Override the default verifier, everything is checked by traits.
+  let verifier = ?;
+}
+
 def ForOp : SCF_Op<"for",
       [DeclareOpInterfaceMethods<LoopLikeOpInterface>,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -413,8 +432,135 @@ def ReduceReturnOp :
   let assemblyFormat = "$result attr-dict `:` type($result)";
 }
 
+def WhileOp : SCF_Op<"while",
+    [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+     RecursiveSideEffects]> {
+  let summary = "a generic 'while' loop";
+  let description = [{
+    This operation represents a generic "while"/"do-while" loop that keeps
+    iterating as long as a condition is satisfied. There is no restriction on
+    the complexity of the condition. It consists of two regions (with single
+    block each): "before" region and "after" region. The names of regions
+    indicates whether they execute before or after the condition check.
+    Therefore, if the main loop payload is located in the "before" region, the
+    operation is a "do-while" loop. Otherwise, it is a "while" loop.
+
+    The "before" region terminates with a special operation, `scf.condition`,
+    that accepts as its first operand an `i1` value indicating whether to
+    proceed to the "after" region (value is `true`) or not. The two regions
+    communicate by means of region arguments. Initially, the "before" region
+    accepts as arguments the operands of the `scf.while` operation and uses them
+    to evaluate the condition. It forwards the trailing, non-condition operands
+    of the `scf.condition` terminator either to the "after" region if the
+    control flow is transferred there or to results of the `scf.while` operation
+    otherwise. The "after" region takes as arguments the values produced by the
+    "before" region and uses `scf.yield` to supply new arguments for the "after"
+    region, into which it transfers the control flow unconditionally.
+
+    A simple "while" loop can be represented as follows.
+
+    ```mlir
+    %res = scf.while (%arg1 = %init1) : (f32) -> f32 {
+      /* "Before" region.
+       * In a "while" loop, this region computes the condition. */
+      %condition = call @evaluate_condition(%arg1) : (f32) -> i1
+
+      /* Forward the argument (as result or "after" region argument). */
+      scf.condition(%condition) %arg1 : f32
+
+    } do {
+    ^bb0(%arg2: f32):
+      /* "After region.
+       * In a "while" loop, this region is the loop body. */
+      %next = call @payload(%arg2) : (f32) -> f32
+
+      /* Forward the new value to the "before" region.
+       * The operand types must match the types of the `scf.while` operands. */
+      scf.yield %next : f32
+    }
+    ```
+
+    A simple "do-while" loop can be represented by reducing the "after" block
+    to a simple forwarder.
+
+    ```mlir
+    %res = scf.while (%arg1 = %init1) : (f32) -> f32 {
+      /* "Before" region.
+       * In a "do-while" loop, this region contains the loop body. */
+      %next = call @payload(%arg1) : (f32) -> f32
+
+      /* And also evalutes the condition. */
+      %condition = call @evaluate_condition(%arg1) : (f32) -> i1
+
+      /* Loop through the "after" region. */
+      scf.condition(%condition) %next : f32
+
+    } do {
+    ^bb0(%arg2: f32):
+      /* "After" region.
+       * Forwards the values back to "before" region unmodified. */
+      scf.yield %arg2 : f32
+    }
+    ```
+
+    Note that the types of region arguments need not to match with each other.
+    The op expects the operand types to match with argument types of the
+    "before" region"; the result types to match with the trailing operand types
+    of the terminator of the "before" region, and with the argument types of the
+    "after" region. The following scheme can be used to share the results of
+    some operations executed in the "before" region with the "after" region,
+    avoiding the need to recompute them.
+
+    ```mlir
+    %res = scf.while (%arg1 = %init1) : (f32) -> i64 {
+      /* One can perform some computations, e.g., necessary to evaluate the
+       * condition, in the "before" region and forward their results to the
+       * "after" region. */
+      %shared = call @shared_compute(%arg1) : (f32) -> i64
+
+      /* Evalute the condition. */
+      %condition = call @evaluate_condition(%arg1, %shared) : (f32, i64) -> i1
+
+      /* Forward the result of the shared computation to the "after" region.
+       * The types must match the arguments of the "after" region as well as
+       * those of the `scf.while` results. */
+      scf.condition(%condition) %shared : i64
+
+    } do {
+    ^bb0(%arg2: i64) {
+      /* Use the partial result to compute the rest of the payload in the
+       * "after" region. */
+      %res = call @payload(%arg2) : (i64) -> f32
+
+      /* Forward the new value to the "before" region.
+       * The operand types must match the types of the `scf.while` operands. */
+      scf.yield %res : f32
+    }
+    ```
+
+    The custom syntax for this operation is as follows.
+
+    ```
+    op ::= `scf.while` assignments `:` function-type region `do` region
+           `attributes` attribute-dict
+    initializer ::= /* empty */ | `(` assignment-list `)`
+    assignment-list ::= assignment | assignment `,` assignment-list
+    assignment ::= ssa-value `=` ssa-value
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$inits);
+  let results = (outs Variadic<AnyType>:$results);
+  let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
+
+  let extraClassDeclaration = [{
+    OperandRange getSuccessorEntryOperands(unsigned index);
+  }];
+}
+
 def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
-                               ParentOneOf<["IfOp, ForOp", "ParallelOp"]>]> {
+                               ParentOneOf<["IfOp, ForOp", "ParallelOp",
+                                            "WhileOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
     "scf.yield" yields an SSA value from the SCF dialect op region and
@@ -434,4 +580,5 @@ def YieldOp : SCF_Op<"yield", [NoSideEffect, ReturnLike, Terminator,
   // needed.
   let verifier = ?;
 }
+
 #endif // MLIR_DIALECT_SCF_SCFOPS

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0813cde0256d..d24561333a3d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -755,11 +755,18 @@ class OpAsmParser {
   parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
   /// Parse a list of assignments of the form
-  /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
-  /// The list must contain at least one entry
-  virtual ParseResult
-  parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
-                      SmallVectorImpl<OperandType> &rhs) = 0;
+  ///   (%x1 = %y1, %x2 = %y2, ...)
+  ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
+                                  SmallVectorImpl<OperandType> &rhs) {
+    OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
+    if (!result.hasValue())
+      return emitError(getCurrentLocation(), "expected '('");
+    return result.getValue();
+  }
+
+  virtual OptionalParseResult
+  parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
+                              SmallVectorImpl<OperandType> &rhs) = 0;
 
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {

diff  --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 39f6e2d88162..56932ff1f30e 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -140,26 +140,37 @@ static LogicalResult verify(ForOp op) {
   return RegionBranchOpInterface::verifyTypes(op);
 }
 
+/// Prints the initialization list in the form of
+///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
+/// where 'inner' values are assumed to be region arguments and 'outer' values
+/// are regular SSA values.
+static void printInitializationList(OpAsmPrinter &p,
+                                    Block::BlockArgListType blocksArgs,
+                                    ValueRange initializers,
+                                    StringRef prefix = "") {
+  assert(blocksArgs.size() == initializers.size() &&
+         "expected same length of arguments and initializers");
+  if (initializers.empty())
+    return;
+
+  p << prefix << '(';
+  llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
+    p << std::get<0>(it) << " = " << std::get<1>(it);
+  });
+  p << ")";
+}
+
 static void print(OpAsmPrinter &p, ForOp op) {
-  bool printBlockTerminators = false;
   p << op.getOperationName() << " " << op.getInductionVar() << " = "
     << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
 
-  if (op.hasIterOperands()) {
-    p << " iter_args(";
-    auto regionArgs = op.getRegionIterArgs();
-    auto operands = op.getIterOperands();
-
-    llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
-      p << std::get<0>(it) << " = " << std::get<1>(it);
-    });
-    p << ")";
-    p << " -> (" << op.getResultTypes() << ")";
-    printBlockTerminators = true;
-  }
+  printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
+                          " iter_args");
+  if (!op.getIterOperands().empty())
+    p << " -> (" << op.getIterOperands().getTypes() << ')';
   p.printRegion(op.region(),
                 /*printEntryBlockArgs=*/false,
-                /*printBlockTerminators=*/printBlockTerminators);
+                /*printBlockTerminators=*/op.hasIterOperands());
   p.printOptionalAttrDict(op.getAttrs());
 }
 
@@ -933,6 +944,158 @@ static LogicalResult verify(ReduceReturnOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WhileOp
+//===----------------------------------------------------------------------===//
+
+OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
+  assert(index == 0 &&
+         "WhileOp is expected to branch only to the first region");
+
+  return inits();
+}
+
+void WhileOp::getSuccessorRegions(Optional<unsigned> index,
+                                  ArrayRef<Attribute> operands,
+                                  SmallVectorImpl<RegionSuccessor> &regions) {
+  (void)operands;
+
+  if (!index.hasValue()) {
+    regions.emplace_back(&before(), before().getArguments());
+    return;
+  }
+
+  assert(*index < 2 && "there are only two regions in a WhileOp");
+  if (*index == 0) {
+    regions.emplace_back(&after(), after().getArguments());
+    regions.emplace_back(getResults());
+    return;
+  }
+
+  regions.emplace_back(&before(), before().getArguments());
+}
+
+/// Parses a `while` op.
+///
+/// op ::= `scf.while` assignments `:` function-type region `do` region
+///         `attributes` attribute-dict
+/// initializer ::= /* empty */ | `(` assignment-list `)`
+/// assignment-list ::= assignment | assignment `,` assignment-list
+/// assignment ::= ssa-value `=` ssa-value
+static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
+  SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
+  Region *before = result.addRegion();
+  Region *after = result.addRegion();
+
+  OptionalParseResult listResult =
+      parser.parseOptionalAssignmentList(regionArgs, operands);
+  if (listResult.hasValue() && failed(listResult.getValue()))
+    return failure();
+
+  FunctionType functionType;
+  llvm::SMLoc typeLoc = parser.getCurrentLocation();
+  if (failed(parser.parseColonType(functionType)))
+    return failure();
+
+  result.addTypes(functionType.getResults());
+
+  if (functionType.getNumInputs() != operands.size()) {
+    return parser.emitError(typeLoc)
+           << "expected as many input types as operands "
+           << "(expected " << operands.size() << " got "
+           << functionType.getNumInputs() << ")";
+  }
+
+  // Resolve input operands.
+  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+                                    parser.getCurrentLocation(),
+                                    result.operands)))
+    return failure();
+
+  return failure(
+      parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
+      parser.parseKeyword("do") || parser.parseRegion(*after) ||
+      parser.parseOptionalAttrDictWithKeyword(result.attributes));
+}
+
+/// Prints a `while` op.
+static void print(OpAsmPrinter &p, scf::WhileOp op) {
+  p << op.getOperationName();
+  printInitializationList(p, op.before().front().getArguments(), op.inits(),
+                          " ");
+  p << " : ";
+  p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
+  p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
+  p << " do";
+  p.printRegion(op.after());
+  p.printOptionalAttrDictWithKeyword(op.getAttrs());
+}
+
+/// Verifies that two ranges of types match, i.e. have the same number of
+/// entries and that types are pairwise equals. Reports errors on the given
+/// operation in case of mismatch.
+template <typename OpTy>
+static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
+                                           TypeRange right, StringRef message) {
+  if (left.size() != right.size())
+    return op.emitOpError("expects the same number of ") << message;
+
+  for (unsigned i = 0, e = left.size(); i < e; ++i) {
+    if (left[i] != right[i]) {
+      InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
+                                << message;
+      diag.attachNote() << "for argument " << i << ", found " << left[i]
+                        << " and " << right[i];
+      return diag;
+    }
+  }
+
+  return success();
+}
+
+/// Verifies that the first block of the given `region` is terminated by a
+/// YieldOp. Reports errors on the given operation if it is not the case.
+template <typename TerminatorTy>
+static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
+                                           StringRef errorMessage) {
+  Operation *terminatorOperation = region.front().getTerminator();
+  if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
+    return yield;
+
+  auto diag = op.emitOpError(errorMessage);
+  if (terminatorOperation)
+    diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
+  return nullptr;
+}
+
+static LogicalResult verify(scf::WhileOp op) {
+  if (failed(RegionBranchOpInterface::verifyTypes(op)))
+    return failure();
+
+  auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
+      op, op.before(),
+      "expects the 'before' region to terminate with 'scf.condition'");
+  if (!beforeTerminator)
+    return failure();
+
+  TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes();
+  if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands,
+                                   op.after().getArgumentTypes(),
+                                   "trailing operands of the 'before' block "
+                                   "terminator and 'after' region arguments")))
+    return failure();
+
+  if (failed(verifyTypeRangesMatch(
+          op, trailingTerminatorOperands, op.getResultTypes(),
+          "trailing operands of the 'before' block terminator and op results")))
+    return failure();
+
+  auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
+      op, op.after(),
+      "expects the 'after' region to terminate with 'scf.yield'");
+  return success(afterTerminator != nullptr);
+}
+
 //===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 194a64576986..99e86bcb5057 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -76,10 +76,13 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 /// Verify that types match along all region control flow edges originating from
 /// `sourceNo` (region # if source is a region, llvm::None if source is parent
 /// op). `getInputsTypesForRegion` is a function that returns the types of the
-/// inputs that flow from `sourceIndex' to the given region.
-static LogicalResult verifyTypesAlongAllEdges(
-    Operation *op, Optional<unsigned> sourceNo,
-    function_ref<TypeRange(Optional<unsigned>)> getInputsTypesForRegion) {
+/// inputs that flow from `sourceIndex' to the given region, or llvm::None if
+/// the exact type match verification is not necessary (e.g., if the Op verifies
+/// the match itself).
+static LogicalResult
+verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
+                         function_ref<Optional<TypeRange>(Optional<unsigned>)>
+                             getInputsTypesForRegion) {
   auto regionInterface = cast<RegionBranchOpInterface>(op);
 
   SmallVector<RegionSuccessor, 2> successors;
@@ -113,17 +116,20 @@ static LogicalResult verifyTypesAlongAllEdges(
       return diag;
     };
 
-    TypeRange sourceTypes = getInputsTypesForRegion(succRegionNo);
+    Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
+    if (!sourceTypes.hasValue())
+      continue;
+
     TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
-    if (sourceTypes.size() != succInputsTypes.size()) {
+    if (sourceTypes->size() != succInputsTypes.size()) {
       InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
-      return printEdgeName(diag) << ": source has " << sourceTypes.size()
+      return printEdgeName(diag) << ": source has " << sourceTypes->size()
                                  << " operands, but target successor needs "
                                  << succInputsTypes.size();
     }
 
     for (auto typesIdx :
-         llvm::enumerate(llvm::zip(sourceTypes, succInputsTypes))) {
+         llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
       Type sourceType = std::get<0>(typesIdx.value());
       Type inputType = std::get<1>(typesIdx.value());
       if (sourceType != inputType) {
@@ -191,10 +197,15 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
                << " operands mismatch between return-like terminators";
     }
 
-    auto inputTypesFromRegion = [&](Optional<unsigned> regionNo) -> TypeRange {
+    auto inputTypesFromRegion =
+        [&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
+      // If there is no return-like terminator, the op itself should verify
+      // type consistency.
+      if (!regionReturn)
+        return llvm::None;
+
       // All successors get the same set of operands.
-      return regionReturn ? TypeRange(regionReturn->getOperands().getTypes())
-                          : TypeRange();
+      return TypeRange(regionReturn->getOperands().getTypes());
     };
 
     if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index a824687aefb2..8c581ad7bc0c 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1480,10 +1480,13 @@ class CustomOpAsmParser : public OpAsmParser {
   }
 
   /// Parse a list of assignments of the form
-  /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...).
-  /// The list must contain at least one entry
-  ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
-                                  SmallVectorImpl<OperandType> &rhs) override {
+  ///   (%x1 = %y1, %x2 = %y2, ...).
+  OptionalParseResult
+  parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
+                              SmallVectorImpl<OperandType> &rhs) override {
+    if (failed(parseOptionalLParen()))
+      return llvm::None;
+
     auto parseElt = [&]() -> ParseResult {
       OperandType regionArg, operand;
       if (parseRegionArgument(regionArg) || parseEqual() ||
@@ -1493,8 +1496,6 @@ class CustomOpAsmParser : public OpAsmParser {
       rhs.push_back(operand);
       return success();
     };
-    if (parseLParen())
-      return failure();
     return parser.parseCommaSeparatedListUntil(Token::r_paren, parseElt);
   }
 

diff  --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 06b902da781c..4eaef611cb16 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -425,10 +425,88 @@ func @parallel_invalid_yield(
 }
 
 // -----
+
 func @yield_invalid_parent_op() {
   "my.op"() ({
-   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel'}}
+   // expected-error at +1 {{'scf.yield' op expects parent op to be one of 'scf.if, scf.for, scf.parallel, scf.while'}}
    scf.yield
   }) : () -> ()
   return
 }
+
+// -----
+
+func @while_parser_type_mismatch() {
+  %true = constant true
+  // expected-error at +1 {{expected as many input types as operands (expected 0 got 1)}}
+  scf.while : (i32) -> () {
+    scf.condition(%true)
+  } do {
+    scf.yield
+  }
+}
+
+// -----
+
+func @while_bad_terminator() {
+  // expected-error at +1 {{expects the 'before' region to terminate with 'scf.condition'}}
+  scf.while : () -> () {
+    // expected-note at +1 {{terminator here}}
+    "some.other_terminator"() : () -> ()
+  } do {
+    scf.yield
+  }
+}
+
+// -----
+
+func @while_cross_region_type_mismatch() {
+  %true = constant true
+  // expected-error at +1 {{expects the same number of trailing operands of the 'before' block terminator and 'after' region arguments}}
+  scf.while : () -> () {
+    scf.condition(%true)
+  } do {
+  ^bb0(%arg0: i32):
+    scf.yield
+  }
+}
+
+// -----
+
+func @while_cross_region_type_mismatch() {
+  %true = constant true
+  // expected-error at +2 {{expects the same types for trailing operands of the 'before' block terminator and 'after' region arguments}}
+  // expected-note at +1 {{for argument 0, found 'i1' and 'i32}}
+  scf.while : () -> () {
+    scf.condition(%true) %true : i1
+  } do {
+  ^bb0(%arg0: i32):
+    scf.yield
+  }
+}
+
+// -----
+
+func @while_result_type_mismatch() {
+  %true = constant true
+  // expected-error at +1 {{expects the same number of trailing operands of the 'before' block terminator and op results}}
+  scf.while : () -> () {
+    scf.condition(%true) %true : i1
+  } do {
+  ^bb0(%arg0: i1):
+    scf.yield
+  }
+}
+
+// -----
+
+func @while_bad_terminator() {
+  %true = constant true
+  // expected-error at +1 {{expects the 'after' region to terminate with 'scf.yield'}}
+  scf.while : () -> () {
+    scf.condition(%true)
+  } do {
+    // expected-note at +1 {{terminator here}}
+    "some.other_terminator"() : () -> ()
+  }
+}

diff  --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 1058983f5fb9..8e9f6a0ed33d 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -240,3 +240,42 @@ func @conditional_reduce(%buffer: memref<1024xf32>, %lb: index, %ub: index, %ste
 //  CHECK-NEXT: scf.yield %[[IFRES]] : f32
 //  CHECK-NEXT: }
 //  CHECK-NEXT: return %[[RESULT]]
+
+// CHECK-LABEL: @while
+func @while() {
+  %0 = "test.get_some_value"() : () -> i32
+  %1 = "test.get_some_value"() : () -> f32
+
+  // CHECK: = scf.while (%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, f32) -> (i64, f64) {
+  %2:2 = scf.while (%arg0 = %0, %arg1 = %1) : (i32, f32) -> (i64, f64) {
+    %3:2 = "test.some_operation"(%arg0, %arg1) : (i32, f32) -> (i64, f64)
+    %4 = "test.some_condition"(%arg0, %arg1) : (i32, f32) -> i1
+    // CHECK: scf.condition(%{{.*}}) %{{.*}}, %{{.*}} : i64, f64
+    scf.condition(%4) %3#0, %3#1 : i64, f64
+  // CHECK: } do {
+  } do {
+  // CHECK: ^{{.*}}(%{{.*}}: i64, %{{.*}}: f64):
+  ^bb0(%arg2: i64, %arg3: f64):
+    %5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32)
+    // CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32
+    scf.yield %5#0, %5#1 : i32, f32
+  // CHECK: attributes {foo = "bar"}
+  } attributes {foo="bar"}
+  return
+}
+
+// CHECK-LABEL: @infinite_while
+func @infinite_while() {
+  %true = constant true
+
+  // CHECK: scf.while  : () -> () {
+  scf.while : () -> () {
+    // CHECK: scf.condition(%{{.*}})
+    scf.condition(%true)
+  // CHECK: } do {
+  } do {
+    // CHECK: scf.yield
+    scf.yield
+  }
+  return
+}


        


More information about the Mlir-commits mailing list