[Mlir-commits] [mlir] [mlir][scf] `scf.while` uplifting: Add preparation patterns. (PR #89222)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 18 05:02:37 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

<details>
<summary>Changes</summary>

`scf.while` -> `scf.for` uplifting expects `before` block consisting of single cmp op, so we need to cleanup it before running the uplifting. One of the possible cleanups is LICM.
Second one is moving and duplicating ops from `before` block to `after` block and after the loop. Add the pattern for such transformation.

---
Full diff: https://github.com/llvm/llvm-project/pull/89222.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h (+4) 
- (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+192) 
- (added) mlir/test/Dialect/SCF/uplift-while-prepare.mlir (+74) 
- (modified) mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp (+23) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf25706269803..244423274c0555 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to prepare scf.while loops for upliting, e.g. for before
+/// block cleanup.
+void populatePrepareUpliftWhileToForPatterns(RewritePatternSet &patterns);
+
 /// Populate patterns to uplift `scf.while` ops to `scf.for`.
 /// Uplifitng expects a specific ops pattern:
 ///  * `before` block consisting of single arith.cmp op
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a72..959c30315ae8af 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -20,7 +20,194 @@
 
 using namespace mlir;
 
+static Operation *findOpToMoveFromBefore(scf::WhileOp loop) {
+  Block *body = loop.getBeforeBody();
+  if (body->without_terminator().empty())
+    return nullptr;
+
+  // Check last op first.
+  // TODO: It's usually safe to move and duplicate last op even if it has side
+  // effects, as long as the sequence of the ops executed on each path will stay
+  // the same. Exceptions are GPU barrier/group ops, LLVM proper has
+  // convergent attribute/semantics to check this, but we doesn't model it yet.
+  Operation *lastOp = &(*std::prev(body->without_terminator().end()));
+
+  auto term = loop.getConditionOp();
+  Operation *termCondOp = term.getCondition().getDefiningOp();
+  if (lastOp != termCondOp)
+    return lastOp;
+
+  // Try to move terminator args producers.
+  for (Value termArg : term.getArgs()) {
+    Operation *op = termArg.getDefiningOp();
+    if (!op || op->getParentOp() != loop || op == termCondOp || !isPure(op))
+      continue;
+
+    // Each result must be only used as terminator arg, meaning it can have one
+    // use at max, duplicated terminator args must be already cleaned up
+    // by canonicalizations at this point.
+    if (!llvm::all_of(op->getResults(), [&](Value val) {
+          return val.hasOneUse() || val.use_empty();
+        }))
+      continue;
+
+    return op;
+  }
+  return nullptr;
+}
+
 namespace {
+/// `scf.while` uplifting expects before block consisting of single cmp op,
+/// try to move ops from before block to after block and to after loop.
+///
+/// ```
+/// scf.while(...) {
+/// before:
+///   ...
+///   some_op()
+///   scf.condition ..
+/// after:
+///   ...
+/// }
+/// ```
+/// to
+/// ```
+/// scf.while(...) {
+/// before:
+///   ...
+///   scf.condition ..
+/// after:
+///   some_op()
+///   ...
+/// }
+/// some_op()
+/// ```
+struct MoveOpsFromBefore : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    Operation *opToMove = findOpToMoveFromBefore(loop);
+    if (!opToMove)
+      return rewriter.notifyMatchFailure(loop, "No suitable ops found");
+
+    auto condOp = loop.getConditionOp();
+    SmallVector<Value> newCondArgs;
+
+    // Populate new terminator args.
+
+    // Add original terminator args, except args produced by the op we decided
+    // to move.
+    for (Value arg : condOp.getArgs()) {
+      if (arg.getDefiningOp() == opToMove)
+        continue;
+
+      newCondArgs.emplace_back(arg);
+    }
+    auto originalArgsOffset = newCondArgs.size();
+
+    // Add moved op operands to terminator args, if they are defined in loop
+    // block.
+    DominanceInfo dom;
+    for (Value arg : opToMove->getOperands()) {
+      if (dom.properlyDominates(arg, loop))
+        continue;
+
+      newCondArgs.emplace_back(arg);
+    }
+
+    // Create new loop.
+    ValueRange tempRange(newCondArgs);
+    auto newLoop = rewriter.create<mlir::scf::WhileOp>(
+        loop.getLoc(), TypeRange(tempRange), loop.getInits(), nullptr, nullptr);
+
+    OpBuilder::InsertionGuard g(rewriter);
+
+    // Create new terminator, old terminator will be deleted later.
+    rewriter.setInsertionPoint(condOp);
+    rewriter.create<scf::ConditionOp>(condOp.getLoc(), condOp.getCondition(),
+                                      newCondArgs);
+
+    Block *oldBefore = loop.getBeforeBody();
+    Block *newBefore = newLoop.getBeforeBody();
+
+    // Inline before block as is.
+    rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+                               newBefore->getArguments());
+
+    Block *oldAfter = loop.getAfterBody();
+    Block *newAfter = newLoop.getAfterBody();
+
+    // Build mapping between original op args and new after block args/new loop
+    // results.
+    IRMapping afterBodyMapping;
+    IRMapping afterLoopMapping;
+    {
+      ValueRange blockArgs =
+          newAfter->getArguments().drop_front(originalArgsOffset);
+      ValueRange newLoopArgs =
+          newLoop.getResults().drop_front(originalArgsOffset);
+      for (Value arg : opToMove->getOperands()) {
+        if (dom.properlyDominates(arg, loop))
+          continue;
+
+        assert(!blockArgs.empty());
+        assert(!newLoopArgs.empty());
+        afterBodyMapping.map(arg, blockArgs.front());
+        afterLoopMapping.map(arg, newLoopArgs.front());
+        blockArgs = blockArgs.drop_front();
+        newLoopArgs = newLoopArgs.drop_front();
+      }
+    }
+
+    {
+      // Clone op into after body.
+      rewriter.setInsertionPointToStart(oldAfter);
+      Operation *newAfterBodyOp = rewriter.clone(*opToMove, afterBodyMapping);
+
+      // Clone op after loop.
+      rewriter.setInsertionPointAfter(newLoop);
+      Operation *newAfterLoopOp = rewriter.clone(*opToMove, afterLoopMapping);
+
+      // Build mapping between old and new after block args and between old and
+      // new loop results.
+      ValueRange blockArgs =
+          newAfter->getArguments().take_front(originalArgsOffset);
+      ValueRange newLoopArgs =
+          newLoop.getResults().take_front(originalArgsOffset);
+      SmallVector<Value> argsMapping;
+      SmallVector<Value> newLoopResults;
+      for (Value arg : condOp.getArgs()) {
+        if (arg.getDefiningOp() == opToMove) {
+          auto resNumber = cast<OpResult>(arg).getResultNumber();
+          argsMapping.emplace_back(newAfterBodyOp->getResult(resNumber));
+          newLoopResults.emplace_back(newAfterLoopOp->getResult(resNumber));
+          continue;
+        }
+
+        assert(!blockArgs.empty());
+        assert(!newLoopArgs.empty());
+        argsMapping.emplace_back(blockArgs.front());
+        newLoopResults.emplace_back(newLoopArgs.front());
+        blockArgs = blockArgs.drop_front();
+        newLoopArgs = newLoopArgs.drop_front();
+      }
+
+      // Inline after block.
+      rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+                                 argsMapping);
+
+      // Replace loop.
+      rewriter.replaceOp(loop, newLoopResults);
+    }
+
+    // Finally, we can remove old terminator and the original op.
+    rewriter.eraseOp(condOp);
+    rewriter.eraseOp(opToMove);
+    return success();
+  }
+};
+
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -209,6 +396,11 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   return newLoop;
 }
 
+void mlir::scf::populatePrepareUpliftWhileToForPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<MoveOpsFromBefore>(patterns.getContext());
+}
+
 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
   patterns.add<UpliftWhileOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/uplift-while-prepare.mlir b/mlir/test/Dialect/SCF/uplift-while-prepare.mlir
new file mode 100644
index 00000000000000..fd359efa20ba0c
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while-prepare.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-prepare-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func.func @test()
+//       CHECK:  scf.while
+//   CHECK-NOT:  "test.test1"
+//       CHECK:  scf.condition(%{{.*}})
+//       CHECK:  } do {
+//       CHECK:  "test.test1"() : () -> ()
+//       CHECK:  "test.test2"() : () -> ()
+//       CHECK:  scf.yield
+//       CHECK:  "test.test1"() : () -> ()
+//       CHECK:  return
+func.func @test() {
+  scf.while () : () -> () {
+    %1 = "test.cond"() : () -> i1
+    "test.test1"() : () -> ()
+    scf.condition(%1)
+  } do {
+  ^bb0():
+    "test.test2"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test()
+//       CHECK:  scf.while
+//   CHECK-NOT:  "test.test1"
+//       CHECK:  scf.condition(%{{.*}})
+//       CHECK:  } do {
+//       CHECK:  %[[R1:.*]]:2 = "test.test1"() : () -> (i32, i64)
+//       CHECK:  "test.test2"(%[[R1]]#1, %[[R1]]#0) : (i64, i32) -> ()
+//       CHECK:  scf.yield
+//       CHECK:  %[[R2:.*]]:2 = "test.test1"() : () -> (i32, i64)
+//       CHECK:  return %[[R2]]#1, %[[R2]]#0 : i64, i32
+func.func @test() -> (i64, i32) {
+  %0:2 = scf.while () : () -> (i64, i32) {
+    %1 = "test.cond"() : () -> i1
+    %2:2 = "test.test1"() : () -> (i32, i64)
+    scf.condition(%1) %2#1, %2#0 : i64, i32
+  } do {
+  ^bb0(%arg1: i64, %arg2: i32):
+    "test.test2"(%arg1, %arg2) : (i64, i32) -> ()
+    scf.yield
+  }
+  return %0#0, %0#1 : i64, i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test
+//  CHECK-SAME:  (%[[ARG0:.*]]: index)
+//       CHECK:  %[[RES:.*]] = scf.while (%[[ARG1:.*]] = %[[ARG0]]) : (index) -> index {
+//   CHECK-NOT:  arith.addi
+//       CHECK:  scf.condition(%{{.*}}) %[[ARG1]] : index
+//       CHECK:  } do {
+//       CHECK:  ^bb0(%[[ARG2:.*]]: index):
+//       CHECK:  %[[A1:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : index
+//       CHECK:  scf.yield %[[A1]]
+//       CHECK:  %[[A2:.*]] = arith.addi %[[ARG0]], %[[RES]] : index
+//       CHECK:  return %[[A2]]
+func.func @test(%arg0: index) -> index {
+  %res = scf.while (%arg1 = %arg0) : (index) -> (index) {
+    %0 = arith.addi %arg0, %arg1 : index
+    %1 = "test.cond"() : () -> i1
+    scf.condition(%1) %0 : index
+  } do {
+  ^bb0(%arg2: index):
+    scf.yield %arg2 : index
+  }
+  return %res : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
index 468bc0ca78489f..3eaad9eaa8a731 100644
--- a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
@@ -19,6 +19,28 @@ using namespace mlir;
 
 namespace {
 
+struct TestSCFPrepareUpliftWhileToFor
+    : public PassWrapper<TestSCFPrepareUpliftWhileToFor, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPrepareUpliftWhileToFor)
+
+  StringRef getArgument() const final {
+    return "test-scf-prepare-uplift-while-to-for";
+  }
+
+  StringRef getDescription() const final {
+    return "test scf while to for uplifting preparation";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    scf::populatePrepareUpliftWhileToForPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct TestSCFUpliftWhileToFor
     : public PassWrapper<TestSCFUpliftWhileToFor, OperationPass<void>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFUpliftWhileToFor)
@@ -44,6 +66,7 @@ struct TestSCFUpliftWhileToFor
 namespace mlir {
 namespace test {
 void registerTestSCFUpliftWhileToFor() {
+  PassRegistration<TestSCFPrepareUpliftWhileToFor>();
   PassRegistration<TestSCFUpliftWhileToFor>();
 }
 } // namespace test

``````````

</details>


https://github.com/llvm/llvm-project/pull/89222


More information about the Mlir-commits mailing list