[Mlir-commits] [mlir] [mlir] Add transformation to wrap scf::while in zero-trip-check (PR #81050)

Jerry Wu llvmlistbot at llvm.org
Thu Feb 8 10:48:06 PST 2024


================
@@ -0,0 +1,122 @@
+//===- WrapInZeroTripCheck.cpp - Loop transforms to add zero-trip-check ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+/// Create zero-trip-check around a `while` op and return the new loop op in the
+/// check. The while loop is rotated to avoid evaluating the condition twice.
+///
+/// Given an example below:
+///
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// First clone before block to the front of the loop:
+///
+///   %pre_val = .., %init : i64
+///   %pre_cond = arith.cmpi .., %init : i32
+///   scf.while (%arg0 = %init) : (i32) -> i64 {
+///     %val = .., %arg0 : i64
+///     %cond = arith.cmpi .., %arg0 : i32
+///     scf.condition(%cond) %val : i64
+///   } do {
+///   ^bb0(%arg1: i64):
+///     %next = .., %arg1 : i32
+///     scf.yield %next : i32
+///   }
+///
+/// Create `if` op with the condition, rotate and move the loop into the else
+/// branch:
+///
+///   %pre_val = .., %init : i64
+///   %pre_cond = arith.cmpi .., %init : i32
+///   scf.if %pre_cond -> i64 {
+///     %res = scf.while (%arg1 = %va0) : (i64) -> i64 {
+///       // Original after block
+///       %next = .., %arg1 : i32
+///       // Original before block
+///       %val = .., %next : i64
+///       %cond = arith.cmpi .., %next : i32
+///       scf.condition(%cond) %val : i64
+///     } do {
+///     ^bb0(%arg2: i64):
+///       %scf.yield %arg2 : i32
+///     }
+///     scf.yield %res : i64
+///   } else {
+///     scf.yield %pre_val : i64
+///   }
+FailureOr<scf::WhileOp>
+mlir::scf::wrapWhileLoopInZeroTripCheck(scf::WhileOp whileOp,
+                                        RewriterBase &rewriter) {
+  IRMapping mapper;
+  Block *beforeBlock = whileOp.getBeforeBody();
+  // Clone before block before the loop for zero-trip-check.
+  for (auto [arg, init] :
+       llvm::zip_equal(beforeBlock->getArguments(), whileOp.getInits())) {
+    mapper.map(arg, init);
+  }
+  rewriter.setInsertionPoint(whileOp);
+  for (auto &op : *beforeBlock) {
+    if (isa<scf::ConditionOp>(op)) {
+      break;
+    }
+    // Safe to clone everything as in a single block all defs have been cloned
+    // and added to mapper in order.
+    rewriter.insert(op.clone(mapper));
+  }
+
+  auto condOp = whileOp.getConditionOp();
+  auto clonedCondition = mapper.lookupOrDefault(condOp.getCondition());
+  auto clonedCondArgs = llvm::map_to_vector(
+      condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); });
+
+  // Create zero-trip-check and move the while loop in.
+  scf::WhileOp newLoopOp = nullptr;
+  auto ifOp = rewriter.create<scf::IfOp>(
+      whileOp->getLoc(), clonedCondition,
+      [&](OpBuilder &builder, Location loc) {
+        // Then runs the while loop.
+        newLoopOp = builder.create<scf::WhileOp>(
+            loc, whileOp.getResultTypes(), clonedCondArgs,
+            [&](OpBuilder &builder, Location loc, ValueRange args) {
+              // Rotate and move the loop body into before block.
+              auto newBlock = builder.getBlock();
+              rewriter.mergeBlocks(whileOp.getAfterBody(), newBlock, args);
+              auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
+              rewriter.mergeBlocks(whileOp.getBeforeBody(), newBlock,
+                                   yieldOp.getResults());
+              rewriter.eraseOp(yieldOp);
+            },
----------------
pzread wrote:

Done. I tried a few ways before but now I think it looks clearer to create the new while loop first.

https://github.com/llvm/llvm-project/pull/81050


More information about the Mlir-commits mailing list