[Mlir-commits] [mlir] Add a structured if operation (PR #67234)
Gil Rapaport
llvmlistbot at llvm.org
Sun Sep 24 14:39:22 PDT 2023
https://github.com/aniragil updated https://github.com/llvm/llvm-project/pull/67234
>From 7a892c6f232cba4aad88e8f3eeff339512ce6a52 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Thu, 21 Sep 2023 21:04:10 +0300
Subject: [PATCH 1/2] [mlir][emitc] Add a structured if operation
Add an emitc.if op to the EmitC dialect. A new convert-scf-to-emitc
pass replaces the existing direct translation of scf.if to C; The
translator now handles emitc.if instead.
The emitc.if op doesn't return any value and its then/else regions are
terminated with a new scf.yield op. Values returned by scf.if are
lowered using emitc.variable ops, assigned to in the then/else regions
using a new emitc.assign op.
---
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 11 ++
.../mlir/Conversion/SCFToEmitC/SCFToEmitC.h | 29 +++
mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 8 +
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 100 ++++++++++
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt | 18 ++
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 131 ++++++++++++
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 187 ++++++++++++++++++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 48 ++---
mlir/test/Conversion/SCFToEmitC/if.mlir | 70 +++++++
mlir/test/Dialect/EmitC/invalid_ops.mlir | 25 +++
mlir/test/Dialect/EmitC/ops.mlir | 30 +++
mlir/test/Target/Cpp/if.mlir | 22 +--
14 files changed, 648 insertions(+), 33 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
create mode 100644 mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
create mode 100644 mlir/test/Conversion/SCFToEmitC/if.mlir
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index fc5e9adba114405..41806004fc1dca8 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -48,6 +48,7 @@
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 11008baa0160efe..cca1e262df7c121 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -931,6 +931,17 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let dependentDialects = ["affine::AffineDialect", "gpu::GPUDialect"];
}
+//===----------------------------------------------------------------------===//
+// SCFToEmitC
+//===----------------------------------------------------------------------===//
+
+def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
+ let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
+ " control flow";
+ let constructor = "mlir::createConvertSCFToEmitCPass()";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ShapeToStandard
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
new file mode 100644
index 000000000000000..ec7a9c5de634496
--- /dev/null
+++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
@@ -0,0 +1,29 @@
+//===- SCFToEmitC.h - SCF to EmitC Pass entrypoint --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
+#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_SCFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
+void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
+
+/// Creates a pass to convert SCF operations to the EmitC dialect.
+std::unique_ptr<Pass> createConvertSCFToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H_
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index b3c1170eefdab90..4dff26e23c42850 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -14,15 +14,23 @@
#define MLIR_DIALECT_EMITC_IR_EMITC_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
+namespace mlir {
+namespace emitc {
+void buildTerminatedBody(OpBuilder &builder, Location loc);
+} // namespace emitc
+} // namespace mlir
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCAttributes.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 40a4896aca4b633..ad28763c89f67a5 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -17,6 +17,7 @@ include "mlir/Dialect/EmitC/IR/EmitCAttributes.td"
include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
include "mlir/Interfaces/CastInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -402,4 +403,103 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}
+def EmitC_AssignOp : EmitC_Op<"assign", [MemoryEffects<[MemWrite]>]> {
+ let summary = "Assign operation";
+ let description = [{
+ The `assign` operation stores an SSA value to the location designated by an
+ EmitC variable. This operation doesn't return any value. The assigned value
+ must be of the same type as the variable being assigned. The operation is
+ emitted as a C/C++ '=' operator.
+
+ Example:
+
+ ```mlir
+ // Integer variable
+ %0 = "emitc.variable"(){value = 42 : i32} : () -> i32
+ %1 = emitc.call "foo"() : () -> (i32)
+
+ // Assign emitted as `... = ...;`
+ "emitc.assign"(%0, %1) : (i32, %i32) -> ()
+ ```
+ }];
+
+ let arguments = (ins AnyType:$var, AnyType:$value);
+ let results = (outs);
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
+}
+
+def YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
+ let summary = "block termination operation";
+ let description = [{
+ "yield" terminates blocks within EmitC control-flow operations. Since
+ control-flow constructs in C do not return values, this operation doesn't
+ take any arguments.
+ }];
+
+ let arguments = (ins);
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+
+ let assemblyFormat = [{ attr-dict }];
+}
+
+def EmitC_IfOp : EmitC_Op<"if",
+ [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
+ "getNumRegionInvocations", "getRegionInvocationBounds",
+ "getEntrySuccessorRegions"]>, SingleBlock,
+ SingleBlockImplicitTerminator<"emitc::YieldOp">,
+ RecursiveMemoryEffects, NoRegionArguments]> {
+ let summary = "if-then-else operation";
+ let description = [{
+ The `if` operation represents an if-then-else construct for
+ conditionally executing two regions of code. The operand to an if operation
+ is a boolean value. For example:
+
+ ```mlir
+ emitc.if %b {
+ ...
+ } else {
+ ...
+ }
+ ```
+
+ The "then" region has exactly 1 block. The "else" region may have 0 or 1
+ blocks. The blocks are always terminated with `emitc.yield`, which can be
+ left out to be inserted implicitly. This operation doesn't produce any
+ results.
+ }];
+ let arguments = (ins I1:$condition);
+ let results = (outs);
+ let regions = (region SizedRegion<1>:$thenRegion,
+ MaxSizedRegion<1>:$elseRegion);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$cond)>,
+ OpBuilder<(ins "Value":$cond, "bool":$addThenBlock, "bool":$addElseBlock)>,
+ OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
+ OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>,
+ OpBuilder<(ins "Value":$cond,
+ CArg<"function_ref<void(OpBuilder &, Location)>",
+ "buildTerminatedBody">:$thenBuilder,
+ CArg<"function_ref<void(OpBuilder &, Location)>",
+ "nullptr">:$elseBuilder)>,
+ ];
+
+ let extraClassDeclaration = [{
+ OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) {
+ Block* body = getBody(0);
+ return OpBuilder::atBlockEnd(body, listener);
+ }
+ OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) {
+ Block* body = getBody(1);
+ return OpBuilder::atBlockEnd(body, listener);
+ }
+ Block* thenBlock();
+ Block* elseBlock();
+ }];
+ let hasCustomAssemblyFormat = 1;
+}
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 275e095245e89ce..660e48768c4ff34 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -38,6 +38,7 @@ add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
+add_subdirectory(SCFToEmitC)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToOpenMP)
add_subdirectory(SCFToSPIRV)
diff --git a/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
new file mode 100644
index 000000000000000..79119d374f7a5e9
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
@@ -0,0 +1,18 @@
+add_mlir_conversion_library(MLIRSCFToEmitC
+ SCFToEmitC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIREmitCDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ )
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
new file mode 100644
index 000000000000000..1fa518bad7cfb29
--- /dev/null
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -0,0 +1,131 @@
+//===- SCFToEmitC.cpp - SCF to CF conversion ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert scf.for, scf.if and loop.terminator
+// ops into standard CFG ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_SCFTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+
+struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
+ void runOnOperation() override;
+};
+
+// Lower scf::if to emitc::if, implementing return values as emitc::variable's
+// updated within the then and else regions.
+struct IfLowering : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp ifOp,
+ PatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
+ PatternRewriter &rewriter) const {
+ auto loc = ifOp.getLoc();
+
+ SmallVector<Value> resultVariables;
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the then & else regions.
+ if (ifOp.getNumResults()) {
+ auto context = ifOp.getContext();
+ rewriter.setInsertionPoint(ifOp);
+ for (auto result : ifOp.getResults()) {
+ auto resultType = result.getType();
+ auto noInit = emitc::OpaqueAttr::get(context, "");
+ auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+ resultVariables.push_back(var);
+ }
+ }
+
+ // Utility function to lower the contents of an scf::if region to an emitc::if
+ // region. The contents of the scf::if regions is moved into the respective
+ // emitc::if regions, but the scf::yield is replaced not only with an
+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
+ // yielded values into the result variables.
+ auto lowerRegion = [&resultVariables, &rewriter](Region ®ion,
+ Region &loweredRegion) {
+ rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
+ Operation *terminator = loweredRegion.back().getTerminator();
+ auto terminatorLoc = terminator->getLoc();
+ ValueRange terminatorOperands = terminator->getOperands();
+ rewriter.setInsertionPointToEnd(&loweredRegion.back());
+ for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
+ auto resultValue = std::get<0>(value2Var);
+ auto resultVar = std::get<1>(value2Var);
+ rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
+ }
+ rewriter.create<emitc::YieldOp>(terminatorLoc);
+ rewriter.eraseOp(terminator);
+ };
+
+ auto &thenRegion = ifOp.getThenRegion();
+ auto &elseRegion = ifOp.getElseRegion();
+
+ bool hasElseBlock = !elseRegion.empty();
+
+ auto loweredIf =
+ rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
+
+ auto &loweredThenRegion = loweredIf.getThenRegion();
+ lowerRegion(thenRegion, loweredThenRegion);
+
+ if (hasElseBlock) {
+ auto &loweredElseRegion = loweredIf.getElseRegion();
+ lowerRegion(elseRegion, loweredElseRegion);
+ }
+
+ rewriter.replaceOp(ifOp, resultVariables);
+ return success();
+}
+
+void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<IfLowering>(patterns.getContext());
+}
+
+void SCFToEmitCPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ populateSCFToEmitCConversionPatterns(patterns);
+
+ // Configure conversion to lower out SCF operations.
+ ConversionTarget target(getContext());
+ target.addIllegalOp<scf::IfOp>();
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createConvertSCFToEmitCPass() {
+ return std::make_unique<SCFToEmitCPass>();
+}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 1c95ed07702d1db..447c7b748a90015 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -44,6 +44,12 @@ Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
return builder.create<emitc::ConstantOp>(loc, type, value);
}
+/// Default callback for builders of ops carrying a region. Inserts a yield
+/// without arguments.
+void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
+ builder.create<emitc::YieldOp>(loc);
+}
+
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
@@ -248,6 +254,187 @@ LogicalResult emitc::VariableOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// AssignOp
+//===----------------------------------------------------------------------===//
+
+/// The assign op requires that the assigned value's type matches the
+/// assigned-to variable type.
+LogicalResult emitc::AssignOp::verify() {
+ auto variable = getVar();
+ auto variableDef = variable.getDefiningOp();
+ if (!variableDef || !llvm::isa<emitc::VariableOp>(variableDef))
+ return emitOpError() << "requires first operand (" << variable
+ << ") to be a Variable";
+
+ auto value = getValue();
+ if (variable.getType() != value.getType())
+ return emitOpError() << "requires value's type (" << value.getType()
+ << ") to match variable's type (" << variable.getType()
+ << ")";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// IfOp
+//===----------------------------------------------------------------------===//
+
+void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
+ bool addThenBlock, bool addElseBlock) {
+ assert((!addElseBlock || addThenBlock) &&
+ "must not create else block w/o then block");
+ result.addOperands(cond);
+
+ // Add regions and blocks.
+ OpBuilder::InsertionGuard guard(builder);
+ Region *thenRegion = result.addRegion();
+ if (addThenBlock)
+ builder.createBlock(thenRegion);
+ Region *elseRegion = result.addRegion();
+ if (addElseBlock)
+ builder.createBlock(elseRegion);
+}
+
+void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
+ bool withElseRegion) {
+ result.addOperands(cond);
+
+ // Build then region.
+ OpBuilder::InsertionGuard guard(builder);
+ Region *thenRegion = result.addRegion();
+ builder.createBlock(thenRegion);
+
+ // Build else region.
+ Region *elseRegion = result.addRegion();
+ if (withElseRegion) {
+ builder.createBlock(elseRegion);
+ }
+}
+
+void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
+ function_ref<void(OpBuilder &, Location)> thenBuilder,
+ function_ref<void(OpBuilder &, Location)> elseBuilder) {
+ assert(thenBuilder && "the builder callback for 'then' must be present");
+ result.addOperands(cond);
+
+ // Build then region.
+ OpBuilder::InsertionGuard guard(builder);
+ Region *thenRegion = result.addRegion();
+ builder.createBlock(thenRegion);
+ thenBuilder(builder, result.location);
+
+ // Build else region.
+ Region *elseRegion = result.addRegion();
+ if (elseBuilder) {
+ builder.createBlock(elseRegion);
+ elseBuilder(builder, result.location);
+ }
+}
+
+ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+ // Create the regions for 'then'.
+ result.regions.reserve(2);
+ Region *thenRegion = result.addRegion();
+ Region *elseRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::UnresolvedOperand cond;
+ Type i1Type = builder.getIntegerType(1);
+ if (parser.parseOperand(cond) ||
+ parser.resolveOperand(cond, i1Type, result.operands))
+ return failure();
+ // Parse the 'then' region.
+ if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
+
+ // If we find an 'else' keyword then parse the 'else' region.
+ if (!parser.parseOptionalKeyword("else")) {
+ if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
+ }
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+void IfOp::print(OpAsmPrinter &p) {
+ bool printBlockTerminators = false;
+
+ p << " " << getCondition();
+ p << ' ';
+ p.printRegion(getThenRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+
+ // Print the 'else' regions if it exists and has a block.
+ auto &elseRegion = getElseRegion();
+ if (!elseRegion.empty()) {
+ p << " else ";
+ p.printRegion(elseRegion,
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/printBlockTerminators);
+ }
+
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+/// Given the region at `index`, or the parent operation if `index` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void IfOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // The `then` and the `else` region branch back to the parent operation.
+ if (!point.isParent()) {
+ regions.push_back(RegionSuccessor());
+ return;
+ }
+
+ regions.push_back(RegionSuccessor(&getThenRegion()));
+
+ // Don't consider the else region if it is empty.
+ Region *elseRegion = &this->getElseRegion();
+ if (elseRegion->empty())
+ regions.push_back(RegionSuccessor());
+ else
+ regions.push_back(RegionSuccessor(elseRegion));
+}
+
+void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ FoldAdaptor adaptor(operands, *this);
+ auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+ if (!boolAttr || boolAttr.getValue())
+ regions.emplace_back(&getThenRegion());
+
+ // If the else region is empty, execution continues after the parent op.
+ if (!boolAttr || !boolAttr.getValue()) {
+ if (!getElseRegion().empty())
+ regions.emplace_back(&getElseRegion());
+ else
+ regions.emplace_back();
+ }
+}
+
+void IfOp::getRegionInvocationBounds(
+ ArrayRef<Attribute> operands,
+ SmallVectorImpl<InvocationBounds> &invocationBounds) {
+ if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
+ // If the condition is known, then one region is known to be executed once
+ // and the other zero times.
+ invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
+ invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
+ } else {
+ // Non-constant condition. Each region may be executed 0 or 1 times.
+ invocationBounds.assign(2, {0, 1});
+ }
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 832dd8f2013fa4d..12d794419621931 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -246,6 +246,19 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::AssignOp assignOp) {
+ auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
+ OpResult result = variableOp->getResult(0);
+
+ if (failed(emitter.emitVariableAssignment(result)))
+ return failure();
+
+ emitter.ostream() << emitter.getOrCreateName(assignOp.getValue());
+
+ return success();
+}
+
static LogicalResult printBinaryOperation(CppEmitter &emitter,
Operation *operation,
StringRef binaryOperator) {
@@ -567,17 +580,9 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
raw_indented_ostream &os = emitter.ostream();
- if (!emitter.shouldDeclareVariablesAtTop()) {
- for (OpResult result : ifOp.getResults()) {
- if (failed(emitter.emitVariableDeclaration(result,
- /*trailingSemicolon=*/true)))
- return failure();
- }
- }
-
os << "if (";
if (failed(emitter.emitOperands(*ifOp.getOperation())))
return failure();
@@ -585,10 +590,9 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
os.indent();
Region &thenRegion = ifOp.getThenRegion();
- for (Operation &op : thenRegion.getOps()) {
- // Note: This prints a superfluous semicolon if the terminating yield op has
- // zero results.
- if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
+ auto thenOps = thenRegion.getOps();
+ for (auto it = thenOps.begin(); std::next(it) != thenOps.end(); ++it) {
+ if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
return failure();
}
@@ -599,10 +603,9 @@ static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
os << " else {\n";
os.indent();
- for (Operation &op : elseRegion.getOps()) {
- // Note: This prints a superfluous semicolon if the terminating yield op
- // has zero results.
- if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
+ auto elseOps = elseRegion.getOps();
+ for (auto it = elseOps.begin(); std::next(it) != elseOps.end(); ++it) {
+ if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
return failure();
}
@@ -741,12 +744,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
return failure();
}
for (Operation &op : block.getOperations()) {
- // When generating code for an scf.if or cf.cond_br op no semicolon needs
- // to be printed after the closing brace.
+ // When generating code for an emitc.if or cf.cond_br op no semicolon
+ // needs to be printed after the closing brace.
// When generating code for an scf.for op, printing a trailing semicolon
// is handled within the printOperation function.
bool trailingSemicolon =
- !isa<cf::CondBranchOp, emitc::LiteralOp, scf::IfOp, scf::ForOp>(op);
+ !isa<cf::CondBranchOp, emitc::LiteralOp, emitc::IfOp, scf::ForOp>(op);
if (failed(emitter.emitOperation(
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1010,13 +1013,14 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// EmitC ops.
.Case<emitc::AddOp, emitc::ApplyOp, emitc::CallOp, emitc::CastOp,
emitc::CmpOp, emitc::ConstantOp, emitc::DivOp, emitc::IncludeOp,
- emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
+ emitc::MulOp, emitc::RemOp, emitc::SubOp, emitc::VariableOp,
+ emitc::AssignOp, emitc::IfOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
[&](auto op) { return printOperation(*this, op); })
// SCF ops.
- .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
+ .Case<scf::ForOp, scf::YieldOp>(
[&](auto op) { return printOperation(*this, op); })
// Arithmetic ops.
.Case<arith::ConstantOp>(
diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir
new file mode 100644
index 000000000000000..738d73fdbe869e8
--- /dev/null
+++ b/mlir/test/Conversion/SCFToEmitC/if.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s
+
+func.func @test_if(%arg0: i1, %arg1: f32) {
+ emitc.if %arg0 {
+ %0 = emitc.call "func_const"(%arg1) : (f32) -> i32
+ }
+ return
+}
+// CHECK-LABEL: func.func @test_if(
+// CHECK-SAME: %[[VAL_0:.*]]: i1,
+// CHECK-SAME: %[[VAL_1:.*]]: f32) {
+// CHECK-NEXT: emitc.if %[[VAL_0]] {
+// CHECK-NEXT: %[[VAL_2:.*]] = emitc.call "func_const"(%[[VAL_1]]) : (f32) -> i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+func.func @test_if_else(%arg0: i1, %arg1: f32) {
+ emitc.if %arg0 {
+ %0 = emitc.call "func_true"(%arg1) : (f32) -> i32
+ } else {
+ %0 = emitc.call "func_false"(%arg1) : (f32) -> i32
+ }
+ return
+}
+// CHECK-LABEL: func.func @test_if_else(
+// CHECK-SAME: %[[VAL_0:.*]]: i1,
+// CHECK-SAME: %[[VAL_1:.*]]: f32) {
+// CHECK-NEXT: emitc.if %[[VAL_0]] {
+// CHECK-NEXT: %[[VAL_2:.*]] = emitc.call "func_true"(%[[VAL_1]]) : (f32) -> i32
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[VAL_3:.*]] = emitc.call "func_false"(%[[VAL_1]]) : (f32) -> i32
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+
+
+func.func @test_if_yield(%arg0: i1, %arg1: f32) {
+ %0 = arith.constant 0 : i8
+ %x, %y = scf.if %arg0 -> (i32, f64) {
+ %1 = emitc.call "func_true_1"(%arg1) : (f32) -> i32
+ %2 = emitc.call "func_true_2"(%arg1) : (f32) -> f64
+ scf.yield %1, %2 : i32, f64
+ } else {
+ %1 = emitc.call "func_false_1"(%arg1) : (f32) -> i32
+ %2 = emitc.call "func_false_2"(%arg1) : (f32) -> f64
+ scf.yield %1, %2 : i32, f64
+ }
+ return
+}
+// CHECK-LABEL: func.func @test_if_yield(
+// CHECK-SAME: %[[VAL_0:.*]]: i1,
+// CHECK-SAME: %[[VAL_1:.*]]: f32) {
+// CHECK-NEXT: %[[VAL_2:.*]] = arith.constant 0 : i8
+// CHECK-NEXT: %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+// CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
+// CHECK-NEXT: emitc.if %[[VAL_0]] {
+// CHECK-NEXT: %[[VAL_5:.*]] = emitc.call "func_true_1"(%[[VAL_1]]) : (f32) -> i32
+// CHECK-NEXT: %[[VAL_6:.*]] = emitc.call "func_true_2"(%[[VAL_1]]) : (f32) -> f64
+// CHECK-NEXT: emitc.assign %[[VAL_5]] : i32 to %[[VAL_3]] : i32
+// CHECK-NEXT: emitc.assign %[[VAL_6]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT: } else {
+// CHECK-NEXT: %[[VAL_7:.*]] = emitc.call "func_false_1"(%[[VAL_1]]) : (f32) -> i32
+// CHECK-NEXT: %[[VAL_8:.*]] = emitc.call "func_false_2"(%[[VAL_1]]) : (f32) -> f64
+// CHECK-NEXT: emitc.assign %[[VAL_7]] : i32 to %[[VAL_3]] : i32
+// CHECK-NEXT: emitc.assign %[[VAL_8]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 3978f906525b884..9e8f0bf0bf8bdcd 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -199,3 +199,28 @@ func.func @sub_pointer_pointer(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
%1 = "emitc.sub" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
return
}
+
+// -----
+
+func.func @test_misplaced_yield() {
+ // expected-error @+1 {{'emitc.yield' op expects parent op 'emitc.if'}}
+ emitc.yield
+ return
+}
+
+// -----
+
+func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
+ // expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable}}
+ emitc.assign %arg1 : f32 to %arg2 : f32
+ return
+}
+
+// -----
+
+func.func @test_assign_type_mismatch(%arg1: f32) {
+ %v = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ // expected-error @+1 {{'emitc.assign' op requires value's type ('f32') to match variable's type ('i32')}}
+ emitc.assign %arg1 : f32 to %v : i32
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 279fe13229c594e..0817945e3b1e0bc 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -97,3 +97,33 @@ func.func @cmp(%arg0 : i32, %arg1 : f32, %arg2 : i64, %arg3 : f64, %arg4 : !emit
%14 = emitc.cmp three_way, %arg6, %arg6 : (!emitc.opaque<"custom">, !emitc.opaque<"custom">) -> !emitc.opaque<"custom">
return
}
+
+func.func @test_if(%arg0: i1, %arg1: f32) {
+ emitc.if %arg0 {
+ %0 = emitc.call "func_const"(%arg1) : (f32) -> i32
+ }
+ return
+}
+
+func.func @test_explicit_yield(%arg0: i1, %arg1: f32) {
+ emitc.if %arg0 {
+ %0 = emitc.call "func_const"(%arg1) : (f32) -> i32
+ emitc.yield
+ }
+ return
+}
+
+func.func @test_if_else(%arg0: i1, %arg1: f32) {
+ emitc.if %arg0 {
+ %0 = emitc.call "func_true"(%arg1) : (f32) -> i32
+ } else {
+ %0 = emitc.call "func_false"(%arg1) : (f32) -> i32
+ }
+ return
+}
+
+func.func @test_assign(%arg1: f32) {
+ %v = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+ emitc.assign %arg1 : f32 to %v : f32
+ return
+}
diff --git a/mlir/test/Target/Cpp/if.mlir b/mlir/test/Target/Cpp/if.mlir
index 74fcb7104228f08..beff2182777b42d 100644
--- a/mlir/test/Target/Cpp/if.mlir
+++ b/mlir/test/Target/Cpp/if.mlir
@@ -2,7 +2,7 @@
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
func.func @test_if(%arg0: i1, %arg1: f32) {
- scf.if %arg0 {
+ emitc.if %arg0 {
%0 = emitc.call "func_const"(%arg1) : (f32) -> i32
}
return
@@ -10,7 +10,6 @@ func.func @test_if(%arg0: i1, %arg1: f32) {
// CPP-DEFAULT: void test_if(bool [[V0:[^ ]*]], float [[V1:[^ ]*]]) {
// CPP-DEFAULT-NEXT: if ([[V0]]) {
// CPP-DEFAULT-NEXT: int32_t [[V2:[^ ]*]] = func_const([[V1]]);
-// CPP-DEFAULT-NEXT: ;
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
@@ -18,13 +17,12 @@ func.func @test_if(%arg0: i1, %arg1: f32) {
// CPP-DECLTOP-NEXT: int32_t [[V2:[^ ]*]];
// CPP-DECLTOP-NEXT: if ([[V0]]) {
// CPP-DECLTOP-NEXT: [[V2]] = func_const([[V1]]);
-// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: }
// CPP-DECLTOP-NEXT: return;
func.func @test_if_else(%arg0: i1, %arg1: f32) {
- scf.if %arg0 {
+ emitc.if %arg0 {
%0 = emitc.call "func_true"(%arg1) : (f32) -> i32
} else {
%0 = emitc.call "func_false"(%arg1) : (f32) -> i32
@@ -34,10 +32,8 @@ func.func @test_if_else(%arg0: i1, %arg1: f32) {
// CPP-DEFAULT: void test_if_else(bool [[V0:[^ ]*]], float [[V1:[^ ]*]]) {
// CPP-DEFAULT-NEXT: if ([[V0]]) {
// CPP-DEFAULT-NEXT: int32_t [[V2:[^ ]*]] = func_true([[V1]]);
-// CPP-DEFAULT-NEXT: ;
// CPP-DEFAULT-NEXT: } else {
// CPP-DEFAULT-NEXT: int32_t [[V3:[^ ]*]] = func_false([[V1]]);
-// CPP-DEFAULT-NEXT: ;
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
@@ -46,24 +42,26 @@ func.func @test_if_else(%arg0: i1, %arg1: f32) {
// CPP-DECLTOP-NEXT: int32_t [[V3:[^ ]*]];
// CPP-DECLTOP-NEXT: if ([[V0]]) {
// CPP-DECLTOP-NEXT: [[V2]] = func_true([[V1]]);
-// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: } else {
// CPP-DECLTOP-NEXT: [[V3]] = func_false([[V1]]);
-// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: }
// CPP-DECLTOP-NEXT: return;
func.func @test_if_yield(%arg0: i1, %arg1: f32) {
%0 = arith.constant 0 : i8
- %x, %y = scf.if %arg0 -> (i32, f64) {
+ %x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
+ %y = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
+ emitc.if %arg0 {
%1 = emitc.call "func_true_1"(%arg1) : (f32) -> i32
%2 = emitc.call "func_true_2"(%arg1) : (f32) -> f64
- scf.yield %1, %2 : i32, f64
+ emitc.assign %1 : i32 to %x : i32
+ emitc.assign %2 : f64 to %y : f64
} else {
%1 = emitc.call "func_false_1"(%arg1) : (f32) -> i32
%2 = emitc.call "func_false_2"(%arg1) : (f32) -> f64
- scf.yield %1, %2 : i32, f64
+ emitc.assign %1 : i32 to %x : i32
+ emitc.assign %2 : f64 to %y : f64
}
return
}
@@ -93,6 +91,8 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]];
// CPP-DECLTOP-NEXT: double [[V8:[^ ]*]];
// CPP-DECLTOP-NEXT: [[V2]] = 0;
+// CPP-DECLTOP-NEXT: ;
+// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: if ([[V0]]) {
// CPP-DECLTOP-NEXT: [[V5]] = func_true_1([[V1]]);
// CPP-DECLTOP-NEXT: [[V6]] = func_true_2([[V1]]);
>From a65a2773806879a465f3665044939383eef35c57 Mon Sep 17 00:00:00 2001
From: Gil Rapaport <gil.rapaport at mobileye.com>
Date: Sun, 24 Sep 2023 11:53:06 +0300
Subject: [PATCH 2/2] fixup! [mlir][emitc] Add a structured if operation
Addressed review comments
---
mlir/include/mlir/Conversion/Passes.td | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index cca1e262df7c121..a17fabbaf28239e 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -938,7 +938,6 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
" control flow";
- let constructor = "mlir::createConvertSCFToEmitCPass()";
let dependentDialects = ["emitc::EmitCDialect"];
}
More information about the Mlir-commits
mailing list