[Mlir-commits] [mlir] [MLIR][SCF] Add canonicalization pattern to fold away iter args of scf.forall (PR #90189)

Abhishek Varma llvmlistbot at llvm.org
Mon May 6 01:33:13 PDT 2024


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/90189

>From f71cf509c45f61f106b5ffd7974d17377b6e646f Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 25 Apr 2024 11:01:04 +0000
Subject: [PATCH 1/2] [MLIR][SCF] Add canonicalization pattern to fold away
 iter args of scf.forall

-- This commit adds a canonicalization pattern to fold away iter args
   of scf.forall if :-
   a. The corresponding tied result has no use.
   b. It is not being modified within the loop.

Signed-off-by: Abhishek Varma <avarma094 at gmail.com>
---
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt  |   1 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 200 +++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir |  80 ++++++++++
 3 files changed, 280 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 423e1c3e1e042c..6e5d80078e8022 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
+  MLIRSubsetOpInterface
   MLIRTensorDialect
   MLIRValueBoundsOpInterface
   )
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 7a1aafc9f1c2f9..355cfc8b3ee626 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
+#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -1509,6 +1510,203 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
   }
 };
 
+/// The following canonicalization pattern folds the iter arguments of
+/// scf.forall op if :-
+/// 1. The corresponding result has zero uses.
+/// 2. The iter argument is NOT being modified within the loop body.
+/// uses.
+///
+/// Example of first case :-
+///  INPUT:
+///   %res:3 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b, %arg2 = %c)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                <SOME USE OF %arg2>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                    <STORE OP WITH DESTINATION %arg0>
+///                    <STORE OP WITH DESTINATION %arg2>
+///                }
+///             }
+///   return %res#1
+///
+///  OUTPUT:
+///   %res:3 = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                <SOME USE OF %c>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %res
+///
+/// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
+///          scf.forall is replaced by their corresponding operands.
+///       2. The canonicalization assumes that there are no <STORE OP WITH
+///          DESTINATION *> ops within the body of the scf.forall except within
+///          scf.forall.in_parallel terminator.
+///       3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
+///          within scf.forall.in_parallel - the code below takes care of this
+///          by traversing the uses of the corresponding iter arg.
+///
+/// Example of second case :-
+///  INPUT:
+///   %res:2 = scf.forall ... shared_outs(%arg0 = %a, %arg1 = %b)
+///            {
+///                ...
+///                <SOME USE OF %arg0>
+///                <SOME USE OF %arg1>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %arg1>
+///                }
+///             }
+///   return %res#0, %res#1
+///
+///  OUTPUT:
+///   %res = scf.forall ... shared_outs(%new_arg0 = %b)
+///            {
+///                ...
+///                <SOME USE OF %a>
+///                <SOME USE OF %new_arg0>
+///                ...
+///                scf.forall.in_parallel {
+///                    <STORE OP WITH DESTINATION %new_arg0>
+///                }
+///             }
+///   return %a, %res
+struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
+  using OpRewritePattern<ForallOp>::OpRewritePattern;
+
+  /// Utility function that checks if a candidate value satisifies any of the
+  /// conditions (see above doc comment) to make it viable for folding away.
+  static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
+    if (result.use_empty()) {
+      return true;
+    }
+    Value::user_range users = blockArg.getUsers();
+    return llvm::all_of(users, [&](Operation *user) {
+      return !isa<SubsetInsertionOpInterface>(user);
+    });
+  }
+
+  LogicalResult matchAndRewrite(ForallOp forallOp,
+                                PatternRewriter &rewriter) const final {
+    scf::InParallelOp terminatorOp = forallOp.getTerminator();
+    SmallVector<Operation *> yieldingOps = llvm::map_to_vector(
+        terminatorOp.getYieldingOps(), [](Operation &op) { return &op; });
+
+    // The following check should indeed be part of SCF::ForallOp::verify.
+    SmallVector<SubsetInsertionOpInterface> subsetInsertionOpInterfaceOps;
+    for (Operation *op : yieldingOps) {
+      if (auto subsetInsertionOpInterfaceOp =
+              dyn_cast<SubsetInsertionOpInterface>(op)) {
+        subsetInsertionOpInterfaceOps.push_back(subsetInsertionOpInterfaceOp);
+        continue;
+      }
+      return failure();
+    }
+
+    // Step 1: For a given i-th result of scf.forall, check the following :-
+    //         a. If it has any use.
+    //         b. If the corresponding iter argument is being modified within
+    //            the loop.
+    //
+    //         Based on the check we maintain the following :-
+    //         a. `resultToDelete` - i-th result of scf.forall that'll be
+    //            deleted.
+    //         b. `resultToReplace` - i-th result of the old scf.forall
+    //            whose uses will be replaced by the new scf.forall.
+    //         c. `newOuts` - the shared_outs' operand of the new scf.forall
+    //            corresponding to the i-th result with at least one use.
+    //         d. `mapping` - mapping the old iter block argument of scf.forall
+    //            with the corresponding shared_outs' operand. This will be
+    //            used when creating a new scf.forall op.
+    SmallVector<OpResult> resultToDelete;
+    SmallVector<Value> resultToReplace;
+    SmallVector<Value> newOuts;
+    IRMapping mapping;
+    for (OpResult result : forallOp.getResults()) {
+      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+      if (isCandidateValueToDelete(result, blockArg)) {
+        resultToDelete.push_back(result);
+        mapping.map(blockArg, opOperand->get());
+      } else {
+        resultToReplace.push_back(result);
+        newOuts.push_back(opOperand->get());
+      }
+    }
+
+    // Return early if all results of scf.forall has at least one use and being
+    // modified within the loop.
+    if (resultToDelete.empty()) {
+      return failure();
+    }
+
+    // Step 2: For the the i-th result, do the following :-
+    //         a. Fetch the corresponding BlockArgument.
+    //         b. Look for an op within scf.forall.in_parallel whose destination
+    //            operand is the BlockArgument fetched in step a.
+    //         c. Remove the operation fetched in b.
+    //         d. For any use of the BlockArgument in the body of the scf.forall
+    //            replace it with the corresponding Output value.
+    for (OpResult result : resultToDelete) {
+      OpOperand *opOperand = forallOp.getTiedOpOperand(result);
+      BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
+      Value::user_range users = blockArg.getUsers();
+      Operation *terminatorOperationToDelete = nullptr;
+      for (Operation *user : users) {
+        if (auto subsetInsertionOpInterfaceOp =
+                dyn_cast<SubsetInsertionOpInterface>(user)) {
+          if (subsetInsertionOpInterfaceOp.getDestinationOperand().get() ==
+              blockArg) {
+            terminatorOperationToDelete = subsetInsertionOpInterfaceOp;
+            break;
+          }
+        }
+      }
+      if (terminatorOperationToDelete)
+        rewriter.eraseOp(terminatorOperationToDelete);
+    }
+
+    // Step 3. Create a new scf.forall op with the new shared_outs' operands
+    //         fetched earlier
+    auto newforallOp = rewriter.create<scf::ForallOp>(
+        forallOp.getLoc(), forallOp.getMixedLowerBound(),
+        forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
+        forallOp.getMapping());
+
+    // Step 4. Clone the region of the old scf.forall into the newly created
+    //         scf.forall using the IRMapping formed in Step 1.
+    newforallOp.getBodyRegion().getBlocks().clear();
+    rewriter.cloneRegionBefore(forallOp.getRegion(), newforallOp.getRegion(),
+                               newforallOp.getRegion().begin(), mapping);
+
+    // Step 5. Replace the uses of result of old scf.forall with that of the new
+    //         scf.forall.
+    for (auto &&[oldResult, newResult] :
+         llvm::zip(resultToReplace, newforallOp->getResults())) {
+      rewriter.replaceAllUsesWith(oldResult, newResult);
+    }
+    // Step 6. Replace the uses of those values that either has no use or are
+    //         not being modified within the loop with the corresponding
+    //         OpOperand.
+    for (OpResult oldResult : resultToDelete) {
+      rewriter.replaceAllUsesWith(oldResult,
+                                  forallOp.getTiedOpOperand(oldResult)->get());
+    }
+    return success();
+  }
+};
+
 struct ForallOpSingleOrZeroIterationDimsFolder
     : public OpRewritePattern<ForallOp> {
   using OpRewritePattern<ForallOp>::OpRewritePattern;
@@ -1667,7 +1865,7 @@ struct FoldTensorCastOfOutputIntoForallOp
 void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
   results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
-              ForallOpControlOperandsFolder,
+              ForallOpControlOperandsFolder, ForallOpIterArgsFolder,
               ForallOpSingleOrZeroIterationDimsFolder>(context);
 }
 
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index b4c9ed4db94e0e..9b379ad15f1ecf 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1735,6 +1735,86 @@ func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
 
 // -----
 
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+  func.func @fold_iter_args_not_being_modified_within_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 4.200000e+01 : f32
+    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+    %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+    %1 = affine.apply #map()[%dim, %arg0]
+    %2:2 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg1, %arg5 = %arg2) -> (tensor<?xf32>, tensor<?xf32>) {
+      %3 = affine.apply #map1(%arg3)[%arg0]
+      %4 = affine.min #map2(%arg3)[%dim, %arg0]
+      %extracted_slice0 = tensor.extract_slice %arg4[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %extracted_slice1 = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %5 = linalg.elemwise_unary ins(%extracted_slice0 : tensor<?xf32>) outs(%extracted_slice1 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %5 into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    return %2#0, %2#1 : tensor<?xf32>, tensor<?xf32>
+  }
+}
+// CHECK-LABEL: @fold_iter_args_not_being_modified_within_scfforall
+//  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
+//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//  CHECK-SAME:                       shared_outs(%[[ITER_ARG_5:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+//       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+//       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ITER_ARG_5]]
+//       CHECK:      %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+//       CHECK:      scf.forall.in_parallel {
+//  CHECK-NEXT:         tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_5]]
+//  CHECK-NEXT:      }
+//  CHECK-NEXT:    }
+//  CHECK-NEXT:    return %[[ARG1]], %[[RESULT]]
+
+// -----
+
+#map = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+module {
+  func.func @fold_iter_args_with_no_use_of_result_scfforall(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>, %arg3: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+    %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?xf32>) -> tensor<?xf32>
+    %dim = tensor.dim %arg1, %c0 : tensor<?xf32>
+    %1 = affine.apply #map()[%dim, %arg0]
+    %2:3 = scf.forall (%arg4) in (%1) shared_outs(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+      %3 = affine.apply #map1(%arg4)[%arg0]
+      %4 = affine.min #map2(%arg4)[%dim, %arg0]
+      %extracted_slice = tensor.extract_slice %arg5[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %extracted_slice_0 = tensor.extract_slice %arg6[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %extracted_slice_1 = tensor.extract_slice %arg7[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %extracted_slice_2 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+      %5 = linalg.elemwise_unary ins(%extracted_slice : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        tensor.parallel_insert_slice %5 into %arg6[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %extracted_slice into %arg5[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %extracted_slice_0 into %arg7[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    return %2#1 : tensor<?xf32>
+  }
+}
+// CHECK-LABEL: @fold_iter_args_with_no_use_of_result_scfforall
+//  CHECK-SAME:   (%{{.*}}: index, %[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: tensor<?xf32>, %[[ARG3:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+//       CHECK:    %[[RESULT:.*]] = scf.forall 
+//  CHECK-SAME:                       shared_outs(%[[ITER_ARG_6:.*]] = %[[ARG2]]) -> (tensor<?xf32>) {
+//       CHECK:      %[[OPERAND0:.*]] = tensor.extract_slice %[[ARG1]]
+//       CHECK:      %[[OPERAND1:.*]] = tensor.extract_slice %[[ARG3]]
+//       CHECK:      %[[ELEM:.*]] = linalg.elemwise_unary ins(%[[OPERAND0]] : tensor<?xf32>) outs(%[[OPERAND1]] : tensor<?xf32>) -> tensor<?xf32>
+//       CHECK:      scf.forall.in_parallel {
+//  CHECK-NEXT:         tensor.parallel_insert_slice %[[ELEM]] into %[[ITER_ARG_6]]
+//  CHECK-NEXT:      }
+//  CHECK-NEXT:    }
+//  CHECK-NEXT:    return %[[RESULT]]
+
+// -----
+
 func.func @index_switch_fold() -> (f32, f32) {
   %switch_cst = arith.constant 1: index
   %0 = scf.index_switch %switch_cst -> f32

>From 09501e19d20b825ab4503021c121062af8840175 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 2 May 2024 08:30:52 +0000
Subject: [PATCH 2/2] Address review comments

---
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td |   4 +
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt     |   1 -
 mlir/lib/Dialect/SCF/IR/SCF.cpp            | 102 ++++++++++-----------
 3 files changed, 55 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index b3d085bfff1af9..2cf4575affcb76 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -608,6 +608,10 @@ def ForallOp : SCF_Op<"forall", [
 
     // Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
+
+    /// Returns operations within scf.forall.in_parallel whose destination
+    /// operand is the block argument `bbArg`.
+    SmallVector<Operation*> getStoreOpUser(BlockArgument bbArg);
   }];
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 6e5d80078e8022..423e1c3e1e042c 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -17,7 +17,6 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
-  MLIRSubsetOpInterface
   MLIRTensorDialect
   MLIRValueBoundsOpInterface
   )
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 355cfc8b3ee626..dd2921dc03f3f1 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -20,7 +20,6 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
-#include "mlir/Interfaces/SubsetOpInterface.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -1416,6 +1415,19 @@ InParallelOp ForallOp::getTerminator() {
   return cast<InParallelOp>(getBody()->getTerminator());
 }
 
+SmallVector<Operation *> ForallOp::getStoreOpUser(BlockArgument bbArg) {
+  SmallVector<Operation *> storeOps;
+  InParallelOp inParallelOp = getTerminator();
+  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
+    if (auto parallelInsertSliceOp =
+            dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
+        parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
+      storeOps.push_back(parallelInsertSliceOp);
+    }
+  }
+  return storeOps;
+}
+
 std::optional<Value> ForallOp::getSingleInductionVar() {
   if (getRank() != 1)
     return std::nullopt;
@@ -1549,12 +1561,13 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
 ///
 /// NOTE: 1. All uses of the folded shared_outs (iter argument) within the
 ///          scf.forall is replaced by their corresponding operands.
-///       2. The canonicalization assumes that there are no <STORE OP WITH
-///          DESTINATION *> ops within the body of the scf.forall except within
-///          scf.forall.in_parallel terminator.
-///       3. The order of the <STORE OP WITH DESTINATION *> can be arbitrary
-///          within scf.forall.in_parallel - the code below takes care of this
-///          by traversing the uses of the corresponding iter arg.
+///       2. Even if there are <STORE OP WITH DESTINATION *> ops within the body
+///          of the scf.forall besides within scf.forall.in_parallel terminator,
+///          this canonicalization remains valid. For more details, please refer
+///          to :
+///          https://github.com/llvm/llvm-project/pull/90189#discussion_r1589011124
+///       3. TODO(avarma): Generalize it for other store ops. Currently it
+///          handles tensor.parallel_insert_slice ops only.
 ///
 /// Example of second case :-
 ///  INPUT:
@@ -1585,18 +1598,6 @@ class ForallOpControlOperandsFolder : public OpRewritePattern<ForallOp> {
 struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
   using OpRewritePattern<ForallOp>::OpRewritePattern;
 
-  /// Utility function that checks if a candidate value satisifies any of the
-  /// conditions (see above doc comment) to make it viable for folding away.
-  static bool isCandidateValueToDelete(Value result, BlockArgument blockArg) {
-    if (result.use_empty()) {
-      return true;
-    }
-    Value::user_range users = blockArg.getUsers();
-    return llvm::all_of(users, [&](Operation *user) {
-      return !isa<SubsetInsertionOpInterface>(user);
-    });
-  }
-
   LogicalResult matchAndRewrite(ForallOp forallOp,
                                 PatternRewriter &rewriter) const final {
     scf::InParallelOp terminatorOp = forallOp.getTerminator();
@@ -1617,7 +1618,9 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     // Step 1: For a given i-th result of scf.forall, check the following :-
     //         a. If it has any use.
     //         b. If the corresponding iter argument is being modified within
-    //            the loop.
+    //            the loop, i.e. has at least one store op with the iter arg as
+    //            its destination operand. For this we use
+    //            ForallOp::getStoreOpUser(iter_arg).
     //
     //         Based on the check we maintain the following :-
     //         a. `resultToDelete` - i-th result of scf.forall that'll be
@@ -1626,19 +1629,14 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     //            whose uses will be replaced by the new scf.forall.
     //         c. `newOuts` - the shared_outs' operand of the new scf.forall
     //            corresponding to the i-th result with at least one use.
-    //         d. `mapping` - mapping the old iter block argument of scf.forall
-    //            with the corresponding shared_outs' operand. This will be
-    //            used when creating a new scf.forall op.
-    SmallVector<OpResult> resultToDelete;
+    SetVector<OpResult> resultToDelete;
     SmallVector<Value> resultToReplace;
     SmallVector<Value> newOuts;
-    IRMapping mapping;
     for (OpResult result : forallOp.getResults()) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
-      if (isCandidateValueToDelete(result, blockArg)) {
-        resultToDelete.push_back(result);
-        mapping.map(blockArg, opOperand->get());
+      if (result.use_empty() || forallOp.getStoreOpUser(blockArg).empty()) {
+        resultToDelete.insert(result);
       } else {
         resultToReplace.push_back(result);
         newOuts.push_back(opOperand->get());
@@ -1653,28 +1651,16 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
 
     // Step 2: For the the i-th result, do the following :-
     //         a. Fetch the corresponding BlockArgument.
-    //         b. Look for an op within scf.forall.in_parallel whose destination
-    //            operand is the BlockArgument fetched in step a.
-    //         c. Remove the operation fetched in b.
-    //         d. For any use of the BlockArgument in the body of the scf.forall
-    //            replace it with the corresponding Output value.
+    //         b. Look for store ops (currently tensor.parallel_insert_slice)
+    //            with the BlockArgument as its destination operand.
+    //         c. Remove the operations fetched in b.
     for (OpResult result : resultToDelete) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
-      Value::user_range users = blockArg.getUsers();
-      Operation *terminatorOperationToDelete = nullptr;
-      for (Operation *user : users) {
-        if (auto subsetInsertionOpInterfaceOp =
-                dyn_cast<SubsetInsertionOpInterface>(user)) {
-          if (subsetInsertionOpInterfaceOp.getDestinationOperand().get() ==
-              blockArg) {
-            terminatorOperationToDelete = subsetInsertionOpInterfaceOp;
-            break;
-          }
-        }
+      SmallVector<Operation *> storeOps = forallOp.getStoreOpUser(blockArg);
+      for (Operation *storeOp : storeOps) {
+        rewriter.eraseOp(storeOp);
       }
-      if (terminatorOperationToDelete)
-        rewriter.eraseOp(terminatorOperationToDelete);
     }
 
     // Step 3. Create a new scf.forall op with the new shared_outs' operands
@@ -1684,11 +1670,25 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
         forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts,
         forallOp.getMapping());
 
-    // Step 4. Clone the region of the old scf.forall into the newly created
-    //         scf.forall using the IRMapping formed in Step 1.
-    newforallOp.getBodyRegion().getBlocks().clear();
-    rewriter.cloneRegionBefore(forallOp.getRegion(), newforallOp.getRegion(),
-                               newforallOp.getRegion().begin(), mapping);
+    // Step 4. Merge the block of the old scf.forall into the newly created
+    //         scf.forall using the new set of arguments.
+    Block *loopBody = forallOp.getBody();
+    Block *newLoopBody = newforallOp.getBody();
+    ArrayRef<BlockArgument> newBbArgs = newLoopBody->getArguments();
+    SmallVector<Value> newBlockArgs =
+        llvm::map_to_vector(newBbArgs.take_front(forallOp.getRank()),
+                            [](BlockArgument b) -> Value { return b; });
+    Block::BlockArgListType newSharedOutsArgs = newforallOp.getRegionOutArgs();
+    unsigned index = 0;
+    for (OpResult result : forallOp.getResults()) {
+      if (resultToDelete.count(result)) {
+        newBlockArgs.push_back(forallOp.getTiedOpOperand(result)->get());
+      } else {
+        newBlockArgs.push_back(newSharedOutsArgs[index++]);
+      }
+    }
+    rewriter.eraseOp(newforallOp.getTerminator());
+    rewriter.mergeBlocks(loopBody, newLoopBody, newBlockArgs);
 
     // Step 5. Replace the uses of result of old scf.forall with that of the new
     //         scf.forall.



More information about the Mlir-commits mailing list