[Mlir-commits] [mlir] [mlir][emitc] Add a structured for operation (PR #68206)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 4 04:08:48 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
<details>
<summary>Changes</summary>
Add an emitc.for op to the EmitC dialect as a lowering target for
scf.for, replacing its current direct translation to C; The translator
now handles emitc.for instead.
---
Patch is 35.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68206.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.h (+1)
- (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+78-1)
- (modified) mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp (+116-34)
- (modified) mlir/lib/Dialect/EmitC/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/EmitC/IR/EmitC.cpp (+142)
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+6-76)
- (added) mlir/test/Conversion/SCFToEmitC/for.mlir (+96)
- (modified) mlir/test/Dialect/EmitC/invalid_ops.mlir (+1-1)
- (modified) mlir/test/Dialect/EmitC/ops.mlir (+23-1)
- (modified) mlir/test/Target/Cpp/for.mlir (+42-14)
``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index 4dff26e23c42850..9d2ec0f41a75568 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -20,6 +20,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 827ffc0278fce1c..381247bcdadf27d 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -18,6 +18,7 @@ include "mlir/Dialect/EmitC/IR/EmitCTypes.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
@@ -430,7 +431,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
}
-def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
+def EmitC_YieldOp : EmitC_Op<"yield",
+ [Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
let summary = "block termination operation";
let description = [{
"yield" terminates blocks within EmitC control-flow operations. Since
@@ -444,6 +446,81 @@ def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]>
let assemblyFormat = [{ attr-dict }];
}
+def EmitC_ForOp : EmitC_Op<"for",
+ [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInits", "getSingleInductionVar", "getSingleLowerBound",
+ "getSingleStep", "getSingleUpperBound"]>,
+ AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"emitc::YieldOp">,
+ RecursiveMemoryEffects]> {
+ let summary = "for operation";
+ let description = [{
+ The `emitc.for` operation represents a loop taking 3 SSA values as operands
+ that represent the lower bound, upper bound and step respectively. The
+ operation defines an SSA value for its induction variable. It has one
+ region capturing the loop body. The induction variable is represented as an
+ argument of this region. This SSA value is a signless integer or index.
+ The step is a value of same type but required to be positive. The lower and
+ upper bounds specify a half-open range: the range includes the lower bound
+ but does not include the upper bound. This operation has no result.
+
+ The body region must contain exactly one block that terminates with
+ `emitc.yield`. Calling ForOp::build will create such a region and insert
+ the terminator implicitly if none is defined, so will the parsing even in
+ cases when it is absent from the custom format. For example:
+
+ ```mlir
+ // Index case.
+ emitc.for %iv = %lb to %ub step %step {
+ ... // body
+ }
+ ...
+ // Integer case.
+ emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
+ ... // body
+ }
+ ```
+ }];
+ let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
+ AnySignlessIntegerOrIndex:$upperBound,
+ AnySignlessIntegerOrIndex:$step);
+ let results = (outs);
+ let regions = (region SizedRegion<1>:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
+ CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
+ ];
+
+ let extraClassDeclaration = [{
+ using BodyBuilderFn =
+ function_ref<void(OpBuilder &, Location, Value)>;
+ Value getInductionVar() { return getBody()->getArgument(0); }
+ Block::BlockArgListType getRegionIterArgs() {
+ return Block::BlockArgListType();
+ }
+ FailureOr<LoopLikeOpInterface> replaceWithAdditionalYields(
+ RewriterBase &rewriter, ValueRange newInitOperands,
+ bool replaceInitOperandUsesInLoop,
+ const NewYieldValuesFn &newYieldValuesFn) {
+ return LoopLikeOpInterfaceTrait::replaceWithAdditionalYields(
+ rewriter, newInitOperands, replaceInitOperandUsesInLoop,
+ newYieldValuesFn);
+ }
+ void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
+ void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
+ void setStep(Value step) { getOperation()->setOperand(2, step); }
+ }];
+
+ let hasCanonicalizer = 1;
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+}
+
def EmitC_IfOp : EmitC_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 5d0d8df8869e313..22041344ba3d9a3 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -37,6 +37,106 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};
+// Lower scf::for to emitc::for, implementing return values using
+// emitc::variable's updated within loop body.
+struct ForLowering : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override;
+};
+
+// Create an uninitialized emitc::variable op for each result of given op.
+template <typename T>
+static SmallVector<Value> createVariablesForResults(T op,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> resultVariables;
+
+ if (!op.getNumResults())
+ return resultVariables;
+
+ Location loc = op->getLoc();
+ MLIRContext *context = op.getContext();
+
+ auto insertionPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+
+ for (OpResult result : op.getResults()) {
+ Type resultType = result.getType();
+ auto noInit = emitc::OpaqueAttr::get(context, "");
+ auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+ resultVariables.push_back(var);
+ }
+
+ rewriter.restoreInsertionPoint(insertionPoint);
+
+ return resultVariables;
+}
+
+// Create a series of assign ops assigning given values to given variables at
+// the current insertion point of given rewriter.
+static void assignValues(ValueRange values, SmallVector<Value> &variables,
+ PatternRewriter &rewriter, Location loc) {
+ for (auto value2Var : llvm::zip(values, variables)) {
+ Value value = std::get<0>(value2Var);
+ Value var = std::get<1>(value2Var);
+ rewriter.create<emitc::AssignOp>(loc, var, value);
+ }
+}
+
+static void lowerYield(SmallVector<Value> &resultVariables,
+ PatternRewriter &rewriter, scf::YieldOp yield) {
+ Location loc = yield.getLoc();
+ ValueRange operands = yield.getOperands();
+
+ auto insertionPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(yield);
+
+ assignValues(operands, resultVariables, rewriter, loc);
+
+ rewriter.create<emitc::YieldOp>(loc);
+ rewriter.restoreInsertionPoint(insertionPoint);
+ rewriter.eraseOp(yield);
+}
+
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const {
+ Location loc = forOp.getLoc();
+
+ rewriter.setInsertionPoint(forOp);
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables =
+ createVariablesForResults(forOp, rewriter);
+ SmallVector<Value> iterArgsVariables =
+ createVariablesForResults(forOp, rewriter);
+
+ assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
+
+ auto loweredFor = rewriter.create<emitc::ForOp>(
+ loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+
+ Block *loweredBody = loweredFor.getBody();
+
+ // Erase the auto-generated terminator for the lowered for op.
+ rewriter.eraseOp(loweredBody->getTerminator());
+
+ SmallVector<Value> replacingValues;
+ replacingValues.push_back(loweredFor.getInductionVar());
+ replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+
+ rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
+ lowerYield(iterArgsVariables, rewriter,
+ cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+ // Copy iterArgs into results after the for loop.
+ assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+
+ rewriter.replaceOp(forOp, resultVariables);
+ return success();
+}
+
// Lower scf::if to emitc::if, implementing return values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
@@ -52,41 +152,22 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();
- SmallVector<Value> resultVariables;
-
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
- if (ifOp.getNumResults()) {
- MLIRContext *context = ifOp.getContext();
- rewriter.setInsertionPoint(ifOp);
- for (OpResult result : ifOp.getResults()) {
- Type resultType = result.getType();
- auto noInit = emitc::OpaqueAttr::get(context, "");
- auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
- resultVariables.push_back(var);
- }
- }
-
- // Utility function to lower the contents of an scf::if region to an emitc::if
- // region. The contents of the scf::if regions is moved into the respective
- // emitc::if regions, but the scf::yield is replaced not only with an
- // emitc::yield, but also with a sequence of emitc::assign ops that set the
- // yielded values into the result variables.
- auto lowerRegion = [&resultVariables, &rewriter](Region ®ion,
- Region &loweredRegion) {
- rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
- Operation *terminator = loweredRegion.back().getTerminator();
- Location terminatorLoc = terminator->getLoc();
- ValueRange terminatorOperands = terminator->getOperands();
- rewriter.setInsertionPointToEnd(&loweredRegion.back());
- for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
- Value resultValue = std::get<0>(value2Var);
- Value resultVar = std::get<1>(value2Var);
- rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
- }
- rewriter.create<emitc::YieldOp>(terminatorLoc);
- rewriter.eraseOp(terminator);
- };
+ SmallVector<Value> resultVariables =
+ createVariablesForResults(ifOp, rewriter);
+
+ // Utility function to lower the contents of an scf::if region to an emitc::if
+ // region. The contents of the scf::if regions is moved into the respective
+ // emitc::if regions, but the scf::yield is replaced not only with an
+ // emitc::yield, but also with a sequence of emitc::assign ops that set the
+ // yielded values into the result variables.
+ auto lowerRegion = [&resultVariables, &rewriter](Region ®ion,
+ Region &loweredRegion) {
+ rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
+ Operation *terminator = loweredRegion.back().getTerminator();
+ lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
+ };
Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();
@@ -109,6 +190,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
}
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
}
@@ -118,7 +200,7 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::IfOp>();
+ target.addIllegalOp<scf::IfOp, scf::ForOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
index 4665c41a62e80b8..50e79d22d57e681 100644
--- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt
@@ -12,5 +12,6 @@ add_mlir_dialect_library(MLIREmitCDialect
MLIRCastInterfaces
MLIRControlFlowInterfaces
MLIRIR
+ MLIRLoopLikeInterface
MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 961a52a70a2a168..740504cc9db2bcd 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/Matchers.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -189,6 +190,147 @@ LogicalResult emitc::ConstantOp::verify() {
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// ForOp
+//===----------------------------------------------------------------------===//
+
+void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
+ Value ub, Value step, BodyBuilderFn bodyBuilder) {
+ result.addOperands({lb, ub, step});
+ Type t = lb.getType();
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArgument(t, result.location);
+
+ // Create the default terminator if the builder is not provided.
+ if (!bodyBuilder) {
+ ForOp::ensureTerminator(*bodyRegion, builder, result.location);
+ } else if (bodyBuilder) {
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ }
+}
+
+void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
+
+OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+ return RegionBranchOpInterfaceTrait::getEntrySuccessorOperands(point);
+}
+
+OperandRange ForOp::getInits() { return LoopLikeOpInterfaceTrait::getInits(); }
+
+SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
+
+std::optional<Value> ForOp::getSingleInductionVar() {
+ return getInductionVar();
+}
+
+std::optional<OpFoldResult> ForOp::getSingleLowerBound() {
+ return OpFoldResult(getLowerBound());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleStep() {
+ return OpFoldResult(getStep());
+}
+
+std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
+ return OpFoldResult(getUpperBound());
+}
+
+void ForOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both the operation itself and the region may be branching into the body or
+ // back into the operation itself. It is possible for loop not to enter the
+ // body.
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ regions.push_back(RegionSuccessor({}));
+}
+
+ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto &builder = parser.getBuilder();
+ Type type;
+
+ OpAsmParser::Argument inductionVariable;
+ OpAsmParser::UnresolvedOperand lb, ub, step;
+
+ // Parse the induction variable followed by '='.
+ if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
+ // Parse loop bounds.
+ parser.parseOperand(lb) || parser.parseKeyword("to") ||
+ parser.parseOperand(ub) || parser.parseKeyword("step") ||
+ parser.parseOperand(step))
+ return failure();
+
+ // Parse the optional initial iteration arguments.
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+ regionArgs.push_back(inductionVariable);
+
+ if (regionArgs.size() != result.types.size() + 1)
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of loop-carried values and defined values");
+
+ // Parse optional type, else assume Index.
+ if (parser.parseOptionalColon())
+ type = builder.getIndexType();
+ else if (parser.parseType(type))
+ return failure();
+
+ // Resolve input operands.
+ regionArgs.front().type = type;
+ if (parser.resolveOperand(lb, type, result.operands) ||
+ parser.resolveOperand(ub, type, result.operands) ||
+ parser.resolveOperand(step, type, result.operands))
+ return failure();
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ return success();
+}
+
+void ForOp::print(OpAsmPrinter &p) {
+ p << " " << getInductionVar() << " = " << getLowerBound() << " to "
+ << getUpperBound() << " step " << getStep();
+
+ p << ' ';
+ if (Type t = getInductionVar().getType(); !t.isIndex())
+ p << " : " << t << ' ';
+ p.printRegion(getRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+LogicalResult ForOp::verify() {
+ IntegerAttr step;
+ if (matchPattern(getStep(), m_Constant(&step)) && step.getInt() <= 0)
+ return emitOpError("constant step operand must be positive");
+
+ return success();
+}
+
+LogicalResult ForOp::verifyRegions() {
+ // Check that the body defines as single block argument for the induction
+ // variable.
+ if (getInductionVar().getType() != getLowerBound().getType())
+ return emitOpError(
+ "expected induction variable to be same type as bounds and step");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// IfOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4645ca4b206e78c..05d472ba50d070f 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -502,30 +502,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
+static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
raw_indented_ostream &os = emitter.ostream();
- OperandRange operands = forOp.getInitArgs();
- Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
- Operation::result_range results = forOp.getResults();
-
- if (!emitter.shouldDeclareVariablesAtTop()) {
- for (OpResult result : results) {
- if (failed(emitter.emitVariableDeclaration(result,
- /*trailingSemicolon=*/true)))
- return failure();
- }
- }
-
- for (auto pair : llvm::zip(iterArgs, operands)) {
- if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
- return failure();
- os << " " << emitter.getOrCreateName(std::get<0>(pair)) << " = ";
- os << emitter.getOrCreateName(std::get<1>(pair)) << ";";
- os << "\n";
- }
-
os << "for (";
if (failed(
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
@@ -548,35 +528,1...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/68206
More information about the Mlir-commits
mailing list