[Mlir-commits] [mlir] [mlir][emitc] Add a structured for operation (PR #68206)
Gil Rapaport
llvmlistbot at llvm.org
Wed Oct 4 04:53:48 PDT 2023
https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/68206
>From 20a8e1bba49b1355932ba86ee3f010aa6bc2596f Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Tue, 3 Oct 2023 20:27:02 +0300
Subject: [PATCH] [mlir][emitc] Add a structured for operation
Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 1 +
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 79 +++++++++-
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 128 +++++++++++++---
mlir/lib/Dialect/EmitC/IR/CMakeLists.txt | 1 +
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 142 ++++++++++++++++++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 82 +---------
mlir/test/Conversion/SCFToEmitC/for.mlir | 96 ++++++++++++
mlir/test/Dialect/EmitC/invalid_ops.mlir | 2 +-
mlir/test/Dialect/EmitC/ops.mlir | 24 ++-
mlir/test/Target/Cpp/for.mlir | 56 +++++--
10 files changed, 495 insertions(+), 116 deletions(-)
create mode 100644 mlir/test/Conversion/SCFToEmitC/for.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 4dff26e23c42850..9d2ec0f41a75568 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 827ffc0278fce1c..381247bcdadf27d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -430,7 +431,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
}
-def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
+def EmitC_YieldOp : EmitC_Op<"yield",
+ [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
@@ -444,6 +446,81 @@ def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]>
let assemblyFormat = [{ attr-dict }];
}
+def EmitC_ForOp : EmitC_Op<"for",
+ [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInits", "getSingleInductionVar", "getSingleLowerBound",
+ "getSingleStep", "getSingleUpperBound"]>,
+ AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"emitc::YieldOp">,
+ RecursiveMemoryEffects]> {
+ let summary = "for operation";
+ let description = [{
+ The `emitc.for` operation represents a loop taking 3 SSA values as operands
+ that represent the lower bound, upper bound and step respectively. The
+ operation defines an SSA value for its induction variable. It has one
+ region capturing the loop body. The induction variable is represented as an
+ argument of this region. This SSA value is a signless integer or index.
+ The step is a value of same type but required to be positive. The lower and
+ upper bounds specify a half-open range: the range includes the lower bound
+ but does not include the upper bound. This operation has no result.
+
+ The body region must contain exactly one block that terminates with
+ `emitc.yield`. Calling ForOp::build will create such a region and insert
+ the terminator implicitly if none is defined, so will the parsing even in
+ cases when it is absent from the custom format. For example:
+
+ ```mlir
+ // Index case.
+ emitc.for %iv = %lb to %ub step %step {
+ ... // body
+ }
+ ...
+ // Integer case.
+ emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+ ... // body
+ }
+ ```
+ }];
+ let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
+ AnySignlessIntegerOrIndex:$upperBound,
+ AnySignlessIntegerOrIndex:$step);
+ let results = (outs);
+ let regions = (region SizedRegion<1>:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
+ CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
+ ];
+
+ let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ function_ref<void(OpBuilder &, Location, Value)>;
+ Value getInductionVar() { return getBody()->getArgument(0); }
+ Block::BlockArgListType getRegionIterArgs() {
+ return Block::BlockArgListType();
+ }
+ FailureOr<LoopLikeOpInterface> replaceWithAdditionalYields(
+ RewriterBase &rewriter, ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop,
+ const NewYieldValuesFn &newYieldValuesFn) {
+ return LoopLikeOpInterfaceTrait::replaceWithAdditionalYields(
+ rewriter, newInitOperands, replaceInitOperandUsesInLoop,
+ newYieldValuesFn);
+ }
+ void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
+ void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
+ void setStep(Value step) { getOperation()->setOperand(2, step); }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+}
+
def EmitC_IfOp : EmitC_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 5d0d8df8869e313..3d145e307f06aa4 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -37,6 +37,106 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};
+// Lower scf::for to emitc::for, implementing return values using
+// emitc::variable's updated within loop body.
+struct ForLowering : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override;
+};
+
+// Create an uninitialized emitc::variable op for each result of given op.
+template <typename T>
+static SmallVector<Value> createVariablesForResults(T op,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> resultVariables;
+
+ if (!op.getNumResults())
+ return resultVariables;
+
+ Location loc = op->getLoc();
+ MLIRContext *context = op.getContext();
+
+ auto insertionPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+
+ for (OpResult result : op.getResults()) {
+ Type resultType = result.getType();
+ auto noInit = emitc::OpaqueAttr::get(context, "");
+ auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+ resultVariables.push_back(var);
+ }
+
+ rewriter.restoreInsertionPoint(insertionPoint);
+
+ return resultVariables;
+}
+
+// Create a series of assign ops assigning given values to given variables at
+// the current insertion point of given rewriter.
+static void assignValues(ValueRange values, SmallVector<Value> &variables,
+ PatternRewriter &rewriter, Location loc) {
+ for (auto value2Var : llvm::zip(values, variables)) {
+ Value value = std::get<0>(value2Var);
+ Value var = std::get<1>(value2Var);
+ rewriter.create<emitc::AssignOp>(loc, var, value);
+ }
+}
+
+static void lowerYield(SmallVector<Value> &resultVariables,
+ PatternRewriter &rewriter, scf::YieldOp yield) {
+ Location loc = yield.getLoc();
+ ValueRange operands = yield.getOperands();
+
+ auto insertionPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(yield);
+
+ assignValues(operands, resultVariables, rewriter, loc);
+
+ rewriter.create<emitc::YieldOp>(loc);
+ rewriter.restoreInsertionPoint(insertionPoint);
+ rewriter.eraseOp(yield);
+}
+
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const {
+ Location loc = forOp.getLoc();
+
+ rewriter.setInsertionPoint(forOp);
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables =
+ createVariablesForResults(forOp, rewriter);
+ SmallVector<Value> iterArgsVariables =
+ createVariablesForResults(forOp, rewriter);
+
+ assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
+
+ auto loweredFor = rewriter.create<emitc::ForOp>(
+ loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+
+ Block *loweredBody = loweredFor.getBody();
+
+ // Erase the auto-generated terminator for the lowered for op.
+ rewriter.eraseOp(loweredBody->getTerminator());
+
+ SmallVector<Value> replacingValues;
+ replacingValues.push_back(loweredFor.getInductionVar());
+ replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+
+ rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
+ lowerYield(iterArgsVariables, rewriter,
+ cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+ // Copy iterArgs into results after the for loop.
+ assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+
+ rewriter.replaceOp(forOp, resultVariables);
+ return success();
+}
+
// Lower scf::if to emitc::if, implementing return values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
@@ -52,20 +152,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();
- SmallVector<Value> resultVariables;
-
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
- if (ifOp.getNumResults()) {
- MLIRContext *context = ifOp.getContext();
- rewriter.setInsertionPoint(ifOp);
- for (OpResult result : ifOp.getResults()) {
- Type resultType = result.getType();
- auto noInit = emitc::OpaqueAttr::get(context, "");
- auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
- resultVariables.push_back(var);
- }
- }
+ SmallVector<Value> resultVariables =
+ createVariablesForResults(ifOp, rewriter);
// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
@@ -76,16 +166,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
- Location terminatorLoc = terminator->getLoc();
- ValueRange terminatorOperands = terminator->getOperands();
- rewriter.setInsertionPointToEnd(&loweredRegion.back());
- for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
- Value resultValue = std::get<0>(value2Var);
- Value resultVar = std::get<1>(value2Var);
- rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
- }
- rewriter.create<emitc::YieldOp>(terminatorLoc);
- rewriter.eraseOp(terminator);
+ lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
};
Region &thenRegion = ifOp.getThenRegion();
@@ -109,6 +190,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
}
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
}
@@ -118,7 +200,7 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::IfOp>();
+ target.addIllegalOp<scf::IfOp, scf::ForOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
index 4665c41a62e80b8..50e79d22d57e681 100644
--- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
@@ -12,5 +12,6 @@ add_mlir_dialect_library(MLIREmitCDialect
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRIR
+ MLIRLoopLikeInterface
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 961a52a70a2a168..740504cc9db2bcd 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -189,6 +190,147 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
+ Value ub, Value step, BodyBuilderFn bodyBuilder) {
+ result.addOperands({lb, ub, step});
+ Type t = lb.getType();
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArgument(t, result.location);
+
+ // Create the default terminator if the builder is not provided.
+ if (!bodyBuilder) {
+ ForOp::ensureTerminator(*bodyRegion, builder, result.location);
+ } else if (bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ }
+}
+
+void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
+
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return RegionBranchOpInterfaceTrait::getEntrySuccessorOperands(point);
+}
+
+OperandRange ForOp::getInits() { return LoopLikeOpInterfaceTrait::getInits(); }
+
+SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> ForOp::getSingleInductionVar() {
+ return getInductionVar();
+}
+
+std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
+ return OpFoldResult(getLowerBound());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleStep() {
+ return OpFoldResult(getStep());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
+ return OpFoldResult(getUpperBound());
+}
+
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for loop not to enter the
+ // body.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ regions.push_back(RegionSuccessor({}));
+}
+
+ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto &builder = parser.getBuilder();
+ Type type;
+
+ OpAsmParser::Argument inductionVariable;
+ OpAsmParser::UnresolvedOperand lb, ub, step;
+
+ // Parse the induction variable followed by '='.
+ if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
+ // Parse loop bounds.
+ parser.parseOperand(lb) || parser.parseKeyword("to") ||
+ parser.parseOperand(ub) || parser.parseKeyword("step") ||
+ parser.parseOperand(step))
+ return failure();
+
+ // Parse the optional initial iteration arguments.
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+ regionArgs.push_back(inductionVariable);
+
+ if (regionArgs.size() != result.types.size() + 1)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
+
+ // Parse optional type, else assume Index.
+ if (parser.parseOptionalColon())
+ type = builder.getIndexType();
+ else if (parser.parseType(type))
+ return failure();
+
+ // Resolve input operands.
+ regionArgs.front().type = type;
+ if (parser.resolveOperand(lb, type, result.operands) ||
+ parser.resolveOperand(ub, type, result.operands) ||
+ parser.resolveOperand(step, type, result.operands))
+ return failure();
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+void ForOp::print(OpAsmPrinter &p) {
+ p << " " << getInductionVar() << " = " << getLowerBound() << " to "
+ << getUpperBound() << " step " << getStep();
+
+ p << ' ';
+ if (Type t = getInductionVar().getType(); !t.isIndex())
+ p << " : " << t << ' ';
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult ForOp::verify() {
+ IntegerAttr step;
+ if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
+ return emitOpError("constant step operand must be positive");
+
+ return success();
+}
+
+LogicalResult ForOp::verifyRegions() {
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ if (getInductionVar().getType() != getLowerBound().getType())
+ return emitOpError(
+ "expected induction variable to be same type as bounds and step");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4645ca4b206e78c..05d472ba50d070f 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -502,30 +502,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
- OperandRange operands = forOp.getInitArgs();
- Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
- Operation::result_range results = forOp.getResults();
-
- if (!emitter.shouldDeclareVariablesAtTop()) {
- for (OpResult result : results) {
- if (failed(emitter.emitVariableDeclaration(result,
- /*trailingSemicolon=*/true)))
- return failure();
- }
- }
-
- for (auto pair : llvm::zip(iterArgs, operands)) {
- if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
- return failure();
- os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
- os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
- os << "\n";
- }
-
os << "for (";
if (failed(
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
@@ -548,35 +528,14 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
Region &forRegion = forOp.getRegion();
auto regionOps = forRegion.getOps();
- // We skip the trailing yield op because this updates the result variables
- // of the for op in the generated code. Instead we update the iterArgs at
- // the end of a loop iteration and set the result variables after the for
- // loop.
+ // We skip the trailing yield op.
for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
return failure();
}
- Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
- // Copy yield operands into iterArgs at the end of a loop iteration.
- for (auto pair : llvm::zip(iterArgs, yieldOp->getOperands())) {
- BlockArgument iterArg = std::get<0>(pair);
- Value operand = std::get<1>(pair);
- os << emitter.getOrCreateName(iterArg) << " = "
- << emitter.getOrCreateName(operand) << ";\n";
- }
-
os.unindent() << "}";
- // Copy iterArgs into results after the for loop.
- for (auto pair : llvm::zip(results, iterArgs)) {
- OpResult result = std::get<0>(pair);
- BlockArgument iterArg = std::get<1>(pair);
- os << "\n"
- << emitter.getOrCreateName(result) << " = "
- << emitter.getOrCreateName(iterArg) << ";";
- }
-
return success();
}
@@ -617,33 +576,6 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
- raw_ostream &os = emitter.ostream();
- Operation &parentOp = *yieldOp.getOperation()->getParentOp();
-
- if (yieldOp.getNumOperands() != parentOp.getNumResults()) {
- return yieldOp.emitError("number of operands does not to match the number "
- "of the parent op's results");
- }
-
- if (failed(interleaveWithError(
- llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
- [&](auto pair) -> LogicalResult {
- auto result = std::get<0>(pair);
- auto operand = std::get<1>(pair);
- os << emitter.getOrCreateName(result) << " = ";
-
- if (!emitter.hasValueInScope(operand))
- return yieldOp.emitError("operand value not in scope");
- os << emitter.getOrCreateName(operand);
- return success();
- },
- [&]() { os << ";\n"; })))
- return failure();
-
- return success();
-}
-
static LogicalResult printOperation(CppEmitter &emitter,
func::ReturnOp returnOp) {
raw_ostream &os = emitter.ostream();
@@ -751,7 +683,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
// When generating code for an scf.for op, printing a trailing semicolon
// is handled within the printOperation function.
bool trailingSemicolon =
- !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, scf::ForOp>(op);
+ !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, emitc::ForOp>(
+ op);
if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1015,15 +948,12 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
emitc::CastOp, emitc::CmpOp, emitc::ConstantOp, emitc::DivOp,
- emitc::IfOp, emitc::IncludeOp, emitc::MulOp, emitc::RemOp,
- emitc::SubOp, emitc::VariableOp>(
+ emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
+ emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
- // SCF ops.
- .Case<scf::ForOp, scf::YieldOp>(
- [&](auto op) { return printOperation(*this, op); })
// Arithmetic ops.
.Case<arith::ConstantOp>(
[&](auto op) { return printOperation(*this, op); })
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
new file mode 100644
index 000000000000000..7f90310af218942
--- /dev/null
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
+
+func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) {
+ scf.for %i0 = %arg0 to %arg1 step %arg2 {
+ %c1 = arith.constant 1 : index
+ }
+ return
+}
+// CHECK-LABEL: func.func @simple_std_for_loop(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) {
+ scf.for %i0 = %arg0 to %arg1 step %arg2 {
+ %c1 = arith.constant 1 : index
+ scf.for %i1 = %arg0 to %arg1 step %arg2 {
+ %c1_0 = arith.constant 1 : index
+ }
+ }
+ return
+}
+// CHECK-LABEL: func.func @simple_std_2_for_loops(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) {
+// CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) {
+ %s0 = arith.constant 0.0 : f32
+ %s1 = arith.constant 1.0 : f32
+ %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) {
+ %sn = arith.addf %si, %sj : f32
+ scf.yield %sn, %sn : f32, f32
+ }
+ return %result#0, %result#1 : f32, f32
+}
+// CHECK-LABEL: func.func @for_yield(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
+// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
+// CHECK-NEXT: return %[[VAL_5]], %[[VAL_6]] : f32, f32
+// CHECK-NEXT: }
+
+func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
+ %s0 = arith.constant 1.0 : f32
+ %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) {
+ %result = scf.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) {
+ %sn = arith.addf %si, %si : f32
+ scf.yield %sn : f32
+ }
+ scf.yield %result : f32
+ }
+ return %r : f32
+}
+// CHECK-LABEL: func.func @nested_for_yield(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
+// CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT: emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT: %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
+// CHECK-NEXT: emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT: }
+// CHECK-NEXT: emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
+// CHECK-NEXT: return %[[VAL_4]] : f32
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 9e8f0bf0bf8bdcd..53d88adf4305ff8 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -203,7 +203,7 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
// -----
func.func @test_misplaced_yield() {
- // expected-error @+1 {{'emitc.yield' op expects parent op 'emitc.if'}}
+ // expected-error @+1 {{'emitc.yield' op expects parent op to be one of 'emitc.if, emitc.for'}}
emitc.yield
return
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 0817945e3b1e0bc..6c8398680980466 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -105,7 +105,7 @@ func.func @test_if(%arg0: i1, %arg1: f32) {
return
}
-func.func @test_explicit_yield(%arg0: i1, %arg1: f32) {
+func.func @test_if_explicit_yield(%arg0: i1, %arg1: f32) {
emitc.if %arg0 {
%0 = emitc.call "func_const"(%arg1) : (f32) -> i32
emitc.yield
@@ -127,3 +127,25 @@ func.func @test_assign(%arg1: f32) {
emitc.assign %arg1 : f32 to %v : f32
return
}
+
+func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
+ emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+ %0 = emitc.call "func_const"(%i0) : (index) -> i32
+ }
+ return
+}
+
+func.func @test_for_explicit_yield(%arg0 : index, %arg1 : index, %arg2 : index) {
+ emitc.for %i0 = %arg0 to %arg1 step %arg2 {
+ %0 = emitc.call "func_const"(%i0) : (index) -> i32
+ emitc.yield
+ }
+ return
+}
+
+func.func @test_for_not_index_induction(%arg0 : i16, %arg1 : i16, %arg2 : i16) {
+ emitc.for %i0 = %arg0 to %arg1 step %arg2 : i16 {
+ %0 = emitc.call "func_const"(%i0) : (i16) -> i32
+ }
+ return
+}
diff --git a/mlir/test/Target/Cpp/for.mlir b/mlir/test/Target/Cpp/for.mlir
index e904c99820ad846..c02c8b1ac33e371 100644
--- a/mlir/test/Target/Cpp/for.mlir
+++ b/mlir/test/Target/Cpp/for.mlir
@@ -2,7 +2,7 @@
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
- scf.for %i0 = %arg0 to %arg1 step %arg2 {
+ emitc.for %i0 = %arg0 to %arg1 step %arg2 {
%0 = emitc.call "f"() : () -> i32
}
return
@@ -28,11 +28,21 @@ func.func @test_for_yield() {
%s0 = arith.constant 0 : i32
%p0 = arith.constant 1.0 : f32
- %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
- %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
- %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
- scf.yield %sn, %pn : i32, f32
+ %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ emitc.assign %s0 : i32 to %2 : i32
+ emitc.assign %p0 : f32 to %3 : f32
+ emitc.for %iter = %start to %stop step %step {
+ %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+ %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+ emitc.assign %sn : i32 to %2 : i32
+ emitc.assign %pn : f32 to %3 : f32
+ emitc.yield
}
+ emitc.assign %2 : i32 to %0 : i32
+ emitc.assign %3 : f32 to %1 : f32
return
}
@@ -44,8 +54,10 @@ func.func @test_for_yield() {
// CPP-DEFAULT-NEXT: float [[P0:[^ ]*]] = (float)1.000000000e+00;
// CPP-DEFAULT-NEXT: int32_t [[SE:[^ ]*]];
// CPP-DEFAULT-NEXT: float [[PE:[^ ]*]];
-// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DEFAULT-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DEFAULT-NEXT: float [[PI:[^ ]*]];
+// CPP-DEFAULT-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DEFAULT-NEXT: [[PI:[^ ]*]] = [[P0]];
// CPP-DEFAULT-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
// CPP-DEFAULT-NEXT: int32_t [[SN:[^ ]*]] = add([[SI]], [[ITER]]);
// CPP-DEFAULT-NEXT: float [[PN:[^ ]*]] = mul([[PI]], [[ITER]]);
@@ -64,6 +76,8 @@ func.func @test_for_yield() {
// CPP-DECLTOP-NEXT: float [[P0:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t [[SE:[^ ]*]];
// CPP-DECLTOP-NEXT: float [[PE:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]];
+// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t [[SN:[^ ]*]];
// CPP-DECLTOP-NEXT: float [[PN:[^ ]*]];
// CPP-DECLTOP-NEXT: [[START]] = 0;
@@ -71,8 +85,12 @@ func.func @test_for_yield() {
// CPP-DECLTOP-NEXT: [[STEP]] = 1;
// CPP-DECLTOP-NEXT: [[S0]] = 0;
// CPP-DECLTOP-NEXT: [[P0]] = (float)1.000000000e+00;
-// CPP-DECLTOP-NEXT: int32_t [[SI:[^ ]*]] = [[S0]];
-// CPP-DECLTOP-NEXT: float [[PI:[^ ]*]] = [[P0]];
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: [[SI:[^ ]*]] = [[S0]];
+// CPP-DECLTOP-NEXT: [[PI:[^ ]*]] = [[P0]];
// CPP-DECLTOP-NEXT: for (size_t [[ITER:[^ ]*]] = [[START]]; [[ITER]] < [[STOP]]; [[ITER]] += [[STEP]]) {
// CPP-DECLTOP-NEXT: [[SN]] = add([[SI]], [[ITER]]);
// CPP-DECLTOP-NEXT: [[PN]] = mul([[PI]], [[ITER]]);
@@ -91,14 +109,24 @@ func.func @test_for_yield_2() {
%s0 = emitc.literal "0" : i32
%p0 = emitc.literal "M_PI" : f32
- %result:2 = scf.for %iter = %start to %stop step %step iter_args(%si = %s0, %pi = %p0) -> (i32, f32) {
- %sn = emitc.call "add"(%si, %iter) : (i32, index) -> i32
- %pn = emitc.call "mul"(%pi, %iter) : (f32, index) -> f32
- scf.yield %sn, %pn : i32, f32
+ %0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ %2 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %3 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ emitc.assign %s0 : i32 to %2 : i32
+ emitc.assign %p0 : f32 to %3 : f32
+ emitc.for %iter = %start to %stop step %step {
+ %sn = emitc.call "add"(%2, %iter) : (i32, index) -> i32
+ %pn = emitc.call "mul"(%3, %iter) : (f32, index) -> f32
+ emitc.assign %sn : i32 to %2 : i32
+ emitc.assign %pn : f32 to %3 : f32
+ emitc.yield
}
+ emitc.assign %2 : i32 to %0 : i32
+ emitc.assign %3 : f32 to %1 : f32
return
}
// CPP-DEFAULT: void test_for_yield_2() {
-// CPP-DEFAULT: float{{.*}}= M_PI
+// CPP-DEFAULT: {{.*}}= M_PI
// CPP-DEFAULT: for (size_t [[IN:.*]] = 0; [[IN]] < 10; [[IN]] += 1) {
More information about the Mlir-commits
mailing list