[Mlir-commits] [mlir] [mlir][emitc] Add a structured for operation (PR #68206)

Gil Rapaport llvmlistbot at llvm.org
Thu Oct 26 06:39:37 PDT 2023


https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/68206

>From 70caf277b6017c2662ee44ac19ac2ebc674c9c59 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 3 Oct 2023 20:27:02 +0300
Subject: [PATCH] [mlir][emitc] Add a structured for operation

Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.
---
 mlir/docs/Dialects/emitc.md                   |   3 -
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  64 ++++++++-
 mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 123 ++++++++++++++----
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  95 ++++++++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  85 +-----------
 mlir/test/Conversion/SCFToEmitC/for.mlir      |  96 ++++++++++++++
 mlir/test/Dialect/EmitC/invalid_ops.mlir      |   2 +-
 mlir/test/Dialect/EmitC/ops.mlir              |  24 +++-
 mlir/test/Target/Cpp/for.mlir                 |  56 ++++++--
 9 files changed, 426 insertions(+), 122 deletions(-)
 create mode 100644 mlir/test/Conversion/SCFToEmitC/for.mlir

diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md
index 4d9f04ab11c8f0c..03b85611ee3cd0e 100644
--- a/mlir/docs/Dialects/emitc.md
+++ b/mlir/docs/Dialects/emitc.md
@@ -31,8 +31,5 @@ translating the following operations:
     *   `func.constant`
     *   `func.func`
     *   `func.return`
-*   'scf' Dialect
-    *   `scf.for`
-    *   `scf.yield`
 *   'arith' Dialect
     *   `arith.constant`
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 827ffc0278fce1c..2edeb6f8a9cf01e 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -246,6 +246,67 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
   let results = (outs FloatIntegerIndexOrOpaqueType);
 }
 
+def EmitC_ForOp : EmitC_Op<"for",
+      [AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+       SingleBlockImplicitTerminator<"emitc::YieldOp">,
+       RecursiveMemoryEffects]> {
+  let summary = "for operation";
+  let description = [{
+    The `emitc.for` operation represents a C loop of the following form:
+
+    ```c++
+    for (T i = lb; i < ub; i += step) { /* ... */ } // where T is typeof(lb)
+    ```
+
+    The operation takes 3 SSA values as operands that represent the lower bound,
+    upper bound and step respectively, and defines an SSA value for its
+    induction variable. It has one region capturing the loop body. The induction
+    variable is represented as an argument of this region. This SSA value is a
+    signless integer or index. The step is a value of same type.
+
+    This operation has no result. The body region must contain exactly one block
+    that terminates with `emitc.yield`. Calling ForOp::build will create such a
+    region and insert the terminator implicitly if none is defined, so will the
+    parsing even in cases when it is absent from the custom format. For example:
+
+    ```mlir
+    // Index case.
+    emitc.for %iv = %lb to %ub step %step {
+      ... // body
+    }
+    ...
+    // Integer case.
+    emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+      ... // body
+    }
+    ```
+  }];
+  let arguments = (ins IntegerIndexOrOpaqueType:$lowerBound,
+                       IntegerIndexOrOpaqueType:$upperBound,
+                       IntegerIndexOrOpaqueType:$step);
+  let results = (outs);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
+      CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
+  ];
+
+  let extraClassDeclaration = [{
+    using BodyBuilderFn =
+        function_ref<void(OpBuilder &, Location, Value)>;
+    Value getInductionVar() { return getBody()->getArgument(0); }
+    void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
+    void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
+    void setStep(Value step) { getOperation()->setOperand(2, step); }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasCustomAssemblyFormat = 1;
+  let hasRegionVerifier = 1;
+}
+
 def EmitC_IncludeOp
     : EmitC_Op<"include", [HasParent<"ModuleOp">]> {
   let summary = "Include operation";
@@ -430,7 +491,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
   let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
 }
 
-def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
+def EmitC_YieldOp : EmitC_Op<"yield",
+      [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
   let summary = "block termination operation";
   let description = [{
     "yield" terminates blocks within EmitC control-flow operations. Since
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 5d0d8df8869e313..bf69ba503f4e6b1 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
   void runOnOperation() override;
 };
 
-// Lower scf::if to emitc::if, implementing return values as emitc::variable's
+// Lower scf::for to emitc::for, implementing result values using
+// emitc::variable's updated within the loop body.
+struct ForLowering : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+// Create an uninitialized emitc::variable op for each result of the given op.
+template <typename T>
+static SmallVector<Value> createVariablesForResults(T op,
+                                                    PatternRewriter &rewriter) {
+  SmallVector<Value> resultVariables;
+
+  if (!op.getNumResults())
+    return resultVariables;
+
+  Location loc = op->getLoc();
+  MLIRContext *context = op.getContext();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+
+  for (OpResult result : op.getResults()) {
+    Type resultType = result.getType();
+    emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+    emitc::VariableOp var =
+        rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+    resultVariables.push_back(var);
+  }
+
+  return resultVariables;
+}
+
+// Create a series of assign ops assigning given values to given variables at
+// the current insertion point of given rewriter.
+static void assignValues(ValueRange values, SmallVector<Value> &variables,
+                         PatternRewriter &rewriter, Location loc) {
+  for (auto [value, var] : llvm::zip(values, variables))
+    rewriter.create<emitc::AssignOp>(loc, var, value);
+}
+
+static void lowerYield(SmallVector<Value> &resultVariables,
+                       PatternRewriter &rewriter, scf::YieldOp yield) {
+  Location loc = yield.getLoc();
+  ValueRange operands = yield.getOperands();
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(yield);
+
+  assignValues(operands, resultVariables, rewriter, loc);
+
+  rewriter.create<emitc::YieldOp>(loc);
+  rewriter.eraseOp(yield);
+}
+
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+                                           PatternRewriter &rewriter) const {
+  Location loc = forOp.getLoc();
+
+  // Create an emitc::variable op for each result. These variables will be
+  // assigned to by emitc::assign ops within the loop body.
+  SmallVector<Value> resultVariables =
+      createVariablesForResults(forOp, rewriter);
+  SmallVector<Value> iterArgsVariables =
+      createVariablesForResults(forOp, rewriter);
+
+  assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
+
+  emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
+      loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+
+  Block *loweredBody = loweredFor.getBody();
+
+  // Erase the auto-generated terminator for the lowered for op.
+  rewriter.eraseOp(loweredBody->getTerminator());
+
+  SmallVector<Value> replacingValues;
+  replacingValues.push_back(loweredFor.getInductionVar());
+  replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+
+  rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
+  lowerYield(iterArgsVariables, rewriter,
+             cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+  // Copy iterArgs into results after the for loop.
+  assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+
+  rewriter.replaceOp(forOp, resultVariables);
+  return success();
+}
+
+// Lower scf::if to emitc::if, implementing result values as emitc::variable's
 // updated within the then and else regions.
 struct IfLowering : public OpRewritePattern<IfOp> {
   using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
                                           PatternRewriter &rewriter) const {
   Location loc = ifOp.getLoc();
 
-  SmallVector<Value> resultVariables;
-
   // Create an emitc::variable op for each result. These variables will be
   // assigned to by emitc::assign ops within the then & else regions.
-  if (ifOp.getNumResults()) {
-    MLIRContext *context = ifOp.getContext();
-    rewriter.setInsertionPoint(ifOp);
-    for (OpResult result : ifOp.getResults()) {
-      Type resultType = result.getType();
-      auto noInit = emitc::OpaqueAttr::get(context, "");
-      auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
-      resultVariables.push_back(var);
-    }
-  }
+  SmallVector<Value> resultVariables =
+      createVariablesForResults(ifOp, rewriter);
 
   // Utility function to lower the contents of an scf::if region to an emitc::if
   // region. The contents of the scf::if regions is moved into the respective
@@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
                                                    Region &loweredRegion) {
     rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
     Operation *terminator = loweredRegion.back().getTerminator();
-    Location terminatorLoc = terminator->getLoc();
-    ValueRange terminatorOperands = terminator->getOperands();
-    rewriter.setInsertionPointToEnd(&loweredRegion.back());
-    for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
-      Value resultValue = std::get<0>(value2Var);
-      Value resultVar = std::get<1>(value2Var);
-      rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
-    }
-    rewriter.create<emitc::YieldOp>(terminatorLoc);
-    rewriter.eraseOp(terminator);
+    lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
   };
 
   Region &thenRegion = ifOp.getThenRegion();
@@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
 }
 
 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ForLowering>(patterns.getContext());
   patterns.add<IfLowering>(patterns.getContext());
 }
 
@@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {
 
   // Configure conversion to lower out SCF operations.
   ConversionTarget target(getContext());
-  target.addIllegalOp<scf::IfOp>();
+  target.addIllegalOp<scf::ForOp, scf::IfOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 961a52a70a2a168..d06381b7ddad3dc 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -189,6 +189,101 @@ LogicalResult emitc::ConstantOp::verify() {
 
 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
+                  Value ub, Value step, BodyBuilderFn bodyBuilder) {
+  result.addOperands({lb, ub, step});
+  Type t = lb.getType();
+  Region *bodyRegion = result.addRegion();
+  bodyRegion->push_back(new Block);
+  Block &bodyBlock = bodyRegion->front();
+  bodyBlock.addArgument(t, result.location);
+
+  // Create the default terminator if the builder is not provided.
+  if (!bodyBuilder) {
+    ForOp::ensureTerminator(*bodyRegion, builder, result.location);
+  } else {
+    OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(&bodyBlock);
+    bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+  }
+}
+
+void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
+
+ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
+  Builder &builder = parser.getBuilder();
+  Type type;
+
+  OpAsmParser::Argument inductionVariable;
+  OpAsmParser::UnresolvedOperand lb, ub, step;
+
+  // Parse the induction variable followed by '='.
+  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
+      // Parse loop bounds.
+      parser.parseOperand(lb) || parser.parseKeyword("to") ||
+      parser.parseOperand(ub) || parser.parseKeyword("step") ||
+      parser.parseOperand(step))
+    return failure();
+
+  // Parse the optional initial iteration arguments.
+  SmallVector<OpAsmParser::Argument, 4> regionArgs;
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+  regionArgs.push_back(inductionVariable);
+
+  // Parse optional type, else assume Index.
+  if (parser.parseOptionalColon())
+    type = builder.getIndexType();
+  else if (parser.parseType(type))
+    return failure();
+
+  // Resolve input operands.
+  regionArgs.front().type = type;
+  if (parser.resolveOperand(lb, type, result.operands) ||
+      parser.resolveOperand(ub, type, result.operands) ||
+      parser.resolveOperand(step, type, result.operands))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Parse the optional attribute list.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  return success();
+}
+
+void ForOp::print(OpAsmPrinter &p) {
+  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
+    << getUpperBound() << " step " << getStep();
+
+  p << ' ';
+  if (Type t = getInductionVar().getType(); !t.isIndex())
+    p << " : " << t << ' ';
+  p.printRegion(getRegion(),
+                /*printEntryBlockArgs=*/false,
+                /*printBlockTerminators=*/false);
+  p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult ForOp::verifyRegions() {
+  // Check that the body defines as single block argument for the induction
+  // variable.
+  if (getInductionVar().getType() != getLowerBound().getType())
+    return emitOpError(
+        "expected induction variable to be same type as bounds and step");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // IfOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4645ca4b206e78c..8ffea4d5b7b3248 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -502,30 +501,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
 
   raw_indented_ostream &os = emitter.ostream();
 
-  OperandRange operands = forOp.getInitArgs();
-  Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
-  Operation::result_range results = forOp.getResults();
-
-  if (!emitter.shouldDeclareVariablesAtTop()) {
-    for (OpResult result : results) {
-      if (failed(emitter.emitVariableDeclaration(result,
-                                                 /*trailingSemicolon=*/true)))
-        return failure();
-    }
-  }
-
-  for (auto pair : llvm::zip(iterArgs, operands)) {
-    if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
-      return failure();
-    os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
-    os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
-    os << "\n";
-  }
-
   os << "for (";
   if (failed(
           emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
@@ -548,35 +527,14 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
   Region &forRegion = forOp.getRegion();
   auto regionOps = forRegion.getOps();
 
-  // We skip the trailing yield op because this updates the result variables
-  // of the for op in the generated code. Instead we update the iterArgs at
-  // the end of a loop iteration and set the result variables after the for
-  // loop.
+  // We skip the trailing yield op.
   for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
     if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
       return failure();
   }
 
-  Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
-  // Copy yield operands into iterArgs at the end of a loop iteration.
-  for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
-    BlockArgument iterArg = std::get<0>(pair);
-    Value operand = std::get<1>(pair);
-    os << emitter.getOrCreateName(iterArg) << " = "
-       << emitter.getOrCreateName(operand) << ";\n";
-  }
-
   os.unindent() << "}";
 
-  // Copy iterArgs into results after the for loop.
-  for (auto pair : llvm::zip(results, iterArgs)) {
-    OpResult result = std::get<0>(pair);
-    BlockArgument iterArg = std::get<1>(pair);
-    os << "\n"
-       << emitter.getOrCreateName(result) << " = "
-       << emitter.getOrCreateName(iterArg) << ";";
-  }
-
   return success();
 }
 
@@ -617,33 +575,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
-  raw_ostream &os = emitter.ostream();
-  Operation &parentOp = *yieldOp.getOperation()->getParentOp();
-
-  if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
-    return yieldOp.emitError("number of operands does not to match the number "
-                             "of the parent op's results");
-  }
-
-  if (failed(interleaveWithError(
-          llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
-          [&](auto pair) -> LogicalResult {
-            auto result = std::get<0>(pair);
-            auto operand = std::get<1>(pair);
-            os << emitter.getOrCreateName(result) << " = ";
-
-            if (!emitter.hasValueInScope(operand))
-              return yieldOp.emitError("operand value not in scope");
-            os << emitter.getOrCreateName(operand);
-            return success();
-          },
-          [&]() { os << ";\n"; })))
-    return failure();
-
-  return success();
-}
-
 static LogicalResult printOperation(CppEmitter &emitter,
                                     func::ReturnOp returnOp) {
   raw_ostream &os = emitter.ostream();
@@ -748,10 +679,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
     for (Operation &op : block.getOperations()) {
       // When generating code for an emitc.if or cf.cond_br op no semicolon
       // needs to be printed after the closing brace.
-      // When generating code for an scf.for op, printing a trailing semicolon
+      // When generating code for an emitc.for op, printing a trailing semicolon
       // is handled within the printOperation function.
       bool trailingSemicolon =
-          !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, scf::ForOp>(op);
+          !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp, emitc::LiteralOp>(
+              op);
 
       if (failed(emitter.emitOperation(
               op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1015,15 +947,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
           // EmitC ops.
           .Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
                 emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
-                emitc::IfOp, emitc::IncludeOp, emitc::MulOp, emitc::RemOp,
-                emitc::SubOp, emitc::VariableOp>(
+                emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
+                emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
               [&](auto op) { return printOperation(*this, op); })
-          // SCF ops.
-          .Case<scf::ForOp, scf::YieldOp>(
-              [&](auto op) { return printOperation(*this, op); })
           // Arithmetic ops.
           .Case<arith::ConstantOp>(
               [&](auto op) { return printOperation(*this, op); })
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
new file mode 100644
index 000000000000000..7f90310af218942
--- /dev/null
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
+
+func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
+  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+    %c1 = arith.constant 1 : index
+  }
+  return
+}
+// CHECK-LABEL: func.func @simple_std_for_loop(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT:    emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
+  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+    %c1 = arith.constant 1 : index
+    scf.for %i1 = %arg0 to %arg1 step %arg2 {
+      %c1_0 = arith.constant 1 : index
+    }
+  }
+  return
+}
+// CHECK-LABEL: func.func @simple_std_2_for_loops(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT:    emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT:      emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:        %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-NEXT:      }
+// CHECK-NEXT:    }
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+
+func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
+  %s0 = arith.constant 0.0 : f32
+  %s1 = arith.constant 1.0 : f32
+  %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+    %sn = arith.addf %si, %sj : f32
+    scf.yield %sn, %sn : f32, f32
+  }
+  return %result#0, %result#1 : f32, f32
+}
+// CHECK-LABEL: func.func @for_yield(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
+// CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:    emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
+// CHECK-NEXT:    return %[[VAL_5]], %[[VAL_6]] : f32, f32
+// CHECK-NEXT:  }
+
+func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
+  %s0 = arith.constant 1.0 : f32
+  %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) {
+    %result = scf.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) {
+      %sn = arith.addf %si, %si : f32
+      scf.yield %sn : f32
+    }
+    scf.yield %result : f32
+  }
+  return %r : f32
+}
+// CHECK-LABEL: func.func @nested_for_yield(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
+// CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT:    %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:      %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT:      emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:      emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:        %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
+// CHECK-NEXT:        emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:      }
+// CHECK-NEXT:      emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT:      emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:    }
+// CHECK-NEXT:    emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
+// CHECK-NEXT:    return %[[VAL_4]] : f32
+// CHECK-NEXT:  }
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 9e8f0bf0bf8bdcd..53d88adf4305ff8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
 // -----
 
 func.func @test_misplaced_yield() {
-  // expected-error @+1 {{'emitc.yield' op expects parent op 'emitc.if'}}
+  // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}}
   emitc.yield
   return
 }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 0817945e3b1e0bc..6c8398680980466 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -105,7 +105,7 @@ func.func @test_if(%arg0: i1, %arg1: f32) {
   return
 }
 
-func.func @test_explicit_yield(%arg0: i1, %arg1: f32) {
+func.func @test_if_explicit_yield(%arg0: i1, %arg1: f32) {
   emitc.if %arg0 {
      %0 = emitc.call "func_const"(%arg1) : (f32) -> i32
      emitc.yield
@@ -127,3 +127,25 @@ func.func @test_assign(%arg1: f32) {
   emitc.assign %arg1 : f32 to %v : f32
   return
 }
+
+func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+    %0 = emitc.call "func_const"(%i0) : (index) -> i32
+  }
+  return
+}
+
+func.func @test_for_explicit_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+    %0 = emitc.call "func_const"(%i0) : (index) -> i32
+    emitc.yield
+  }
+  return
+}
+
+func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 : i16 {
+    %0 = emitc.call "func_const"(%i0) : (i16) -> i32
+  }
+  return
+}
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index e904c99820ad846..c02c8b1ac33e371 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -2,7 +2,7 @@
 // RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
 
 func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
-  scf.for %i0 = %arg0 to %arg1 step %arg2 {
+  emitc.for %i0 = %arg0 to %arg1 step %arg2 {
     %0 = emitc.call "f"() : () -> i32
   }
   return
@@ -28,11 +28,21 @@ func.func @test_for_yield() {
   %s0 = arith.constant 0 : i32
   %p0 = arith.constant 1.0 : f32
 
-  %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
-    %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
-    %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
-    scf.yield %sn, %pn : i32, f32
+  %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  emitc.assign %s0 : i32 to %2 : i32
+  emitc.assign %p0 : f32 to %3 : f32
+  emitc.for %iter = %start to %stop step %step {
+    %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+    %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+    emitc.assign %sn : i32 to %2 : i32
+    emitc.assign %pn : f32 to %3 : f32
+    emitc.yield
   }
+  emitc.assign %2 : i32 to %0 : i32
+  emitc.assign %3 : f32 to %1 : f32
 
   return
 }
@@ -44,8 +54,10 @@ func.func @test_for_yield() {
 // CPP-DEFAULT-NEXT: float [[P0:[^ ]*]] = (float)1.000000000e+00;
 // CPP-DEFAULT-NEXT: int32_t [[SE:[^ ]*]];
 // CPP-DEFAULT-NEXT: float [[PE:[^ ]*]];
-// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]];
+// CPP-DEFAULT-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DEFAULT-NEXT: [[PI:[^ ]*]] = [[P0]];
 // CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
 // CPP-DEFAULT-NEXT: int32_t [[SN:[^ ]*]] = add([[SI]], [[ITER]]);
 // CPP-DEFAULT-NEXT: float [[PN:[^ ]*]] = mul([[PI]], [[ITER]]);
@@ -64,6 +76,8 @@ func.func @test_for_yield() {
 // CPP-DECLTOP-NEXT: float [[P0:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t [[SE:[^ ]*]];
 // CPP-DECLTOP-NEXT: float [[PE:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t [[SN:[^ ]*]];
 // CPP-DECLTOP-NEXT: float [[PN:[^ ]*]];
 // CPP-DECLTOP-NEXT: [[START]] = 0;
@@ -71,8 +85,12 @@ func.func @test_for_yield() {
 // CPP-DECLTOP-NEXT: [[STEP]] = 1;
 // CPP-DECLTOP-NEXT: [[S0]] = 0;
 // CPP-DECLTOP-NEXT: [[P0]] = (float)1.000000000e+00;
-// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DECLTOP-NEXT: [[PI:[^ ]*]] = [[P0]];
 // CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
 // CPP-DECLTOP-NEXT: [[SN]] = add([[SI]], [[ITER]]);
 // CPP-DECLTOP-NEXT: [[PN]] = mul([[PI]], [[ITER]]);
@@ -91,14 +109,24 @@ func.func @test_for_yield_2() {
   %s0 = emitc.literal "0" : i32
   %p0 = emitc.literal "M_PI" : f32
 
-  %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
-    %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
-    %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
-    scf.yield %sn, %pn : i32, f32
+  %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+  %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  emitc.assign %s0 : i32 to %2 : i32
+  emitc.assign %p0 : f32 to %3 : f32
+  emitc.for %iter = %start to %stop step %step {
+    %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+    %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+    emitc.assign %sn : i32 to %2 : i32
+    emitc.assign %pn : f32 to %3 : f32
+    emitc.yield
   }
+  emitc.assign %2 : i32 to %0 : i32
+  emitc.assign %3 : f32 to %1 : f32
 
   return
 }
 // CPP-DEFAULT: void test_for_yield_2() {
-// CPP-DEFAULT: float{{.*}}= M_PI
+// CPP-DEFAULT: {{.*}}= M_PI
 // CPP-DEFAULT: for (size_t [[IN:.*]] = 0; [[IN]] < 10; [[IN]] += 1) {



More information about the Mlir-commits mailing list