[Mlir-commits] [mlir] [SCF][Transform] Add support for scf.for in LoopFuseSibling op (PR #81495)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Wed Mar 27 09:52:58 PDT 2024


================
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
   unsigned numTargetOuts = target.getNumResults();
   unsigned numSourceOuts = source.getNumResults();
 
-  OperandRange targetOuts = target.getOutputs();
-  OperandRange sourceOuts = source.getOutputs();
-
   // Create fused shared_outs.
   SmallVector<Value> fusedOuts;
-  fusedOuts.reserve(numTargetOuts + numSourceOuts);
-  fusedOuts.append(targetOuts.begin(), targetOuts.end());
-  fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
+  llvm::append_range(fusedOuts, target.getOutputs());
+  llvm::append_range(fusedOuts, source.getOutputs());
 
-  // Create a new scf::forall op after the source loop.
+  // Create a new scf.forall op after the source loop.
   rewriter.setInsertionPointAfter(source);
   scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
       source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
       source.getMixedStep(), fusedOuts, source.getMapping());
 
   // Map control operands.
-  IRMapping fusedMapping;
-  fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
-  fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
+  IRMapping mapping;
+  mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
+  mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
 
   // Map shared outs.
-  fusedMapping.map(target.getRegionIterArgs(),
-                   fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
-  fusedMapping.map(
-      source.getRegionIterArgs(),
-      fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
   for (Operation &op : target.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : source.getBody()->without_terminator())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
   scf::InParallelOp targetTerm = target.getTerminator();
   scf::InParallelOp sourceTerm = source.getTerminator();
   scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
-
   rewriter.setInsertionPointToStart(fusedTerm.getBody());
   for (Operation &op : targetTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
+    rewriter.clone(op, mapping);
   for (Operation &op : sourceTerm.getYieldingOps())
-    rewriter.clone(op, fusedMapping);
-
-  // Replace all uses of the old loops with the fused loop.
-  rewriter.replaceAllUsesWith(target.getResults(),
-                              fusedLoop.getResults().slice(0, numTargetOuts));
-  rewriter.replaceAllUsesWith(
-      source.getResults(),
-      fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
-
-  // Erase the old loops.
-  rewriter.eraseOp(target);
-  rewriter.eraseOp(source);
+    rewriter.clone(op, mapping);
+
+  // Replace old loops by substituting their uses by results of the fused loop.
+  rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
+  rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
+
+  return fusedLoop;
+}
+
+scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
+                                                scf::ForOp source,
+                                                RewriterBase &rewriter) {
+  unsigned numTargetOuts = target.getNumResults();
+  unsigned numSourceOuts = source.getNumResults();
+
+  // Create fused init_args, with target's init_args before source's init_args.
+  SmallVector<Value> fusedInitArgs;
+  llvm::append_range(fusedInitArgs, target.getInitArgs());
+  llvm::append_range(fusedInitArgs, source.getInitArgs());
+
+  // Create a new scf.for op after the source loop (with scf.yield terminator
+  // (without arguments) only in case its init_args is empty).
+  rewriter.setInsertionPointAfter(source);
+  scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
+      source.getLoc(), source.getLowerBound(), source.getUpperBound(),
+      source.getStep(), fusedInitArgs);
+
+  // Map original induction variables and operands to those of the fused loop.
+  IRMapping mapping;
+  mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(target.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
+  mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
+  mapping.map(source.getRegionIterArgs(),
+              fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
+
+  // Merge target's body into the new (fused) for loop and then source's body.
+  rewriter.setInsertionPointToStart(fusedLoop.getBody());
+  for (Operation &op : target.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+  for (Operation &op : source.getBody()->without_terminator())
+    rewriter.clone(op, mapping);
+
+  // Build fused yield results by appropriately mapping original yield operands.
+  SmallVector<Value> yieldResults;
+  for (Value operand : target.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  for (Value operand : source.getBody()->getTerminator()->getOperands())
+    yieldResults.push_back(mapping.lookupOrDefault(operand));
+  if (yieldResults.size())
----------------
ftynse wrote:

clang-tidy: prefer `.empty()` to check for emptiness.

https://github.com/llvm/llvm-project/pull/81495


More information about the Mlir-commits mailing list