[Mlir-commits] [mlir] 644b55d - [MLIR][SCF] Add for-to-while loop transformation pass
Morten Borup Petersen
llvmlistbot at llvm.org
Mon Sep 20 08:58:39 PDT 2021
Author: Morten Borup Petersen
Date: 2021-09-20T16:57:50+01:00
New Revision: 644b55d57ec76a18916d30f921781b99795f6e10
URL: https://github.com/llvm/llvm-project/commit/644b55d57ec76a18916d30f921781b99795f6e10
DIFF: https://github.com/llvm/llvm-project/commit/644b55d57ec76a18916d30f921781b99795f6e10.diff
LOG: [MLIR][SCF] Add for-to-while loop transformation pass
This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the 'before' region of the while operation, and indctuion variable incrementation + the loop body in the 'after' region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop.
Any 'yield' ops in the for-loop are rewritten to additionally yield the (incremented) induction variable.
This transformation is useful for passes where we want to consider structured control flow solely on the basis of a loop body and the computation of a loop condition. As an example, when doing high-level synthesis in CIRCT, the incrementation of an IV in a for-loop is "just another part" of a circuit datapath, and what we really care about is the distinction between our datapath and our control logic (the condition variable).
Differential Revision: https://reviews.llvm.org/D108454
Added:
mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Passes.h
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index 34fe3c01e2b1..e6123617f656 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -52,6 +52,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
/// loop range.
std::unique_ptr<Pass> createForLoopRangeFoldingPass();
+// Creates a pass which lowers for loops into while loops.
+std::unique_ptr<Pass> createForToWhileLoopPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index ef4df99d0bf9..50346de6d9a0 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -78,4 +78,39 @@ def SCFForLoopRangeFolding
let constructor = "mlir::createForLoopRangeFoldingPass()";
}
+def SCFForToWhileLoop
+ : FunctionPass<"scf-for-to-while"> {
+ let summary = "Convert SCF for loops to SCF while loops";
+ let constructor = "mlir::createForToWhileLoopPass()";
+ let description = [{
+ This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop
+ condition is placed in the 'before' region of the while operation, and the
+ induction variable incrementation and loop body in the 'after' region.
+ The loop carried values of the while op are the induction variable (IV) of
+ the for-loop + any iter_args specified for the for-loop.
+ Any 'yield' ops in the for-loop are rewritten to additionally yield the
+ (incremented) induction variable.
+
+ ```mlir
+ # Before:
+ scf.for %i = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %arg2 : i32
+ memref.store %0, %arg0[%i] : memref<?xi32>
+ }
+
+ # After:
+ %0 = scf.while (%i = %c0) : (index) -> index {
+ %1 = cmpi slt, %i, %arg1 : index
+ scf.condition(%1) %i : index
+ } do {
+ ^bb0(%i: index): // no predecessors
+ %1 = addi %i, %c1 : index
+ %2 = addi %arg2, %arg2 : i32
+ memref.store %2, %arg0[%i] : memref<?xi32>
+ scf.yield %1 : index
+ }
+ ```
+ }];
+}
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 759f02b91fe6..ae157092a56a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRSCFTransforms
Bufferize.cpp
+ ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
LoopRangeFolding.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
new file mode 100644
index 000000000000..830546413d1e
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -0,0 +1,110 @@
+//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForOp's into SCF.WhileOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForOp;
+using scf::WhileOp;
+
+namespace {
+
+struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp forOp,
+ PatternRewriter &rewriter) const override {
+ // Generate type signature for the loop-carried values. The induction
+ // variable is placed first, followed by the forOp.iterArgs.
+ SmallVector<Type, 8> lcvTypes;
+ lcvTypes.push_back(forOp.getInductionVar().getType());
+ llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
+ [&](auto v) { return v.getType(); });
+
+ // Build scf.WhileOp
+ SmallVector<Value> initArgs;
+ initArgs.push_back(forOp.lowerBound());
+ llvm::append_range(initArgs, forOp.initArgs());
+ auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
+ forOp->getAttrs());
+
+ // 'before' region contains the loop condition and forwarding of iteration
+ // arguments to the 'after' region.
+ auto *beforeBlock = rewriter.createBlock(
+ &whileOp.before(), whileOp.before().begin(), lcvTypes, {});
+ rewriter.setInsertionPointToStart(&whileOp.before().front());
+ auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt,
+ beforeBlock->getArgument(0),
+ forOp.upperBound());
+ rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
+ beforeBlock->getArguments());
+
+ // Inline for-loop body into an executeRegion operation in the "after"
+ // region. The return type of the execRegionOp does not contain the
+ // iv - yields in the source for-loop contain only iterArgs.
+ auto *afterBlock = rewriter.createBlock(
+ &whileOp.after(), whileOp.after().begin(), lcvTypes, {});
+
+ // Add induction variable incrementation
+ rewriter.setInsertionPointToEnd(afterBlock);
+ auto ivIncOp = rewriter.create<AddIOp>(
+ whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
+
+ // Rewrite uses of the for-loop block arguments to the new while-loop
+ // "after" arguments
+ for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
+ barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
+
+ // Inline for-loop body operations into 'after' region.
+ for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
+ arg.moveBefore(afterBlock, afterBlock->end());
+
+ // Add incremented IV to yield operations
+ for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
+ SmallVector<Value> yieldOperands = yieldOp.getOperands();
+ yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
+ yieldOp->setOperands(yieldOperands);
+ }
+
+ // We cannot do a direct replacement of the forOp since the while op returns
+ // an extra value (the induction variable escapes the loop through being
+ // carried in the set of iterargs). Instead, rewrite uses of the forOp
+ // results.
+ for (auto arg : llvm::enumerate(forOp.getResults()))
+ arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
+
+ rewriter.eraseOp(forOp);
+ return success();
+ }
+};
+
+struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
+ void runOnFunction() override {
+ FuncOp funcOp = getFunction();
+ MLIRContext *ctx = funcOp.getContext();
+ RewritePatternSet patterns(ctx);
+ patterns.add<ForLoopLoweringPattern>(ctx);
+ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
+ return std::make_unique<ForToWhileLoop>();
+}
diff --git a/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
new file mode 100644
index 000000000000..c3a75bada8ff
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL: builtin.func @single_loop(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: i32) {
+// CHECK: %[[VAL_3:.*]] = constant 0 : index
+// CHECK: %[[VAL_4:.*]] = constant 1 : index
+// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_8:.*]]: index):
+// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
+// CHECK: %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
+// CHECK: memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
+// CHECK: scf.yield %[[VAL_9]] : index
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ scf.for %i = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %arg2 : i32
+ memref.store %0, %arg0[%i] : memref<?xi32>
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: builtin.func @nested_loop(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<?xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: i32) {
+// CHECK: %[[VAL_3:.*]] = constant 0 : index
+// CHECK: %[[VAL_4:.*]] = constant 1 : index
+// CHECK: %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK: %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK: scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_8:.*]]: index):
+// CHECK: %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
+// CHECK: %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK: %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index
+// CHECK: scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_13:.*]]: index):
+// CHECK: %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index
+// CHECK: %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
+// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
+// CHECK: memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<?xi32>
+// CHECK: scf.yield %[[VAL_14]] : index
+// CHECK: }
+// CHECK: scf.yield %[[VAL_9]] : index
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ scf.for %i = %c0 to %arg1 step %c1 {
+ scf.for %j = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %arg2 : i32
+ memref.store %0, %arg0[%i] : memref<?xi32>
+ memref.store %0, %arg0[%j] : memref<?xi32>
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: builtin.func @for_iter_args(
+// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: index) -> f32 {
+// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) {
+// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index
+// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
+// CHECK: %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index
+// CHECK: %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32
+// CHECK: scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32
+// CHECK: }
+// CHECK: return %[[VAL_14:.*]]#2 : f32
+// CHECK: }
+func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
+ %s0 = constant 0.0 : f32
+ %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) {
+ %sn = addf %iarg0, %iarg1 : f32
+ scf.yield %sn, %sn : f32, f32
+ }
+ return %result#1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: builtin.func @exec_region_multiple_yields(
+// CHECK-SAME: %[[VAL_0:.*]]: i32,
+// CHECK-SAME: %[[VAL_1:.*]]: index,
+// CHECK-SAME: %[[VAL_2:.*]]: i32) -> i32 {
+// CHECK: %[[VAL_3:.*]] = constant 0 : index
+// CHECK: %[[VAL_4:.*]] = constant 1 : index
+// CHECK: %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) {
+// CHECK: %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK: scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32
+// CHECK: } do {
+// CHECK: ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32):
+// CHECK: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index
+// CHECK: %[[VAL_12:.*]] = scf.execute_region -> i32 {
+// CHECK: %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index
+// CHECK: cond_br %[[VAL_13]], ^bb1, ^bb2
+// CHECK: ^bb1:
+// CHECK: %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32
+// CHECK: scf.yield %[[VAL_14]] : i32
+// CHECK: ^bb2:
+// CHECK: %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32
+// CHECK: scf.yield %[[VAL_15]] : i32
+// CHECK: }
+// CHECK: scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32
+// CHECK: }
+// CHECK: return %[[VAL_17:.*]]#1 : i32
+// CHECK: }
+func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
+ %c1_i32 = constant 1 : i32
+ %c2_i32 = constant 2 : i32
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c5 = constant 5 : index
+ %0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 {
+ %2 = scf.execute_region -> i32 {
+ %1 = cmpi slt, %i, %c1 : index
+ cond_br %1, ^bb1, ^bb2
+ ^bb1:
+ %2 = subi %iarg0, %arg0 : i32
+ scf.yield %2 : i32
+ ^bb2:
+ %3 = muli %iarg0, %arg2 : i32
+ scf.yield %3 : i32
+ }
+ scf.yield %2 : i32
+ }
+ return %0 : i32
+}
More information about the Mlir-commits
mailing list