[Mlir-commits] [mlir] [mlir][emitc] Add a structured for operation (PR #68206)

Gil Rapaport llvmlistbot at llvm.org
Wed Oct 4 04:07:34 PDT 2023


https://github.com/aniragil created https://github.com/llvm/llvm-project/pull/68206

Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.


>From b5c03d16ae6c3176bcaf34a0d2e091c41c3d3ff3 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 3 Oct 2023 20:27:02 +0300
Subject: [PATCH] [mlir][emitc] Add a structured for operation

Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.h    |   1 +
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  79 ++++++++-
 mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 150 ++++++++++++++----
 mlir/lib/Dialect/EmitC/IR/CMakeLists.txt      |   1 +
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 142 +++++++++++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  82 +---------
 mlir/test/Conversion/SCFToEmitC/for.mlir      |  96 +++++++++++
 mlir/test/Dialect/EmitC/invalid_ops.mlir      |   2 +-
 mlir/test/Dialect/EmitC/ops.mlir              |  24 ++-
 mlir/test/Target/Cpp/for.mlir                 |  56 +++++--
 10 files changed, 506 insertions(+), 127 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToEmitC/for.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 4dff26e23c42850..9d2ec0f41a75568 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/CastInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 827ffc0278fce1c..381247bcdadf27d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
 
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 //===----------------------------------------------------------------------===//
@@ -430,7 +431,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
   let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
 }
 
-def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
+def EmitC_YieldOp : EmitC_Op<"yield",
+      [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
   let summary = "block termination operation";
   let description = [{
     "yield" terminates blocks within EmitC control-flow operations. Since
@@ -444,6 +446,81 @@ def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]>
   let assemblyFormat = [{ attr-dict }];
 }
 
+def EmitC_ForOp : EmitC_Op<"for",
+      [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
+       ["getInits", "getSingleInductionVar", "getSingleLowerBound",
+        "getSingleStep", "getSingleUpperBound"]>,
+       AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+       DeclareOpInterfaceMethods<RegionBranchOpInterface,
+        ["getEntrySuccessorOperands"]>,
+       SingleBlockImplicitTerminator<"emitc::YieldOp">,
+       RecursiveMemoryEffects]> {
+  let summary = "for operation";
+  let description = [{
+    The `emitc.for` operation represents a loop taking 3 SSA values as operands
+    that represent the lower bound, upper bound and step respectively. The
+    operation defines an SSA value for its induction variable. It has one
+    region capturing the loop body. The induction variable is represented as an
+    argument of this region. This SSA value is a signless integer or index.
+    The step is a value of same type but required to be positive. The lower and
+    upper bounds specify a half-open range: the range includes the lower bound
+    but does not include the upper bound. This operation has no result.
+
+    The body region must contain exactly one block that terminates with
+    `emitc.yield`. Calling ForOp::build will create such a region and insert
+    the terminator implicitly if none is defined, so will the parsing even in
+    cases when it is absent from the custom format. For example:
+
+    ```mlir
+    // Index case.
+    emitc.for %iv = %lb to %ub step %step {
+      ... // body
+    }
+    ...
+    // Integer case.
+    emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+      ... // body
+    }
+    ```
+  }];
+  let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
+                       AnySignlessIntegerOrIndex:$upperBound,
+                       AnySignlessIntegerOrIndex:$step);
+  let results = (outs);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
+      CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
+  ];
+
+  let extraClassDeclaration = [{
+    using BodyBuilderFn =
+        function_ref<void(OpBuilder &, Location, Value)>;
+    Value getInductionVar() { return getBody()->getArgument(0); }
+    Block::BlockArgListType getRegionIterArgs() {
+      return Block::BlockArgListType();
+    }
+    FailureOr<LoopLikeOpInterface> replaceWithAdditionalYields(
+        RewriterBase &rewriter, ValueRange newInitOperands,
+        bool replaceInitOperandUsesInLoop,
+        const NewYieldValuesFn &newYieldValuesFn) {
+      return LoopLikeOpInterfaceTrait::replaceWithAdditionalYields(
+          rewriter, newInitOperands, replaceInitOperandUsesInLoop,
+          newYieldValuesFn);
+    }
+    void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
+    void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
+    void setStep(Value step) { getOperation()->setOperand(2, step); }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+}
+
 def EmitC_IfOp : EmitC_Op<"if",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 5d0d8df8869e313..22041344ba3d9a3 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -37,6 +37,106 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
   void runOnOperation() override;
 };
 
+// Lower scf::for to emitc::for, implementing return values using
+// emitc::variable's updated within loop body.
+struct ForLowering : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+// Create an uninitialized emitc::variable op for each result of given op.
+template <typename T>
+static SmallVector<Value> createVariablesForResults(T op,
+                                                    PatternRewriter &rewriter) {
+  SmallVector<Value> resultVariables;
+
+  if (!op.getNumResults())
+    return resultVariables;
+
+  Location loc = op->getLoc();
+  MLIRContext *context = op.getContext();
+
+  auto insertionPoint = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPoint(op);
+
+  for (OpResult result : op.getResults()) {
+    Type resultType = result.getType();
+    auto noInit = emitc::OpaqueAttr::get(context, "");
+    auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+    resultVariables.push_back(var);
+  }
+
+  rewriter.restoreInsertionPoint(insertionPoint);
+
+  return resultVariables;
+}
+
+// Create a series of assign ops assigning given values to given variables at
+// the current insertion point of given rewriter.
+static void assignValues(ValueRange values, SmallVector<Value> &variables,
+                         PatternRewriter &rewriter, Location loc) {
+  for (auto value2Var : llvm::zip(values, variables)) {
+    Value value = std::get<0>(value2Var);
+    Value var = std::get<1>(value2Var);
+    rewriter.create<emitc::AssignOp>(loc, var, value);
+  }
+}
+
+static void lowerYield(SmallVector<Value> &resultVariables,
+                       PatternRewriter &rewriter, scf::YieldOp yield) {
+  Location loc = yield.getLoc();
+  ValueRange operands = yield.getOperands();
+
+  auto insertionPoint = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPoint(yield);
+
+  assignValues(operands, resultVariables, rewriter, loc);
+
+  rewriter.create<emitc::YieldOp>(loc);
+  rewriter.restoreInsertionPoint(insertionPoint);
+  rewriter.eraseOp(yield);
+}
+
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+                                           PatternRewriter &rewriter) const {
+  Location loc = forOp.getLoc();
+
+  rewriter.setInsertionPoint(forOp);
+
+  // Create an emitc::variable op for each result. These variables will be
+  // assigned to by emitc::assign ops within the loop body.
+  SmallVector<Value> resultVariables =
+      createVariablesForResults(forOp, rewriter);
+  SmallVector<Value> iterArgsVariables =
+      createVariablesForResults(forOp, rewriter);
+
+  assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
+
+  auto loweredFor = rewriter.create<emitc::ForOp>(
+      loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+
+  Block *loweredBody = loweredFor.getBody();
+
+  // Erase the auto-generated terminator for the lowered for op.
+  rewriter.eraseOp(loweredBody->getTerminator());
+
+  SmallVector<Value> replacingValues;
+  replacingValues.push_back(loweredFor.getInductionVar());
+  replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+
+  rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
+  lowerYield(iterArgsVariables, rewriter,
+             cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+  // Copy iterArgs into results after the for loop.
+  assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+
+  rewriter.replaceOp(forOp, resultVariables);
+  return success();
+}
+
 // Lower scf::if to emitc::if, implementing return values as emitc::variable's
 // updated within the then and else regions.
 struct IfLowering : public OpRewritePattern<IfOp> {
@@ -52,41 +152,22 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
                                           PatternRewriter &rewriter) const {
   Location loc = ifOp.getLoc();
 
-  SmallVector<Value> resultVariables;
-
   // Create an emitc::variable op for each result. These variables will be
   // assigned to by emitc::assign ops within the then & else regions.
-  if (ifOp.getNumResults()) {
-    MLIRContext *context = ifOp.getContext();
-    rewriter.setInsertionPoint(ifOp);
-    for (OpResult result : ifOp.getResults()) {
-      Type resultType = result.getType();
-      auto noInit = emitc::OpaqueAttr::get(context, "");
-      auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
-      resultVariables.push_back(var);
-    }
-  }
-
-  // Utility function to lower the contents of an scf::if region to an emitc::if
-  // region. The contents of the scf::if regions is moved into the respective
-  // emitc::if regions, but the scf::yield is replaced not only with an
-  // emitc::yield, but also with a sequence of emitc::assign ops that set the
-  // yielded values into the result variables.
-  auto lowerRegion = [&resultVariables, &rewriter](Region &region,
-                                                   Region &loweredRegion) {
-    rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
-    Operation *terminator = loweredRegion.back().getTerminator();
-    Location terminatorLoc = terminator->getLoc();
-    ValueRange terminatorOperands = terminator->getOperands();
-    rewriter.setInsertionPointToEnd(&loweredRegion.back());
-    for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
-      Value resultValue = std::get<0>(value2Var);
-      Value resultVar = std::get<1>(value2Var);
-      rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
-    }
-    rewriter.create<emitc::YieldOp>(terminatorLoc);
-    rewriter.eraseOp(terminator);
-  };
+  SmallVector<Value> resultVariables =
+      createVariablesForResults(ifOp, rewriter);
+
+   // Utility function to lower the contents of an scf::if region to an emitc::if
+   // region. The contents of the scf::if regions is moved into the respective
+   // emitc::if regions, but the scf::yield is replaced not only with an
+   // emitc::yield, but also with a sequence of emitc::assign ops that set the
+   // yielded values into the result variables.
+   auto lowerRegion = [&resultVariables, &rewriter](Region &region,
+                                                    Region &loweredRegion) {
+     rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
+     Operation *terminator = loweredRegion.back().getTerminator();
+     lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
+   };
 
   Region &thenRegion = ifOp.getThenRegion();
   Region &elseRegion = ifOp.getElseRegion();
@@ -109,6 +190,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
 }
 
 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ForLowering>(patterns.getContext());
   patterns.add<IfLowering>(patterns.getContext());
 }
 
@@ -118,7 +200,7 @@ void SCFToEmitCPass::runOnOperation() {
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::IfOp>();
+  target.addIllegalOp<scf::IfOp, scf::ForOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
index 4665c41a62e80b8..50e79d22d57e681 100644
--- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
@@ -12,5 +12,6 @@ add_mlir_dialect_library(MLIREmitCDialect
   MLIRCastInterfaces
   MLIRControlFlowInterfaces
   MLIRIR
+  MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 961a52a70a2a168..740504cc9db2bcd 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -189,6 +190,147 @@ LogicalResult emitc::ConstantOp::verify() {
 
 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
+                  Value ub, Value step, BodyBuilderFn bodyBuilder) {
+  result.addOperands({lb, ub, step});
+  Type t = lb.getType();
+  Region *bodyRegion = result.addRegion();
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  bodyBlock.addArgument(t, result.location);
+
+  // Create the default terminator if the builder is not provided.
+  if (!bodyBuilder) {
+    ForOp::ensureTerminator(*bodyRegion, builder, result.location);
+  } else if (bodyBuilder) {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(&bodyBlock);
+    bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+  }
+}
+
+void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
+
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+  return RegionBranchOpInterfaceTrait::getEntrySuccessorOperands(point);
+}
+
+OperandRange ForOp::getInits() { return LoopLikeOpInterfaceTrait::getInits(); }
+
+SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> ForOp::getSingleInductionVar() {
+  return getInductionVar();
+}
+
+std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
+  return OpFoldResult(getLowerBound());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleStep() {
+  return OpFoldResult(getStep());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
+  return OpFoldResult(getUpperBound());
+}
+
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
+                                SmallVectorImpl<RegionSuccessor> &regions) {
+  // Both the operation itself and the region may be branching into the body or
+  // back into the operation itself. It is possible for loop not to enter the
+  // body.
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+  regions.push_back(RegionSuccessor({}));
+}
+
+ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
+  auto &builder = parser.getBuilder();
+  Type type;
+
+  OpAsmParser::Argument inductionVariable;
+  OpAsmParser::UnresolvedOperand lb, ub, step;
+
+  // Parse the induction variable followed by '='.
+  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
+      // Parse loop bounds.
+      parser.parseOperand(lb) || parser.parseKeyword("to") ||
+      parser.parseOperand(ub) || parser.parseKeyword("step") ||
+      parser.parseOperand(step))
+    return failure();
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+  regionArgs.push_back(inductionVariable);
+
+  if (regionArgs.size() != result.types.size() + 1)
+    return parser.emitError(
+        parser.getNameLoc(),
+        "mismatch in number of loop-carried values and defined values");
+
+  // Parse optional type, else assume Index.
+  if (parser.parseOptionalColon())
+    type = builder.getIndexType();
+  else if (parser.parseType(type))
+    return failure();
+
+  // Resolve input operands.
+  regionArgs.front().type = type;
+  if (parser.resolveOperand(lb, type, result.operands) ||
+      parser.resolveOperand(ub, type, result.operands) ||
+      parser.resolveOperand(step, type, result.operands))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+void ForOp::print(OpAsmPrinter &p) {
+  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
+    << getUpperBound() << " step " << getStep();
+
+  p << ' ';
+  if (Type t = getInductionVar().getType(); !t.isIndex())
+    p << " : " << t << ' ';
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/false);
+  p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult ForOp::verify() {
+  IntegerAttr step;
+  if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
+    return emitOpError("constant step operand must be positive");
+
+  return success();
+}
+
+LogicalResult ForOp::verifyRegions() {
+  // Check that the body defines as single block argument for the induction
+  // variable.
+  if (getInductionVar().getType() != getLowerBound().getType())
+    return emitOpError(
+        "expected induction variable to be same type as bounds and step");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // IfOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4645ca4b206e78c..05d472ba50d070f 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -502,30 +502,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
 
   raw_indented_ostream &os = emitter.ostream();
 
-  OperandRange operands = forOp.getInitArgs();
-  Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
-  Operation::result_range results = forOp.getResults();
-
-  if (!emitter.shouldDeclareVariablesAtTop()) {
-    for (OpResult result : results) {
-      if (failed(emitter.emitVariableDeclaration(result,
-                                                 /*trailingSemicolon=*/true)))
-        return failure();
-    }
-  }
-
-  for (auto pair : llvm::zip(iterArgs, operands)) {
-    if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
-      return failure();
-    os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
-    os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
-    os << "\n";
-  }
-
   os << "for (";
   if (failed(
           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
@@ -548,35 +528,14 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
   Region &forRegion = forOp.getRegion();
   auto regionOps = forRegion.getOps();
 
-  // We skip the trailing yield op because this updates the result variables
-  // of the for op in the generated code. Instead we update the iterArgs at
-  // the end of a loop iteration and set the result variables after the for
-  // loop.
+  // We skip the trailing yield op.
   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
       return failure();
   }
 
-  Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
-  // Copy yield operands into iterArgs at the end of a loop iteration.
-  for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
-    BlockArgument iterArg = std::get<0>(pair);
-    Value operand = std::get<1>(pair);
-    os << emitter.getOrCreateName(iterArg) << " = "
-       << emitter.getOrCreateName(operand) << ";\n";
-  }
-
   os.unindent() << "}";
 
-  // Copy iterArgs into results after the for loop.
-  for (auto pair : llvm::zip(results, iterArgs)) {
-    OpResult result = std::get<0>(pair);
-    BlockArgument iterArg = std::get<1>(pair);
-    os << "\n"
-       << emitter.getOrCreateName(result) << " = "
-       << emitter.getOrCreateName(iterArg) << ";";
-  }
-
   return success();
 }
 
@@ -617,33 +576,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
-  raw_ostream &os = emitter.ostream();
-  Operation &parentOp = *yieldOp.getOperation()->getParentOp();
-
-  if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
-    return yieldOp.emitError("number of operands does not to match the number "
-                             "of the parent op's results");
-  }
-
-  if (failed(interleaveWithError(
-          llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
-          [&](auto pair) -> LogicalResult {
-            auto result = std::get<0>(pair);
-            auto operand = std::get<1>(pair);
-            os << emitter.getOrCreateName(result) << " = ";
-
-            if (!emitter.hasValueInScope(operand))
-              return yieldOp.emitError("operand value not in scope");
-            os << emitter.getOrCreateName(operand);
-            return success();
-          },
-          [&]() { os << ";\n"; })))
-    return failure();
-
-  return success();
-}
-
 static LogicalResult printOperation(CppEmitter &emitter,
                                     func::ReturnOp returnOp) {
   raw_ostream &os = emitter.ostream();
@@ -751,7 +683,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
       // When generating code for an scf.for op, printing a trailing semicolon
       // is handled within the printOperation function.
       bool trailingSemicolon =
-          !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, scf::ForOp>(op);
+          !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, emitc::ForOp>(
+              op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1015,15 +948,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           // EmitC ops.
           .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
                 emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
-                emitc::IfOp, emitc::IncludeOp, emitc::MulOp, emitc::RemOp,
-                emitc::SubOp, emitc::VariableOp>(
+                emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
+                emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
               [&](auto op) { return printOperation(*this, op); })
-          // SCF ops.
-          .Case<scf::ForOp, scf::YieldOp>(
-              [&](auto op) { return printOperation(*this, op); })
           // Arithmetic ops.
           .Case<arith::ConstantOp>(
               [&](auto op) { return printOperation(*this, op); })
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
new file mode 100644
index 000000000000000..7f90310af218942
--- /dev/null
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
+
+func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
+  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+    %c1 = arith.constant 1 : index
+  }
+  return
+}
+// CHECK-LABEL: func.func @simple_std_for_loop(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT:    emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
+  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+    %c1 = arith.constant 1 : index
+    scf.for %i1 = %arg0 to %arg1 step %arg2 {
+      %c1_0 = arith.constant 1 : index
+    }
+  }
+  return
+}
+// CHECK-LABEL: func.func @simple_std_2_for_loops(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT:    emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT:      emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:        %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
+  %s0 = arith.constant 0.0 : f32
+  %s1 = arith.constant 1.0 : f32
+  %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+    %sn = arith.addf %si, %sj : f32
+    scf.yield %sn, %sn : f32, f32
+  }
+  return %result#0, %result#1 : f32, f32
+}
+// CHECK-LABEL: func.func @for_yield(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
+// CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:    emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
+// CHECK-NEXT:    return %[[VAL_5]], %[[VAL_6]] : f32, f32
+// CHECK-NEXT:  }
+
+func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
+  %s0 = arith.constant 1.0 : f32
+  %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) {
+    %result = scf.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) {
+      %sn = arith.addf %si, %si : f32
+      scf.yield %sn : f32
+    }
+    scf.yield %result : f32
+  }
+  return %r : f32
+}
+// CHECK-LABEL: func.func @nested_for_yield(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
+// CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:      %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:      emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:      emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:        %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
+// CHECK-NEXT:        emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:      }
+// CHECK-NEXT:      emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
+// CHECK-NEXT:    return %[[VAL_4]] : f32
+// CHECK-NEXT:  }
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 9e8f0bf0bf8bdcd..53d88adf4305ff8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
 // -----
 
 func.func @test_misplaced_yield() {
-  // expected-error @+1 {{'emitc.yield' op expects parent op 'emitc.if'}}
+  // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}}
   emitc.yield
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 0817945e3b1e0bc..6c8398680980466 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -105,7 +105,7 @@ func.func @test_if(%arg0: i1, %arg1: f32) {
   return
 }
 
-func.func @test_explicit_yield(%arg0: i1, %arg1: f32) {
+func.func @test_if_explicit_yield(%arg0: i1, %arg1: f32) {
   emitc.if %arg0 {
      %0 = emitc.call "func_const"(%arg1) : (f32) -> i32
      emitc.yield
@@ -127,3 +127,25 @@ func.func @test_assign(%arg1: f32) {
   emitc.assign %arg1 : f32 to %v : f32
   return
 }
+
+func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+    %0 = emitc.call "func_const"(%i0) : (index) -> i32
+  }
+  return
+}
+
+func.func @test_for_explicit_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+    %0 = emitc.call "func_const"(%i0) : (index) -> i32
+    emitc.yield
+  }
+  return
+}
+
+func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 : i16 {
+    %0 = emitc.call "func_const"(%i0) : (i16) -> i32
+  }
+  return
+}
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index e904c99820ad846..c02c8b1ac33e371 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -2,7 +2,7 @@
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
 
 func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
-  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
     %0 = emitc.call "f"() : () -> i32
   }
   return
@@ -28,11 +28,21 @@ func.func @test_for_yield() {
   %s0 = arith.constant 0 : i32
   %p0 = arith.constant 1.0 : f32
 
-  %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
-    %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
-    %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
-    scf.yield %sn, %pn : i32, f32
+  %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  emitc.assign %s0 : i32 to %2 : i32
+  emitc.assign %p0 : f32 to %3 : f32
+  emitc.for %iter = %start to %stop step %step {
+    %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+    %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+    emitc.assign %sn : i32 to %2 : i32
+    emitc.assign %pn : f32 to %3 : f32
+    emitc.yield
   }
+  emitc.assign %2 : i32 to %0 : i32
+  emitc.assign %3 : f32 to %1 : f32
 
   return
 }
@@ -44,8 +54,10 @@ func.func @test_for_yield() {
 // CPP-DEFAULT-NEXT: float [[P0:[^ ]*]] = (float)1.000000000e+00;
 // CPP-DEFAULT-NEXT: int32_t [[SE:[^ ]*]];
 // CPP-DEFAULT-NEXT: float [[PE:[^ ]*]];
-// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]];
+// CPP-DEFAULT-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DEFAULT-NEXT: [[PI:[^ ]*]] = [[P0]];
 // CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
 // CPP-DEFAULT-NEXT: int32_t [[SN:[^ ]*]] = add([[SI]], [[ITER]]);
 // CPP-DEFAULT-NEXT: float [[PN:[^ ]*]] = mul([[PI]], [[ITER]]);
@@ -64,6 +76,8 @@ func.func @test_for_yield() {
 // CPP-DECLTOP-NEXT: float [[P0:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t [[SE:[^ ]*]];
 // CPP-DECLTOP-NEXT: float [[PE:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t [[SN:[^ ]*]];
 // CPP-DECLTOP-NEXT: float [[PN:[^ ]*]];
 // CPP-DECLTOP-NEXT: [[START]] = 0;
@@ -71,8 +85,12 @@ func.func @test_for_yield() {
 // CPP-DECLTOP-NEXT: [[STEP]] = 1;
 // CPP-DECLTOP-NEXT: [[S0]] = 0;
 // CPP-DECLTOP-NEXT: [[P0]] = (float)1.000000000e+00;
-// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DECLTOP-NEXT: [[PI:[^ ]*]] = [[P0]];
 // CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
 // CPP-DECLTOP-NEXT: [[SN]] = add([[SI]], [[ITER]]);
 // CPP-DECLTOP-NEXT: [[PN]] = mul([[PI]], [[ITER]]);
@@ -91,14 +109,24 @@ func.func @test_for_yield_2() {
   %s0 = emitc.literal "0" : i32
   %p0 = emitc.literal "M_PI" : f32
 
-  %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
-    %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
-    %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
-    scf.yield %sn, %pn : i32, f32
+  %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  emitc.assign %s0 : i32 to %2 : i32
+  emitc.assign %p0 : f32 to %3 : f32
+  emitc.for %iter = %start to %stop step %step {
+    %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+    %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+    emitc.assign %sn : i32 to %2 : i32
+    emitc.assign %pn : f32 to %3 : f32
+    emitc.yield
   }
+  emitc.assign %2 : i32 to %0 : i32
+  emitc.assign %3 : f32 to %1 : f32
 
   return
 }
 // CPP-DEFAULT: void test_for_yield_2() {
-// CPP-DEFAULT: float{{.*}}= M_PI
+// CPP-DEFAULT: {{.*}}= M_PI
 // CPP-DEFAULT: for (size_t [[IN:.*]] = 0; [[IN]] < 10; [[IN]] += 1) {



More information about the Mlir-commits mailing list