[Mlir-commits] [mlir] [mlir][emitc] Add a structured for operation (PR #68206)

Simon Camphausen llvmlistbot at llvm.org
Mon Oct 23 06:57:43 PDT 2023


================
@@ -37,6 +37,106 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
   void runOnOperation() override;
 };
 
+// Lower scf::for to emitc::for, implementing return values using
+// emitc::variable's updated within loop body.
+struct ForLowering : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+// Create an uninitialized emitc::variable op for each result of given op.
+template <typename T>
+static SmallVector<Value> createVariablesForResults(T op,
+                                                    PatternRewriter &rewriter) {
+  SmallVector<Value> resultVariables;
+
+  if (!op.getNumResults())
+    return resultVariables;
+
+  Location loc = op->getLoc();
+  MLIRContext *context = op.getContext();
+
+  auto insertionPoint = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPoint(op);
+
+  for (OpResult result : op.getResults()) {
+    Type resultType = result.getType();
+    auto noInit = emitc::OpaqueAttr::get(context, "");
+    auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+    resultVariables.push_back(var);
+  }
+
+  rewriter.restoreInsertionPoint(insertionPoint);
+
+  return resultVariables;
+}
+
+// Create a series of assign ops assigning given values to given variables at
+// the current insertion point of given rewriter.
+static void assignValues(ValueRange values, SmallVector<Value> &variables,
+                         PatternRewriter &rewriter, Location loc) {
+  for (auto value2Var : llvm::zip(values, variables)) {
+    Value value = std::get<0>(value2Var);
+    Value var = std::get<1>(value2Var);
+    rewriter.create<emitc::AssignOp>(loc, var, value);
+  }
+}
+
+static void lowerYield(SmallVector<Value> &resultVariables,
+                       PatternRewriter &rewriter, scf::YieldOp yield) {
+  Location loc = yield.getLoc();
+  ValueRange operands = yield.getOperands();
+
+  auto insertionPoint = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPoint(yield);
+
+  assignValues(operands, resultVariables, rewriter, loc);
+
+  rewriter.create<emitc::YieldOp>(loc);
+  rewriter.restoreInsertionPoint(insertionPoint);
+  rewriter.eraseOp(yield);
+}
+
+LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
+                                           PatternRewriter &rewriter) const {
+  Location loc = forOp.getLoc();
+
+  rewriter.setInsertionPoint(forOp);
+
+  // Create an emitc::variable op for each result. These variables will be
+  // assigned to by emitc::assign ops within the loop body.
+  SmallVector<Value> resultVariables =
+      createVariablesForResults(forOp, rewriter);
+  SmallVector<Value> iterArgsVariables =
+      createVariablesForResults(forOp, rewriter);
+
+  assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
+
+  auto loweredFor = rewriter.create<emitc::ForOp>(
+      loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
+
+  Block *loweredBody = loweredFor.getBody();
+
+  // Erase the auto-generated terminator for the lowered for op.
+  rewriter.eraseOp(loweredBody->getTerminator());
+
+  SmallVector<Value> replacingValues;
+  replacingValues.push_back(loweredFor.getInductionVar());
+  replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+
+  rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
+  lowerYield(iterArgsVariables, rewriter,
+             cast<scf::YieldOp>(loweredBody->getTerminator()));
+
+  // Copy iterArgs into results after the for loop.
+  assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+
+  rewriter.replaceOp(forOp, resultVariables);
+  return success();
+}
+
 // Lower scf::if to emitc::if, implementing return values as emitc::variable's
----------------
simon-camp wrote:

```suggestion
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
```

https://github.com/llvm/llvm-project/pull/68206


More information about the Mlir-commits mailing list