[Mlir-commits] [mlir] 3f429e8 - Implement an scf.for range folding optimization pass.

Mehdi Amini llvmlistbot at llvm.org
Wed Jun 23 18:07:41 PDT 2021


Author: Anthony Canino
Date: 2021-06-24T01:07:28Z
New Revision: 3f429e82d3ea1710ee0a841675acba9bb7b658d2

URL: https://github.com/llvm/llvm-project/commit/3f429e82d3ea1710ee0a841675acba9bb7b658d2
DIFF: https://github.com/llvm/llvm-project/commit/3f429e82d3ea1710ee0a841675acba9bb7b658d2.diff

LOG: Implement an scf.for range folding optimization pass.

In cases where arithmetic (addi/muli) ops are performed on an scf.for loops induction variable with a single use, we can fold those ops directly into the scf.for loop.

For example, in the following code:

```
scf.for %i = %c0 to %arg1 step %c1 {
  %0 = addi %arg2, %i : index
  %1 = muli %0, %c4 : index
  %2 = memref.load %arg0[%1] : memref<?xi32>
  %3 = muli %2, %2 : i32
  memref.store %3, %arg0[%1] : memref<?xi32>
}
```

we can lift `%0` up into the scf.for loop range, as it is the only user of %i:

```
%lb = addi %arg2, %c0 : index
%ub = addi %arg2, %i : index
scf.for %i = %lb to %ub step %c1 {
  %1 = muli %0, %c4 : index
  %2 = memref.load %arg0[%1] : memref<?xi32>
  %3 = muli %2, %2 : i32
  memref.store %3, %arg0[%1] : memref<?xi32>
}
```

Reviewed By: mehdi_amini, ftynse, Anthony

Differential Revision: https://reviews.llvm.org/D104289

Added: 
    mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
    mlir/test/Dialect/SCF/loop-range.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Passes.h
    mlir/include/mlir/Dialect/SCF/Passes.td
    mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index f3dda9bec335c..2d1f8b5aff051 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -35,6 +35,10 @@ std::unique_ptr<Pass> createParallelLoopSpecializationPass();
 std::unique_ptr<Pass>
 createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {});
 
+/// Creates a pass which folds arith ops on induction variable into
+/// loop range.
+std::unique_ptr<Pass> createForLoopRangeFoldingPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 458cf769bfc9e..172fb63206809 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -45,4 +45,10 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
   let dependentDialects = ["AffineDialect"];
 }
 
+def SCFForLoopRangeFolding
+    : Pass<"for-loop-range-folding"> {
+  let summary = "Fold add/mul ops into loop range";
+  let constructor = "mlir::createForLoopRangeFoldingPass()";
+}
+
 #endif // MLIR_DIALECT_SCF_PASSES

diff  --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 3c44e1d0472d7..689b53762a673 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRSCFTransforms
   Bufferize.cpp
+  LoopRangeFolding.cpp
   LoopSpecialization.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
new file mode 100644
index 0000000000000..e929f710c8758
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
@@ -0,0 +1,86 @@
+//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop range folding.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+struct ForLoopRangeFolding
+    : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ForLoopRangeFolding::runOnOperation() {
+  getOperation()->walk([&](ForOp op) {
+    Value indVar = op.getInductionVar();
+
+    auto canBeFolded = [&](Value value) {
+      return op.isDefinedOutsideOfLoop(value) || value == indVar;
+    };
+
+    // Fold until a fixed point is reached
+    while (true) {
+
+      // If the induction variable is used more than once, we can't fold its
+      // arith ops into the loop range
+      if (!indVar.hasOneUse())
+        break;
+
+      Operation *user = *indVar.getUsers().begin();
+      if (!isa<AddIOp, MulIOp>(user))
+        break;
+
+      if (!llvm::all_of(user->getOperands(), canBeFolded))
+        break;
+
+      OpBuilder b(op);
+      BlockAndValueMapping lbMap;
+      lbMap.map(indVar, op.lowerBound());
+      BlockAndValueMapping ubMap;
+      ubMap.map(indVar, op.upperBound());
+      BlockAndValueMapping stepMap;
+      stepMap.map(indVar, op.step());
+
+      if (isa<AddIOp>(user)) {
+        Operation *lbFold = b.clone(*user, lbMap);
+        Operation *ubFold = b.clone(*user, ubMap);
+
+        op.setLowerBound(lbFold->getResult(0));
+        op.setUpperBound(ubFold->getResult(0));
+
+      } else if (isa<MulIOp>(user)) {
+        Operation *ubFold = b.clone(*user, ubMap);
+        Operation *stepFold = b.clone(*user, stepMap);
+
+        op.setUpperBound(ubFold->getResult(0));
+        op.setStep(stepFold->getResult(0));
+      }
+
+      ValueRange wrapIndvar(indVar);
+      user->replaceAllUsesWith(wrapIndvar);
+      user->erase();
+    }
+  });
+}
+
+std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
+  return std::make_unique<ForLoopRangeFolding>();
+}

diff  --git a/mlir/test/Dialect/SCF/loop-range.mlir b/mlir/test/Dialect/SCF/loop-range.mlir
new file mode 100644
index 0000000000000..3164a8936adc6
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-range.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -pass-pipeline='func(for-loop-range-folding)' -split-input-file | FileCheck %s
+
+func @fold_one_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  scf.for %i = %c0 to %arg1 step %c1 {
+    %0 = addi %arg2, %i : index
+    %1 = muli %0, %c4 : index
+    %2 = memref.load %arg0[%1] : memref<?xi32>
+    %3 = muli %2, %2 : i32
+    memref.store %3, %arg0[%1] : memref<?xi32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_one_loop
+// CHECK-SAME:   (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C1:.*]] = constant 1 : index
+// CHECK:       %[[C4:.*]] = constant 4 : index
+// CHECK:       %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK:       %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK:       %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK:       %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK:       scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] {
+// CHECK:         %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK:         %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK:         memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+func @fold_one_loop2(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %c10 = constant 10 : index
+  scf.for %j = %c0 to %c10 step %c1 {
+    scf.for %i = %c0 to %arg1 step %c1 {
+      %0 = addi %arg2, %i : index
+      %1 = muli %0, %c4 : index
+      %2 = memref.load %arg0[%1] : memref<?xi32>
+      %3 = muli %2, %2 : i32
+      memref.store %3, %arg0[%1] : memref<?xi32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_one_loop2
+// CHECK-SAME:   (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C1:.*]] = constant 1 : index
+// CHECK:       %[[C4:.*]] = constant 4 : index
+// CHECK:       %[[C10:.*]] = constant 10 : index
+// CHECK:       scf.for %[[J:.*]] = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK:         %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK:         %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK:         %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK:         %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK:         scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] {
+// CHECK:           %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK:           %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK:           memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+func @fold_two_loops(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %c10 = constant 10 : index
+  scf.for %j = %c0 to %c10 step %c1 {
+    scf.for %i = %j to %arg1 step %c1 {
+      %0 = addi %arg2, %i : index
+      %1 = muli %0, %c4 : index
+      %2 = memref.load %arg0[%1] : memref<?xi32>
+      %3 = muli %2, %2 : i32
+      memref.store %3, %arg0[%1] : memref<?xi32>
+    }
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_two_loops
+// CHECK-SAME:   (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C1:.*]] = constant 1 : index
+// CHECK:       %[[C4:.*]] = constant 4 : index
+// CHECK:       %[[C10:.*]] = constant 10 : index
+// CHECK:       %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK:       %[[I1:.*]] = addi %[[ARG2]], %[[C10]] : index
+// CHECK:       scf.for %[[J:.*]] = %[[I0]] to %[[I1]] step %[[C1]] {
+// CHECK:         %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK:         %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK:         %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK:         scf.for %[[I:.*]] = %[[J]] to %[[I2]] step %[[I3]] {
+// CHECK:           %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK:           %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK:           memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+// If an instruction's operands are not defined outside the loop, we cannot
+// perform the optimization, as is the case with the muli below. (If paired
+// with loop invariant code motion we can continue.)
+func @fold_only_first_add(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  scf.for %i = %c0 to %arg1 step %c1 {
+    %0 = addi %arg2, %i : index
+    %1 = addi %arg2, %c4 : index
+    %2 = muli %0, %1 : index
+    %3 = memref.load %arg0[%2] : memref<?xi32>
+    %4 = muli %3, %3 : i32
+    memref.store %4, %arg0[%2] : memref<?xi32>
+  }
+  return
+}
+
+// CHECK-LABEL: func @fold_only_first_add
+// CHECK-SAME:   (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK:       %[[C0:.*]] = constant 0 : index
+// CHECK:       %[[C1:.*]] = constant 1 : index
+// CHECK:       %[[C4:.*]] = constant 4 : index
+// CHECK:       %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK:       %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK:       scf.for %[[I:.*]] = %[[I0]] to %[[I1]] step %[[C1]] {
+// CHECK:         %[[I2:.*]] = addi %[[ARG2]], %[[C4]] : index
+// CHECK:         %[[I3:.*]] = muli %[[I]], %[[I2]] : index
+// CHECK:         %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I3]]
+// CHECK:         %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK:         memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I3]]


        


More information about the Mlir-commits mailing list