[Mlir-commits] [mlir] [mlir][emitc] Add 'emitc.while' and 'emitc.do' ops to the dialect (PR #143008)
Vlad Lazar
llvmlistbot at llvm.org
Tue Aug 19 06:15:25 PDT 2025
================
@@ -332,11 +332,260 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+// Lower scf::while to either emitc::while or emitc::do based on argument usage
+// patterns. Uses mutable variables to maintain loop state across iterations.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> variables;
+ if (failed(
+ createInitVariables(whileOp, rewriter, variables, loc, context))) {
+ return failure();
+ }
+
+ // Select lowering strategy based on condition argument usage:
+ // - emitc.while when condition args match region inputs (direct mapping);
+ // - emitc.do when condition args differ (requires state synchronization).
+ Region &beforeRegion = adaptor.getBefore();
+ Block &beforeBlock = beforeRegion.front();
+ auto condOp = cast<scf::ConditionOp>(beforeRegion.back().getTerminator());
+
+ bool isDoOp = !llvm::equal(beforeBlock.getArguments(), condOp.getArgs());
+
+ LogicalResult result =
+ isDoOp ? lowerDoWhile(whileOp, variables, context, rewriter, loc)
+ : lowerWhile(whileOp, variables, context, rewriter, loc);
+
+ if (failed(result))
+ return failure();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Transfer final loop state to result variables and get final SSA results.
+ SmallVector<Value> finalResults =
+ finalizeLoopResults(resultVariables, variables, rewriter, loc);
+
+ rewriter.replaceOp(whileOp, finalResults);
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ static LogicalResult createInitVariables(WhileOp whileOp,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &outVars,
+ Location loc, MLIRContext *context) {
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(init.getType()), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ outVars.push_back(var.getResult());
+ }
+
+ return success();
+ }
+
+ // Transition from SSA block arguments to variable-based state management by
+ // replacing argument uses with variable loads and cleaning up block
+ // interface.
+ void replaceBlockArgsWithVarLoads(Block *block, ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ rewriter.setInsertionPointToStart(block);
+
+ for (auto [arg, var] : llvm::zip(block->getArguments(), vars)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ arg.replaceAllUsesWith(load);
+ }
+
+ // Remove arguments after replacement to simplify block structure.
+ block->eraseArguments(0, block->getNumArguments());
+ }
+
+ // Convert SCF yield terminators to imperative assignments to update loop
+ // variables, maintaining loop semantics while transitioning to emitc model.
+ void processYieldTerminator(Operation *terminator, ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ auto yieldOp = cast<scf::YieldOp>(terminator);
+ SmallVector<Value> yields(yieldOp.getOperands());
+ rewriter.eraseOp(yieldOp);
+
+ rewriter.setInsertionPointToEnd(yieldOp->getBlock());
+ for (auto [var, val] : llvm::zip(vars, yields))
+ rewriter.create<emitc::AssignOp>(loc, var, val);
+ }
+
+ // Transfers final loop state from mutable variables to result variables,
+ // then returns the final SSA values to replace the original scf::while
+ // results.
+ static SmallVector<Value>
+ finalizeLoopResults(ArrayRef<Value> resultVariables,
+ ArrayRef<Value> loopVariables,
+ ConversionPatternRewriter &rewriter, Location loc) {
+ // Transfer final loop state to result variables to bridge imperative loop
+ // variables with SSA result expectations of the original op.
+ for (auto [resultVar, var] : llvm::zip(resultVariables, loopVariables)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ rewriter.create<emitc::AssignOp>(loc, resultVar, load);
+ }
+
+ // Replace op with loaded values to integrate with converted SSA graph.
+ SmallVector<Value> finalResults;
+ for (Value resultVar : resultVariables) {
+ Type loadedType =
+ cast<emitc::LValueType>(resultVar.getType()).getValueType();
+ finalResults.push_back(
+ rewriter.create<emitc::LoadOp>(loc, loadedType, resultVar));
+ }
+
+ return finalResults;
+ }
+
+ // Direct lowering to emitc.while when condition arguments match region
+ // inputs.
+ LogicalResult lowerWhile(WhileOp whileOp, ArrayRef<Value> vars,
+ MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ auto loweredWhile = rewriter.create<emitc::WhileOp>(loc);
+
+ // Lower before region to condition region.
+ Region &condRegion = loweredWhile.getConditionRegion();
+ Block *condBlock = rewriter.createBlock(&condRegion);
+ rewriter.setInsertionPointToStart(condBlock);
+
+ Type i1Type = IntegerType::get(context, 1);
+ auto exprOp = rewriter.create<emitc::ExpressionOp>(loc, TypeRange{i1Type});
+ Region &exprRegion = exprOp.getBodyRegion();
+
+ rewriter.inlineRegionBefore(whileOp.getBefore(), exprRegion,
----------------
Vladislave0-0 wrote:
> The `before` block is not guaranteed to contain a valid C expression, e.g.:
>
> ```mlir
> func.func @double_use(%p : !emitc.ptr<i32>) -> i32 {
> %init = emitc.literal "1.0" : i32
> %var = emitc.literal "1.0" : i32
> %exit = emitc.literal "10.0" : i32
> %res = scf.while (%arg1 = %init) : (i32) -> i32 {
> %used_twice = emitc.call @payload_with_side_effect(%arg1, %p) : (i32, !emitc.ptr<i32>) -> i32
> %prod = emitc.add %used_twice, %used_twice : (i32, i32) -> i32
> %sum = emitc.add %arg1, %prod : (i32, i32) -> i32
> %condition = emitc.cmp lt, %sum, %exit : (i32, i32) -> i1
> scf.condition(%condition) %arg1 : i32
> } do {
> ^bb0(%arg2: i32):
> %next_arg1 = emitc.call @payload_do(%arg2) : (i32) -> i32
> scf.yield %next_arg1 : i32
> }
> return %res : i32
> }
> ```
>
> It may also include ops not related to the condition at all. The `emitc.expression` op can in principle be extended to support such sequences using the comma operator, but it currently doesn't (I'm also not sure it'd be very aesthetic)
Thank you for noticing this bug. Given that we can not determine whether it'll really translate to `CExpression` that won't have side effects or translate at all, what do you think about always translating to `do-while` style loop as we do now when handling `scf.while`'s `do-while` style condition,
```
bool isDoOp = !llvm::equal(beforeBlock.getArguments(), condOp.getArgs());
LogicalResult result =
isDoOp ? lowerDoWhile(whileOp, variables, context, rewriter, loc)
: lowerWhile(whileOp, variables, context, rewriter, loc);
```
where such issues are eliminated due to translating into `while` body, which is not required to be a `CExpression`? That way the translation of the `scf.while` itself will always work and it'll behave pretty much the same as it does right now.
As a downside, the new `emitc.while` will be "dead", since it's not part of the translation from the higher dialects, but tbf we can determine after `form-expression` pass, which `emitc.do` can be converted to `while` operation, since we have a very specific pattern of `do { cond_res = cond; if (cond_res) while (cond_res) }`. And do that as a separate optimization.
https://github.com/llvm/llvm-project/pull/143008
More information about the Mlir-commits
mailing list