[flang-commits] [flang] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to get yielded values (PR #67305)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Thu Sep 28 03:42:43 PDT 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/67305

>From 8780fcce6f2f0c970a4117ea9fcd62d800191227 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 Sep 2023 12:29:13 +0200
Subject: [PATCH] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to get
 yielded values

Add a new interface method that returns the yielded values.

Also add a verifier that checks the number of inits/iter_args/yielded values. Most of the checked invariants (but not all of them) are already covered by the `RegionBranchOpInterface`, but the `LoopLikeOpInterface` now provides (additional) error messages that are easier to read.
---
 .../include/flang/Optimizer/Dialect/FIROps.td |  4 +-
 flang/lib/Lower/Bridge.cpp                    |  6 +--
 flang/lib/Lower/ConvertExpr.cpp               |  4 +-
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  8 ++--
 .../Transforms/SimplifyIntrinsics.cpp         |  6 +--
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  3 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 11 +++++-
 .../mlir/Interfaces/LoopLikeInterface.h       |  5 +++
 .../mlir/Interfaces/LoopLikeInterface.td      | 30 +++++++++++++++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  4 ++
 .../Linalg/Transforms/HoistPadding.cpp        |  5 +--
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 27 ++++++++------
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../SCF/Transforms/LoopCanonicalization.cpp   |  5 +--
 mlir/lib/Interfaces/LoopLikeInterface.cpp     | 37 +++++++++++++++++++
 mlir/test/Dialect/SCF/invalid.mlir            | 30 ++++++++++++++-
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |  5 +--
 17 files changed, 151 insertions(+), 42 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index a57add9f731979d..483c679eb623ce6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2106,7 +2106,7 @@ def fir_DoLoopOp : region_Op<"do_loop",
     mlir::OpBuilder getBodyBuilder() {
       return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
     }
-    mlir::Block::BlockArgListType getRegionIterArgs() {
+    mlir::Block::BlockArgListType getIterArgs() {
       return getBody()->getArguments().drop_front();
     }
     mlir::Operation::operand_range getIterOperands() {
@@ -2257,7 +2257,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
     mlir::OpBuilder getBodyBuilder() {
       return mlir::OpBuilder(getBody(), std::prev(getBody()->end()));
     }
-    mlir::Block::BlockArgListType getRegionIterArgs() {
+    mlir::Block::BlockArgListType getIterArgs() {
       return getBody()->getArguments().drop_front();
     }
     mlir::Operation::operand_range getIterOperands() {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ee838b3b4a546b9..5e702109a40039a 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -1789,7 +1789,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               /*finalCountValue=*/true,
               builder->createConvert(loc, loopVarType, lowerValue));
           builder->setInsertionPointToStart(info.doLoop.getBody());
-          loopValue = info.doLoop.getRegionIterArgs()[0];
+          loopValue = info.doLoop.getIterArgs()[0];
         }
         // Update the loop variable value in case it has non-index references.
         builder->create<fir::StoreOp>(loc, loopValue, info.loopVariable);
@@ -2105,9 +2105,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         auto lp = builder->create<fir::DoLoopOp>(
             loc, lb, ub, by, /*unordered=*/true,
             /*finalCount=*/false, explicitIterSpace.getInnerArgs());
-        if ((!loops.empty() || !outermost) && !lp.getRegionIterArgs().empty())
+        if ((!loops.empty() || !outermost) && !lp.getIterArgs().empty())
           builder->create<fir::ResultOp>(loc, lp.getResults());
-        explicitIterSpace.setInnerArgs(lp.getRegionIterArgs());
+        explicitIterSpace.setInnerArgs(lp.getIterArgs());
         builder->setInsertionPointToStart(lp.getBody());
         forceControlVariableBinding(ctrlVar, lp.getInductionVar());
         loops.push_back(lp);
diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp
index 8788e82b59a8df0..9871e4e0dedd3be 100644
--- a/flang/lib/Lower/ConvertExpr.cpp
+++ b/flang/lib/Lower/ConvertExpr.cpp
@@ -4342,7 +4342,7 @@ class ArrayExprLowering {
         loop = builder.create<fir::DoLoopOp>(
             loc, zero, i.value(), one, isUnordered(),
             /*finalCount=*/false, mlir::ValueRange{innerArg});
-        innerArg = loop.getRegionIterArgs().front();
+        innerArg = loop.getIterArgs().front();
         if (explicitSpaceIsActive())
           explicitSpace->setInnerArg(0, innerArg);
       } else {
@@ -6328,7 +6328,7 @@ class ArrayExprLowering {
     auto insPt = builder.saveInsertionPoint();
     builder.setInsertionPointToStart(loop.getBody());
     // Thread mem inside the loop via loop argument.
-    mem = loop.getRegionIterArgs()[0];
+    mem = loop.getIterArgs()[0];
 
     mlir::Type eleRefTy = builder.getRefType(eleTy);
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 962b87acd5a8050..1764c7812a78598 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1905,7 +1905,7 @@ mlir::LogicalResult fir::IterWhileOp::verify() {
     return emitOpError(
         "mismatch in number of basic block args and defined values");
   auto iterOperands = getIterOperands();
-  auto iterArgs = getRegionIterArgs();
+  auto iterArgs = getIterArgs();
   auto opResults = getFinalValue() ? getResults().drop_front() : getResults();
   unsigned i = 0u;
   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
@@ -1925,7 +1925,7 @@ void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) {
   p << " (" << getInductionVar() << " = " << getLowerBound() << " to "
     << getUpperBound() << " step " << getStep() << ") and (";
   assert(hasIterOperands());
-  auto regionArgs = getRegionIterArgs();
+  auto regionArgs = getIterArgs();
   auto operands = getIterOperands();
   p << regionArgs.front() << " = " << *operands.begin() << ")";
   if (regionArgs.size() > 1) {
@@ -2194,7 +2194,7 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
     return emitOpError(
         "mismatch in number of basic block args and defined values");
   auto iterOperands = getIterOperands();
-  auto iterArgs = getRegionIterArgs();
+  auto iterArgs = getIterArgs();
   auto opResults = getFinalValue() ? getResults().drop_front() : getResults();
   unsigned i = 0u;
   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
@@ -2218,7 +2218,7 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
     p << " unordered";
   if (hasIterOperands()) {
     p << " iter_args(";
-    auto regionArgs = getRegionIterArgs();
+    auto regionArgs = getIterArgs();
     auto operands = getIterOperands();
     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
       p << std::get<0>(it) << " = " << std::get<1>(it);
diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
index 3eddb9e61ae3b3d..7e6f668a012a0a0 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp
@@ -316,7 +316,7 @@ genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
     auto loop = builder.create<OP>(loc, zeroIdx, loopCount, step,
                                    unorderedOrInitialLoopCond,
                                    /*finalCountValue=*/false, init);
-    init = loop.getRegionIterArgs()[resultIndex];
+    init = loop.getIterArgs()[resultIndex];
     indices.push_back(loop.getInductionVar());
     // Set insertion point to the loop body so that the next loop
     // is inserted inside the current one.
@@ -422,7 +422,7 @@ genMinlocReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
     auto loop =
         builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step, false,
                                       /*finalCountValue=*/false, init);
-    init = loop.getRegionIterArgs()[0];
+    init = loop.getIterArgs()[0];
     indices.push_back(loop.getInductionVar());
     // Set insertion point to the loop body so that the next loop
     // is inserted inside the current one.
@@ -952,7 +952,7 @@ static void genRuntimeDotBody(fir::FirOpBuilder &builder,
   auto loop = builder.create<fir::DoLoopOp>(loc, zeroIdx, loopCount, step,
                                             /*unordered=*/false,
                                             /*finalCountValue=*/false, zero);
-  mlir::Value sumVal = loop.getRegionIterArgs()[0];
+  mlir::Value sumVal = loop.getIterArgs()[0];
 
   // Begin loop code
   mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d8ef0506d0822d7..6eb0a4468fd91b5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -120,7 +120,8 @@ def AffineForOp : Affine_Op<"for",
     [AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "replaceWithAdditionalYields"]>,
+      "getSingleUpperBound", "getYieldedValues",
+      "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
   let summary = "for operation";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e1a604a88715f0e..021f57f75aaa326 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -122,8 +122,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInits", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
-        "replaceWithAdditionalYields"]>,
+        "getSingleStep", "getSingleUpperBound", "getYieldedValues",
+        "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -243,9 +243,11 @@ def ForOp : SCF_Op<"for",
         function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
 
     Value getInductionVar() { return getBody()->getArgument(0); }
+
     Block::BlockArgListType getRegionIterArgs() {
       return getBody()->getArguments().drop_front(getNumInductionVars());
     }
+
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
       assert(index < getNumRegionIterArgs() &&
@@ -1086,6 +1088,11 @@ def WhileOp : SCF_Op<"while",
 
     ConditionOp getConditionOp();
     YieldOp getYieldOp();
+
+    /// Return the values that are yielded from the "before" region (by the
+    /// ConditionOp).
+    ValueRange getYieldedValues();
+
     Block::BlockArgListType getBeforeArguments();
     Block::BlockArgListType getAfterArguments();
     Block *getBeforeBody() { return &getBefore().front(); }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 0eebb984e5897ae..7c7d378d0590ab1 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -24,6 +24,11 @@ class RewriterBase;
 /// arguments in `newBbArgs`.
 using NewYieldValuesFn = std::function<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
+
+namespace detail {
+/// Verify invariants of the LoopLikeOpInterface.
+LogicalResult verifyLoopLikeOpInterface(Operation *op);
+} // namespace detail
 } // namespace mlir
 
 /// Include the generated interface declarations.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index ded0a29292ff6f0..5e3c0db965156c4 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -20,6 +20,19 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     Contains helper functions to query properties and perform transformations
     of a loop. Operations that implement this interface will be considered by
     loop-invariant code motion.
+
+    Loop-carried variables can be exposed through this interface. There are
+    3 components to a loop-carried variable.
+    - The "region iter_arg" is the block argument of the entry block that
+      represents the loop-carried variable in each iteration.
+    - The "init value" is an operand of the loop op that serves as the initial
+      region iter_arg value for the first iteration (if any).
+    - The "yielded" value is the value that is forwarded from one iteration to
+      serve as the region iter_arg of the next iteration.
+
+    If one of the respective interface methods is implemented, so must the other
+    two. The interface verifier ensures that the number of types of the region
+    iter_args, init values and yielded values match.
   }];
   let cppNamespace = "::mlir";
 
@@ -141,6 +154,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return ::mlir::Block::BlockArgListType();
       }]
     >,
+    InterfaceMethod<[{
+        Return the range of values that are yielded to the next iterations.
+      }],
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getYieldedValues",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::ValueRange();
+      }]
+    >,
     InterfaceMethod<[{
         Append the specified additional "init" operands: replace this loop with
         a new loop that has the additional init operands. The loop body of
@@ -192,6 +216,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
           });
     }
   }];
+
+  let verifyWithRegions = 1;
+
+  let verify = [{
+    return detail::verifyLoopLikeOpInterface($_op);
+  }];
 }
 
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6c060c90e24af82..4109167de44e860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2193,6 +2193,10 @@ unsigned AffineForOp::getNumIterOperands() {
   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
 }
 
+ValueRange AffineForOp::getYieldedValues() {
+  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
+}
+
 void AffineForOp::print(OpAsmPrinter &p) {
   p << ' ';
   p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 72bd2b409f5d52b..8fef99bb375095a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
   unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
-  auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
+  auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
   initArgs[iterArgNumber] = hoistedPackedTensor;
-  SmallVector<Value> yieldOperands = yieldOp.getOperands();
+  SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
   yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8d8481421e18d57..508227d6e7ce4b0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
 
   // Replace all results with the yielded values.
   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
-  rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
+  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
 
   // Replace block arguments with lower bound (replacement for IV) and
   // iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   LogicalResult matchAndRewrite(scf::ForOp forOp,
                                 PatternRewriter &rewriter) const final {
     bool canonicalize = false;
-    Block &block = forOp.getRegion().front();
-    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
 
     // An internal flat vector of block transfer
     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
     // transformed block argument mappings. This plays the role of a
     // IRMapping for the particular use case of calling into
     // `inlineBlockBefore`.
+    int64_t numResults = forOp.getNumResults();
     SmallVector<bool, 4> keepMask;
-    keepMask.reserve(yieldOp.getNumOperands());
+    keepMask.reserve(numResults);
     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
         newResultValues;
-    newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
+    newBlockTransferArgs.reserve(1 + numResults);
     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
     newIterArgs.reserve(forOp.getInitArgs().size());
-    newYieldValues.reserve(yieldOp.getNumOperands());
-    newResultValues.reserve(forOp.getNumResults());
+    newYieldValues.reserve(numResults);
+    newResultValues.reserve(numResults);
     for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
                              forOp.getRegionIterArgs(), // iter inside region
                              forOp.getResults(),        // op results
-                             yieldOp.getOperands()      // iter yield
+                             forOp.getYieldedValues()   // iter yield
                              )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       return failure();
     // If the loop is empty, iterates at least once, and only returns values
     // defined outside of the loop, remove it and replace it with yield values.
-    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
-    auto yieldOperands = yieldOp.getOperands();
-    if (llvm::any_of(yieldOperands,
+    if (llvm::any_of(op.getYieldedValues(),
                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
       return failure();
-    rewriter.replaceOp(op, yieldOperands);
+    rewriter.replaceOp(op, op.getYieldedValues());
     return success();
   }
 };
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
   return {};
 }
 
+ValueRange ForOp::getYieldedValues() {
+  return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
+}
+
 Speculation::Speculatability ForOp::getSpeculatability() {
   // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
   // and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
   return cast<YieldOp>(getAfterBody()->getTerminator());
 }
 
+ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }
+
 Block::BlockArgListType WhileOp::getBeforeArguments() {
   return getBeforeBody()->getArguments();
 }
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index bcbc693a9742ccc..22c893457c6f59e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -489,9 +489,8 @@ struct ForOpInterface
     auto forOp = cast<scf::ForOp>(op);
     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     bool equivalentYield = state.areEquivalentBufferizedValues(
-        bbArg, yieldOp->getOperand(opResult.getResultNumber()));
+        bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
     return equivalentYield ? BufferRelation::Equivalent
                            : BufferRelation::Unknown;
   }
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 0cd19fbefa8ef98..43e79d309c66784 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -36,10 +36,9 @@ using namespace mlir::scf;
 /// type of the corresponding basic block argument of the loop.
 /// Note: This function handles only simple cases. Expand as needed.
 static bool isShapePreserving(ForOp forOp, int64_t arg) {
-  auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
-  assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
+  assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
          "arg is out of bounds");
-  Value value = yieldOp.getResults()[arg];
+  Value value = forOp.getYieldedValues()[arg];
   while (value) {
     if (value == forOp.getRegionIterArgs()[arg])
       return true;
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 781a21bb3ecd3a0..15a816f4e448839 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
   }
   return false;
 }
+
+LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
+  // Note: These invariants are also verified by the RegionBranchOpInterface,
+  // but the LoopLikeOpInterface provides better error messages.
+  auto loopLikeOp = cast<LoopLikeOpInterface>(op);
+
+  // Verify number of inits/iter_args/yielded values.
+  if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
+    return op->emitOpError("different number of inits and region iter_args: ")
+           << loopLikeOp.getInits().size()
+           << " != " << loopLikeOp.getRegionIterArgs().size();
+  if (loopLikeOp.getRegionIterArgs().size() !=
+      loopLikeOp.getYieldedValues().size())
+    return op->emitOpError(
+               "different number of region iter_args and yielded values: ")
+           << loopLikeOp.getRegionIterArgs().size()
+           << " != " << loopLikeOp.getYieldedValues().size();
+
+  // Verify types of inits/iter_args/yielded values.
+  int64_t i = 0;
+  for (const auto it :
+       llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
+                       loopLikeOp.getYieldedValues())) {
+    if (std::get<0>(it).getType() != std::get<1>(it).getType())
+      op->emitOpError(std::to_string(i))
+          << "-th init and " << i << "-th region iter_arg have different type: "
+          << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+    if (std::get<1>(it).getType() != std::get<2>(it).getType())
+      op->emitOpError(std::to_string(i))
+          << "-th region iter_arg and " << i
+          << "-th yielded value have different type: "
+          << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+    ++i;
+  }
+
+  return success();
+}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index f6044ad10829227..1b2c3f563195c52 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
 
 // -----
 
+func.func @too_many_iter_args(%arg0: index, %init: f32) {
+  // expected-error @below{{different number of inits and region iter_args: 1 != 2}}
+  %x = "scf.for"(%arg0, %arg0, %arg0, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+      scf.yield %iter, %iter : f32, f32
+    }
+  ) : (index, index, index, f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @too_few_yielded_values(%arg0: index, %init: f32) {
+  // expected-error @below{{different number of region iter_args and yielded values: 2 != 1}}
+  %x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+      scf.yield %iter : f32
+    }
+  ) : (index, index, index, f32, f32) -> (f32, f32)
+  return
+}
+
+// -----
+
 func.func @loop_if_not_i1(%arg0: index) {
   // expected-error at +1 {{operand #0 must be 1-bit signless integer}}
   "scf.if"(%arg0) ({}, {}) : (index) -> ()
@@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
 func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
   %s0 = arith.constant 0.0 : f32
   %t0 = arith.constant 1.0 : f32
-  // expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
+  // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
+  // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
   %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
                     iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
     %sn = arith.addf %si, %si : f32
@@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
   return
 }
 
-
 // -----
 
 func.func @parallel_invalid_yield(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 1d40615305c02c4..565d07669792f1b 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -50,9 +50,8 @@ struct TestSCFForUtilsPass
         auto newInitValues = forOp.getInitArgs();
         if (newInitValues.empty())
           return;
-        auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
-        SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
-                                          yieldOp.getResults().end());
+        SmallVector<Value> oldYieldValues =
+            llvm::to_vector(forOp.getYieldedValues());
         NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
                                   ArrayRef<BlockArgument> newBBArgs) {
           SmallVector<Value> newYieldValues;



More information about the flang-commits mailing list