[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 2 15:00:42 PDT 2024


================
@@ -1509,6 +1535,199 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
   }
 };
 
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+///  INPUT:
+///   %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                <SOME USE OF %arg2>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                    <STORE OP WITH DESTINATION %arg0>
+///                    <STORE OP WITH DESTINATION %arg2>
+///                }
+///             }
+///   return %res#1
+///
+///  OUTPUT:
+///   %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                <SOME USE OF %c>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+///          scf.forall is replaced by their corresponding operands.
+///       2. The canonicalization assumes that there are no <STORE OP WITH
+///          DESTINATION *> ops within the body of the scf.forall except within
+///          scf.forall.in_parallel terminator.
+///       3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+///          within scf.forall.in_parallel - the code below takes care of this
+///          by traversing the uses of the corresponding iter arg.
+///       4. TODO(avarma): Generalize it for other store ops. Currently it
+///          handles tensor.parallel_insert_slice ops only.
+///
+/// Example of second case :-
+///  INPUT:
+///   %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                }
+///             }
+///   return %res#0, %res#1
+///
+///  OUTPUT:
+///   %res = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForallOp forallOp,
+                                PatternRewriter &rewriter) const final {
+    scf::InParallelOp terminatorOp = forallOp.getTerminator();
+    SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+        terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+
+    // The following check should indeed be part of SCF::ForallOp::verify.
+    SmallVector<SubsetInsertionOpInterface> subsetInsertionOpInterfaceOps;
+    for (Operation *op : yieldingOps) {
+      if (auto subsetInsertionOpInterfaceOp =
+              dyn_cast<SubsetInsertionOpInterface>(op)) {
+        subsetInsertionOpInterfaceOps.push_back(subsetInsertionOpInterfaceOp);
+        continue;
+      }
+      return failure();
+    }
+
+    // Step 1: For a given i-th result of scf.forall, check the following :-
+    //         a. If it has any use.
+    //         b. If the corresponding iter argument is being modified within
+    //            the loop, i.e. fetch a unique store op.
+    //
+    //         Based on the check we maintain the following :-
+    //         a. `resultToDelete` - i-th result of scf.forall that'll be
+    //            deleted.
+    //         b. `resultToReplace` - i-th result of the old scf.forall
+    //            whose uses will be replaced by the new scf.forall.
+    //         c. `newOuts` - the shared_outs' operand of the new scf.forall
+    //            corresponding to the i-th result with at least one use.
+    //         d. `mapping` - mapping the old iter block argument of scf.forall
+    //            with the corresponding shared_outs' operand. This will be
+    //            used when creating a new scf.forall op.
+    SmallVector<OpResult> resultToDelete;
+    SmallVector<Value> resultToReplace;
+    SmallVector<Value> newOuts;
+    IRMapping mapping;
+    for (OpResult result : forallOp.getResults()) {
+      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+      if (result.use_empty() || failed(forallOp.getStoreOpUser(blockArg))) {
+        resultToDelete.push_back(result);
+        mapping.map(blockArg, opOperand->get());
+      } else {
+        resultToReplace.push_back(result);
+        newOuts.push_back(opOperand->get());
+      }
+    }
+
+    // Return early if all results of scf.forall has at least one use and being
+    // modified within the loop.
+    if (resultToDelete.empty()) {
+      return failure();
+    }
+
+    // Step 2: For the the i-th result, do the following :-
+    //         a. Fetch the corresponding BlockArgument.
+    //         b. Look for a unique store op (currently
+    //            tensor.parallel_insert_slice) with the BlockArgument as its
+    //            Destination operand.
+    //         c. Remove the operation fetched in b.
+    for (OpResult result : resultToDelete) {
+      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+      FailureOr<SubsetInsertionOpInterface> storeOp =
+          forallOp.getStoreOpUser(blockArg);
+      rewriter.replaceAllUsesWith(blockArg, opOperand->get());
+      if (failed(storeOp)) {
+        continue;
+      }
+      rewriter.eraseOp(storeOp.value());
+    }
+
+    // Step 3. Create a new scf.forall op with the new shared_outs' operands
+    //         fetched earlier
+    auto newforallOp = rewriter.create<scf::ForallOp>(
+        forallOp.getLoc(), forallOp.getMixedLowerBound(),
+        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+        forallOp.getMapping());
+
+    // Step 4. Merge the block of the old scf.forall into the newly created
+    //         scf.forall using the new set of arguments.
+    Block *loopBody = forallOp.getBody();
+    Block *newLoopBody = newforallOp.getBody();
+    ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
+    SmallVector<Value> newBlockArgs =
+        llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
+                            [](BlockArgument b) -> Value { return b; });
+    Block::BlockArgListType newSharedOutsArgs = newforallOp.getRegionOutArgs();
+    unsigned index = 0;
+    for (BlockArgument bbArg : forallOp.getRegionOutArgs()) {
----------------
MaheshRavishankar wrote:

Instead, iterate over the results. For any results you need to delete (computed in `resultsToDelete` above which is now a `SetVector`) use the opoperand, for others use the shared out args. You basically dont need the `IRMapping`.

https://github.com/llvm/llvm-project/pull/90189


More information about the Mlir-commits mailing list