[Mlir-commits] [mlir] [mlir][emitc] Add 'emitc.while' and 'emitc.do' ops to the dialect (PR #143008)

Gil Rapaport llvmlistbot at llvm.org
Thu Aug 21 08:22:25 PDT 2025


================
@@ -332,11 +332,260 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
   return success();
 }
 
+// Lower scf::while to either emitc::while or emitc::do based on argument usage
+// patterns. Uses mutable variables to maintain loop state across iterations.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = whileOp.getLoc();
+    MLIRContext *context = loc.getContext();
+
+    // Create variable storage for loop-carried values to enable imperative
+    // updates while maintaining SSA semantics at conversion boundaries.
+    SmallVector<Value> variables;
+    if (failed(
+            createInitVariables(whileOp, rewriter, variables, loc, context))) {
+      return failure();
+    }
+
+    // Select lowering strategy based on condition argument usage:
+    // - emitc.while when condition args match region inputs (direct mapping);
+    // - emitc.do when condition args differ (requires state synchronization).
+    Region &beforeRegion = adaptor.getBefore();
+    Block &beforeBlock = beforeRegion.front();
+    auto condOp = cast<scf::ConditionOp>(beforeRegion.back().getTerminator());
+
+    bool isDoOp = !llvm::equal(beforeBlock.getArguments(), condOp.getArgs());
+
+    LogicalResult result =
+        isDoOp ? lowerDoWhile(whileOp, variables, context, rewriter, loc)
+               : lowerWhile(whileOp, variables, context, rewriter, loc);
+
+    if (failed(result))
+      return failure();
+
+    // Create an emitc::variable op for each result. These variables will be
+    // assigned to by emitc::assign ops within the loop body.
+    SmallVector<Value> resultVariables;
+    if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+                                         resultVariables))) {
+      return rewriter.notifyMatchFailure(whileOp,
+                                         "Failed to create result variables");
+    }
+
+    rewriter.setInsertionPointAfter(whileOp);
+
+    // Transfer final loop state to result variables and get final SSA results.
+    SmallVector<Value> finalResults =
+        finalizeLoopResults(resultVariables, variables, rewriter, loc);
+
+    rewriter.replaceOp(whileOp, finalResults);
+    return success();
+  }
+
+private:
+  // Initialize variables for loop-carried values to enable state updates
+  // across iterations without SSA argument passing.
+  static LogicalResult createInitVariables(WhileOp whileOp,
+                                           ConversionPatternRewriter &rewriter,
+                                           SmallVectorImpl<Value> &outVars,
+                                           Location loc, MLIRContext *context) {
+    emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+    for (Value init : whileOp.getInits()) {
+      emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+          loc, emitc::LValueType::get(init.getType()), noInit);
+      rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+      outVars.push_back(var.getResult());
+    }
+
+    return success();
+  }
+
+  // Transition from SSA block arguments to variable-based state management by
+  // replacing argument uses with variable loads and cleaning up block
+  // interface.
+  void replaceBlockArgsWithVarLoads(Block *block, ArrayRef<Value> vars,
+                                    ConversionPatternRewriter &rewriter,
+                                    Location loc) const {
+    rewriter.setInsertionPointToStart(block);
+
+    for (auto [arg, var] : llvm::zip(block->getArguments(), vars)) {
+      Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+      Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+      arg.replaceAllUsesWith(load);
+    }
+
+    // Remove arguments after replacement to simplify block structure.
+    block->eraseArguments(0, block->getNumArguments());
+  }
+
+  // Convert SCF yield terminators to imperative assignments to update loop
+  // variables, maintaining loop semantics while transitioning to emitc model.
+  void processYieldTerminator(Operation *terminator, ArrayRef<Value> vars,
+                              ConversionPatternRewriter &rewriter,
+                              Location loc) const {
+    auto yieldOp = cast<scf::YieldOp>(terminator);
+    SmallVector<Value> yields(yieldOp.getOperands());
+    rewriter.eraseOp(yieldOp);
+
+    rewriter.setInsertionPointToEnd(yieldOp->getBlock());
+    for (auto [var, val] : llvm::zip(vars, yields))
+      rewriter.create<emitc::AssignOp>(loc, var, val);
+  }
+
+  // Transfers final loop state from mutable variables to result variables,
+  // then returns the final SSA values to replace the original scf::while
+  // results.
+  static SmallVector<Value>
+  finalizeLoopResults(ArrayRef<Value> resultVariables,
+                      ArrayRef<Value> loopVariables,
+                      ConversionPatternRewriter &rewriter, Location loc) {
+    // Transfer final loop state to result variables to bridge imperative loop
+    // variables with SSA result expectations of the original op.
+    for (auto [resultVar, var] : llvm::zip(resultVariables, loopVariables)) {
+      Type loadedType = cast<emitc::LValueType>(var.getType()).getValueType();
+      Value load = rewriter.create<emitc::LoadOp>(loc, loadedType, var);
+      rewriter.create<emitc::AssignOp>(loc, resultVar, load);
+    }
+
+    // Replace op with loaded values to integrate with converted SSA graph.
+    SmallVector<Value> finalResults;
+    for (Value resultVar : resultVariables) {
+      Type loadedType =
+          cast<emitc::LValueType>(resultVar.getType()).getValueType();
+      finalResults.push_back(
+          rewriter.create<emitc::LoadOp>(loc, loadedType, resultVar));
+    }
+
+    return finalResults;
+  }
+
+  // Direct lowering to emitc.while when condition arguments match region
+  // inputs.
+  LogicalResult lowerWhile(WhileOp whileOp, ArrayRef<Value> vars,
+                           MLIRContext *context,
+                           ConversionPatternRewriter &rewriter,
+                           Location loc) const {
+    auto loweredWhile = rewriter.create<emitc::WhileOp>(loc);
+
+    // Lower before region to condition region.
+    Region &condRegion = loweredWhile.getConditionRegion();
+    Block *condBlock = rewriter.createBlock(&condRegion);
+    rewriter.setInsertionPointToStart(condBlock);
+
+    Type i1Type = IntegerType::get(context, 1);
+    auto exprOp = rewriter.create<emitc::ExpressionOp>(loc, TypeRange{i1Type});
+    Region &exprRegion = exprOp.getBodyRegion();
+
+    rewriter.inlineRegionBefore(whileOp.getBefore(), exprRegion,
----------------
aniragil wrote:

>I wonder what would be the correct way for those things to get translated, given that it may contain operations that can translate to CExpression, but at the time of the transformation, are not CExpression yet, like if you have arith, etc. I thought, that you just translate it anyway, and then it may fail later on due to incomplete translations. Like it's not like emitc.while won't fail now when translating to cpp, since it can not be emitted at all.

Sorry, I wasn't clear: I didn't mean ops of other dialects like `arith`, but `emitc` ops that do not form a single expression (at least not without using C's `comma` operator). For instance, in the following code:
```C
while (true) {
   a[i+1] = b[i*7] + 3;
   foo(t[i] / 5);
   int c = k[i + 11];
   if (c)
     break;
   // do some more work with loads, stores, calls etc.
}
```
the "before" section includes computations unrelated to the condition, so it's not a "classic" while loop where the exit condition is checked before any computation. Since the `while` condition clause requires a single expression, putting all this code there requires turning it into a single expression. This can be done using the comma operator (which `emitc` currently doesn't support), i.e.
```C
while (a[i+1] = b[i*7] + 3, foo(t[i] / 5), k[i + 11]) {
   // do some more work with loads, stores, calls etc.
}
```
But that's far less clear than the original structure IMO.

> what do you think about always translating to do-while style loop ... where such issues are eliminated due to translating into while body, which is not required to be a CExpression? That way the translation of the scf.while itself will always work and it'll behave pretty much the same as it does right now. As a downside, the new emitc.while will be "dead", since it's not part of the translation from the higher dialects, but tbf we can determine after form-expression pass, which emitc.do can be converted to while operation, since we have a very specific pattern of do { cond_res = cond; if (cond_res) while (cond_res) }. And do that as a separate optimization.

Agreed, best to start with a simple and robust lowering. We shouldn't have dead ops, so I'd do this in stages - first a patch to introduce a single loop op to lower `scf.while` in a unified manner. I'm OK with starting either with `emitc.do` or `emitc.while` (with the condition variable initialized to `1`). In any case, lowering can create in the condition region a simple `emitc.expression` that only loads from the condition variable.
The form-expressions pass (if executed) should indeed fold the computation of the condition in the loop body.
A second patch can then introduce a new pass to optimize the loops in both directions, i.e. identify the `emitc.expression` that sets the condition variable and push it down `do {} while (/*HERE*/)` or push it up `while (/*HERE*/) {}` if possible, removing the condition variable. WDYT @kchibisov?

https://github.com/llvm/llvm-project/pull/143008


More information about the Mlir-commits mailing list