[Mlir-commits] [mlir] [mlir][emitc] Add 'emitc.while' and 'emitc.do' ops to the dialect (PR #143008)
Vlad Lazar
llvmlistbot at llvm.org
Wed Sep 17 06:04:03 PDT 2025
================
@@ -336,11 +337,214 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+// Lower scf::while to emitc::do using mutable variables to maintain loop state
+// across iterations. The do-while structure ensures the condition is evaluated
+// after each iteration, matching SCF while semantics.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> variables;
+ if (failed(createInitVariables(whileOp, rewriter, variables, loc, context)))
+ return failure();
+
+ if (failed(lowerDoWhile(whileOp, variables, context, rewriter, loc)))
+ return failure();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+ }
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Transfer final loop state to result variables and get final SSA results.
+ SmallVector<Value> finalResults =
+ finalizeLoopResults(resultVariables, variables, rewriter, loc);
+
+ rewriter.replaceOp(whileOp, finalResults);
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ static LogicalResult createInitVariables(WhileOp whileOp,
+ ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &outVars,
+ Location loc, MLIRContext *context) {
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(init.getType()), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ outVars.push_back(var.getResult());
+ }
+
+ return success();
+ }
+
+ // Transition from SSA block arguments to variable-based state management by
+ // replacing argument uses with variable loads and cleaning up block
+ // interface.
+ void replaceBlockArgsWithVarLoads(Block *block, ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ rewriter.setInsertionPointToStart(block);
+
+ for (auto [arg, var] : llvm::zip(block->getArguments(), vars)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ arg.replaceAllUsesWith(load);
+ }
+
+ // Remove arguments after replacement to simplify block structure.
+ block->eraseArguments(0, block->getNumArguments());
+ }
+
+ // Convert SCF yield terminators to imperative assignments to update loop
+ // variables, maintaining loop semantics while transitioning to emitc model.
+ LogicalResult processYieldTerminator(Operation *terminator,
+ ArrayRef<Value> vars,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ auto yieldOp = cast<scf::YieldOp>(terminator);
+ SmallVector<Value> yields;
+ if (failed(rewriter.getRemappedValues(yieldOp.getOperands(), yields)))
+ return rewriter.notifyMatchFailure(yieldOp,
+ "failed to lower yield operands");
+ rewriter.eraseOp(yieldOp);
+
+ rewriter.setInsertionPointToEnd(yieldOp->getBlock());
+ for (auto [var, val] : llvm::zip(vars, yields))
+ rewriter.create<emitc::AssignOp>(loc, var, val);
+
+ return success();
+ }
+
+ // Transfers final loop state from mutable variables to result variables,
+ // then returns the final SSA values to replace the original scf::while
+ // results.
+ static SmallVector<Value>
+ finalizeLoopResults(ArrayRef<Value> resultVariables,
+ ArrayRef<Value> loopVariables,
+ ConversionPatternRewriter &rewriter, Location loc) {
+ // Transfer final loop state to result variables to bridge imperative loop
+ // variables with SSA result expectations of the original op.
+ for (auto [resultVar, var] : llvm::zip(resultVariables, loopVariables)) {
+ Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+ Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+ rewriter.create<emitc::AssignOp>(loc, resultVar, load);
+ }
+
+ // Replace op with loaded values to integrate with converted SSA graph.
+ SmallVector<Value> finalResults;
+ for (Value resultVar : resultVariables) {
+ Type loadedType =
+ cast<emitc::LValueType>(resultVar.getType()).getValueType();
+ finalResults.push_back(
+ rewriter.create<emitc::LoadOp>(loc, loadedType, resultVar));
+ }
+
+ return finalResults;
+ }
+
+ // Lower scf.while to emitc.do.
+ LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> vars,
+ MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ Type i1Type = IntegerType::get(context, 1);
+ auto globalCondition =
+ rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
+ Value conditionVal = globalCondition.getResult();
+
+ auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+
+ // Lower before region as body.
+ rewriter.inlineRegionBefore(whileOp.getBefore(), loweredDo.getBodyRegion(),
+ loweredDo.getBodyRegion().end());
+
+ Block *bodyBlock = &loweredDo.getBodyRegion().front();
+ replaceBlockArgsWithVarLoads(bodyBlock, vars, rewriter, loc);
+
+ // Convert scf.condition to condition variable assignment.
+ Operation *condTerminator =
+ loweredDo.getBodyRegion().back().getTerminator();
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+
+ // Wrap body region in conditional to preserve scf semantics.
+ auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+
+ // Lower after region as then-block of conditional.
+ rewriter.inlineRegionBefore(whileOp.getAfter(), ifOp.getBodyRegion(),
+ ifOp.getBodyRegion().begin());
+
+ if (!ifOp.getBodyRegion().empty()) {
----------------
Vladislave0-0 wrote:
Thanks for noticing. Fixed it.
https://github.com/llvm/llvm-project/pull/143008
More information about the Mlir-commits
mailing list