[Mlir-commits] [mlir] 032cb16 - [MLIR][SCF] Add for-to-while loop transformation pass

Morten Borup Petersen llvmlistbot at llvm.org
Tue Sep 21 01:10:10 PDT 2021


Author: Morten Borup Petersen
Date: 2021-09-21T09:09:54+01:00
New Revision: 032cb1650fe62f9c93b0a0857e21597b2d11a69f

URL: https://github.com/llvm/llvm-project/commit/032cb1650fe62f9c93b0a0857e21597b2d11a69f
DIFF: https://github.com/llvm/llvm-project/commit/032cb1650fe62f9c93b0a0857e21597b2d11a69f.diff

LOG: [MLIR][SCF] Add for-to-while loop transformation pass

This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the 'before' region of the while operation, and indctuion variable incrementation + the loop body in the 'after' region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop.
Any 'yield' ops in the for-loop are rewritten to additionally yield the (incremented) induction variable.

This transformation is useful for passes where we want to consider structured control flow solely on the basis of a loop body and the computation of a loop condition. As an example, when doing high-level synthesis in CIRCT, the incrementation of an IV in a for-loop is "just another part" of a circuit datapath, and what we really care about is the distinction between our datapath and our control logic (the condition variable).

Differential Revision: https://reviews.llvm.org/D108454

Added: 
    mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
    mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Passes.h
    mlir/include/mlir/Dialect/SCF/Passes.td
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index 34fe3c01e2b16..e6123617f656e 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -52,6 +52,9 @@ createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {},
 /// loop range.
 std::unique_ptr<Pass> createForLoopRangeFoldingPass();
 
+// Creates a pass which lowers for loops into while loops.
+std::unique_ptr<Pass> createForToWhileLoopPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index ef4df99d0bf94..50346de6d9a05 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -78,4 +78,39 @@ def SCFForLoopRangeFolding
   let constructor = "mlir::createForLoopRangeFoldingPass()";
 }
 
+def SCFForToWhileLoop
+    : FunctionPass<"scf-for-to-while"> {
+  let summary = "Convert SCF for loops to SCF while loops";
+  let constructor = "mlir::createForToWhileLoopPass()";
+  let description = [{
+    This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop
+    condition is placed in the 'before' region of the while operation, and the
+    induction variable incrementation and loop body in the 'after' region.
+    The loop carried values of the while op are the induction variable (IV) of
+    the for-loop + any iter_args specified for the for-loop.
+    Any 'yield' ops in the for-loop are rewritten to additionally yield the
+    (incremented) induction variable.
+
+    ```mlir
+    # Before:
+      scf.for %i = %c0 to %arg1 step %c1 {
+        %0 = addi %arg2, %arg2 : i32
+        memref.store %0, %arg0[%i] : memref<?xi32>
+      }
+
+    # After:
+      %0 = scf.while (%i = %c0) : (index) -> index {
+        %1 = cmpi slt, %i, %arg1 : index
+        scf.condition(%1) %i : index
+      } do {
+      ^bb0(%i: index):  // no predecessors
+        %1 = addi %i, %c1 : index
+        %2 = addi %arg2, %arg2 : i32
+        memref.store %2, %arg0[%i] : memref<?xi32>
+        scf.yield %1 : index
+      }
+    ```
+  }];
+}
+
 #endif // MLIR_DIALECT_SCF_PASSES

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 759f02b91fe66..ae157092a56a9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRSCFTransforms
   Bufferize.cpp
+  ForToWhile.cpp
   LoopCanonicalization.cpp
   LoopPipelining.cpp
   LoopRangeFolding.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
new file mode 100644
index 0000000000000..830546413d1ef
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -0,0 +1,110 @@
+//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.ForOp's into SCF.WhileOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace llvm;
+using namespace mlir;
+using scf::ForOp;
+using scf::WhileOp;
+
+namespace {
+
+struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
+  using OpRewritePattern<ForOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ForOp forOp,
+                                PatternRewriter &rewriter) const override {
+    // Generate type signature for the loop-carried values. The induction
+    // variable is placed first, followed by the forOp.iterArgs.
+    SmallVector<Type, 8> lcvTypes;
+    lcvTypes.push_back(forOp.getInductionVar().getType());
+    llvm::transform(forOp.initArgs(), std::back_inserter(lcvTypes),
+                    [&](auto v) { return v.getType(); });
+
+    // Build scf.WhileOp
+    SmallVector<Value> initArgs;
+    initArgs.push_back(forOp.lowerBound());
+    llvm::append_range(initArgs, forOp.initArgs());
+    auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
+                                            forOp->getAttrs());
+
+    // 'before' region contains the loop condition and forwarding of iteration
+    // arguments to the 'after' region.
+    auto *beforeBlock = rewriter.createBlock(
+        &whileOp.before(), whileOp.before().begin(), lcvTypes, {});
+    rewriter.setInsertionPointToStart(&whileOp.before().front());
+    auto cmpOp = rewriter.create<CmpIOp>(whileOp.getLoc(), CmpIPredicate::slt,
+                                         beforeBlock->getArgument(0),
+                                         forOp.upperBound());
+    rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
+                                      beforeBlock->getArguments());
+
+    // Inline for-loop body into an executeRegion operation in the "after"
+    // region. The return type of the execRegionOp does not contain the
+    // iv - yields in the source for-loop contain only iterArgs.
+    auto *afterBlock = rewriter.createBlock(
+        &whileOp.after(), whileOp.after().begin(), lcvTypes, {});
+
+    // Add induction variable incrementation
+    rewriter.setInsertionPointToEnd(afterBlock);
+    auto ivIncOp = rewriter.create<AddIOp>(
+        whileOp.getLoc(), afterBlock->getArgument(0), forOp.step());
+
+    // Rewrite uses of the for-loop block arguments to the new while-loop
+    // "after" arguments
+    for (auto barg : enumerate(forOp.getBody(0)->getArguments()))
+      barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index()));
+
+    // Inline for-loop body operations into 'after' region.
+    for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
+      arg.moveBefore(afterBlock, afterBlock->end());
+
+    // Add incremented IV to yield operations
+    for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
+      SmallVector<Value> yieldOperands = yieldOp.getOperands();
+      yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
+      yieldOp->setOperands(yieldOperands);
+    }
+
+    // We cannot do a direct replacement of the forOp since the while op returns
+    // an extra value (the induction variable escapes the loop through being
+    // carried in the set of iterargs). Instead, rewrite uses of the forOp
+    // results.
+    for (auto arg : llvm::enumerate(forOp.getResults()))
+      arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1));
+
+    rewriter.eraseOp(forOp);
+    return success();
+  }
+};
+
+struct ForToWhileLoop : public SCFForToWhileLoopBase<ForToWhileLoop> {
+  void runOnFunction() override {
+    FuncOp funcOp = getFunction();
+    MLIRContext *ctx = funcOp.getContext();
+    RewritePatternSet patterns(ctx);
+    patterns.add<ForLoopLoweringPattern>(ctx);
+    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
+  return std::make_unique<ForToWhileLoop>();
+}

diff  --git a/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
new file mode 100644
index 0000000000000..5f9a0d117de04
--- /dev/null
+++ b/mlir/test/Dialect/SCF/for-loop-to-while-loop.mlir
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.func(scf-for-to-while)' -split-input-file | FileCheck %s
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK-LABEL:   func @single_loop(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<?xi32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: index,
+// CHECK-SAME:                              %[[VAL_2:.*]]: i32) {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK:             %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK:             scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_10:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
+// CHECK:             memref.store %[[VAL_10]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
+// CHECK:             scf.yield %[[VAL_9]] : index
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func @single_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  scf.for %i = %c0 to %arg1 step %c1 {
+    %0 = addi %arg2, %arg2 : i32
+    memref.store %0, %arg0[%i] : memref<?xi32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func @nested_loop(
+// CHECK-SAME:                              %[[VAL_0:.*]]: memref<?xi32>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: index,
+// CHECK-SAME:                              %[[VAL_2:.*]]: i32) {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = scf.while (%[[VAL_6:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK:             %[[VAL_7:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK:             scf.condition(%[[VAL_7]]) %[[VAL_6]] : index
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_8:.*]]: index):
+// CHECK:             %[[VAL_9:.*]] = addi %[[VAL_8]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_10:.*]] = scf.while (%[[VAL_11:.*]] = %[[VAL_3]]) : (index) -> index {
+// CHECK:               %[[VAL_12:.*]] = cmpi slt, %[[VAL_11]], %[[VAL_1]] : index
+// CHECK:               scf.condition(%[[VAL_12]]) %[[VAL_11]] : index
+// CHECK:             } do {
+// CHECK:             ^bb0(%[[VAL_13:.*]]: index):
+// CHECK:               %[[VAL_14:.*]] = addi %[[VAL_13]], %[[VAL_4]] : index
+// CHECK:               %[[VAL_15:.*]] = addi %[[VAL_2]], %[[VAL_2]] : i32
+// CHECK:               memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
+// CHECK:               memref.store %[[VAL_15]], %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<?xi32>
+// CHECK:               scf.yield %[[VAL_14]] : index
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_9]] : index
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
+func @nested_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: i32) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  scf.for %i = %c0 to %arg1 step %c1 {
+    scf.for %j = %c0 to %arg1 step %c1 {
+      %0 = addi %arg2, %arg2 : i32
+      memref.store %0, %arg0[%i] : memref<?xi32>
+      memref.store %0, %arg0[%j] : memref<?xi32>
+    }
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func @for_iter_args(
+// CHECK-SAME:                                %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME:                                %[[VAL_2:.*]]: index) -> f32 {
+// CHECK:           %[[VAL_3:.*]] = constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]]:3 = scf.while (%[[VAL_5:.*]] = %[[VAL_0]], %[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_3]]) : (index, f32, f32) -> (index, f32, f32) {
+// CHECK:             %[[VAL_8:.*]] = cmpi slt, %[[VAL_5]], %[[VAL_1]] : index
+// CHECK:             scf.condition(%[[VAL_8]]) %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] : index, f32, f32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
+// CHECK:             %[[VAL_12:.*]] = addi %[[VAL_9]], %[[VAL_2]] : index
+// CHECK:             %[[VAL_13:.*]] = addf %[[VAL_10]], %[[VAL_11]] : f32
+// CHECK:             scf.yield %[[VAL_12]], %[[VAL_13]], %[[VAL_13]] : index, f32, f32
+// CHECK:           }
+// CHECK:           return %[[VAL_14:.*]]#2 : f32
+// CHECK:         }
+func @for_iter_args(%arg0 : index, %arg1: index, %arg2: index) -> f32 {
+  %s0 = constant 0.0 : f32
+  %result:2 = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iarg0 = %s0, %iarg1 = %s0) -> (f32, f32) {
+    %sn = addf %iarg0, %iarg1 : f32
+    scf.yield %sn, %sn : f32, f32
+  }
+  return %result#1 : f32
+}
+
+// -----
+
+// CHECK-LABEL:   func @exec_region_multiple_yields(
+// CHECK-SAME:                                              %[[VAL_0:.*]]: i32,
+// CHECK-SAME:                                              %[[VAL_1:.*]]: index,
+// CHECK-SAME:                                              %[[VAL_2:.*]]: i32) -> i32 {
+// CHECK:           %[[VAL_3:.*]] = constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = constant 1 : index
+// CHECK:           %[[VAL_5:.*]]:2 = scf.while (%[[VAL_6:.*]] = %[[VAL_3]], %[[VAL_7:.*]] = %[[VAL_0]]) : (index, i32) -> (index, i32) {
+// CHECK:             %[[VAL_8:.*]] = cmpi slt, %[[VAL_6]], %[[VAL_1]] : index
+// CHECK:             scf.condition(%[[VAL_8]]) %[[VAL_6]], %[[VAL_7]] : index, i32
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[VAL_9:.*]]: index, %[[VAL_10:.*]]: i32):
+// CHECK:             %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_4]] : index
+// CHECK:             %[[VAL_12:.*]] = scf.execute_region -> i32 {
+// CHECK:               %[[VAL_13:.*]] = cmpi slt, %[[VAL_9]], %[[VAL_4]] : index
+// CHECK:               cond_br %[[VAL_13]], ^bb1, ^bb2
+// CHECK:             ^bb1:
+// CHECK:               %[[VAL_14:.*]] = subi %[[VAL_10]], %[[VAL_0]] : i32
+// CHECK:               scf.yield %[[VAL_14]] : i32
+// CHECK:             ^bb2:
+// CHECK:               %[[VAL_15:.*]] = muli %[[VAL_10]], %[[VAL_2]] : i32
+// CHECK:               scf.yield %[[VAL_15]] : i32
+// CHECK:             }
+// CHECK:             scf.yield %[[VAL_11]], %[[VAL_16:.*]] : index, i32
+// CHECK:           }
+// CHECK:           return %[[VAL_17:.*]]#1 : i32
+// CHECK:         }
+func @exec_region_multiple_yields(%arg0: i32, %arg1: index, %arg2: i32) -> i32 {
+  %c1_i32 = constant 1 : i32
+  %c2_i32 = constant 2 : i32
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c5 = constant 5 : index
+  %0 = scf.for %i = %c0 to %arg1 step %c1 iter_args(%iarg0 = %arg0) -> i32 {
+    %2 = scf.execute_region -> i32 {
+      %1 = cmpi slt, %i, %c1 : index
+      cond_br %1, ^bb1, ^bb2
+    ^bb1:
+      %2 = subi %iarg0, %arg0 : i32
+      scf.yield %2 : i32
+    ^bb2:
+      %3 = muli %iarg0, %arg2 : i32
+      scf.yield %3 : i32
+    }
+    scf.yield %2 : i32
+  }
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list