[Mlir-commits] [mlir] 3f429e8 - Implement an scf.for range folding optimization pass.
Mehdi Amini
llvmlistbot at llvm.org
Wed Jun 23 18:07:41 PDT 2021
Author: Anthony Canino
Date: 2021-06-24T01:07:28Z
New Revision: 3f429e82d3ea1710ee0a841675acba9bb7b658d2
URL: https://github.com/llvm/llvm-project/commit/3f429e82d3ea1710ee0a841675acba9bb7b658d2
DIFF: https://github.com/llvm/llvm-project/commit/3f429e82d3ea1710ee0a841675acba9bb7b658d2.diff
LOG: Implement an scf.for range folding optimization pass.
In cases where arithmetic (addi/muli) ops are performed on an scf.for loops induction variable with a single use, we can fold those ops directly into the scf.for loop.
For example, in the following code:
```
scf.for %i = %c0 to %arg1 step %c1 {
%0 = addi %arg2, %i : index
%1 = muli %0, %c4 : index
%2 = memref.load %arg0[%1] : memref<?xi32>
%3 = muli %2, %2 : i32
memref.store %3, %arg0[%1] : memref<?xi32>
}
```
we can lift `%0` up into the scf.for loop range, as it is the only user of %i:
```
%lb = addi %arg2, %c0 : index
%ub = addi %arg2, %i : index
scf.for %i = %lb to %ub step %c1 {
%1 = muli %0, %c4 : index
%2 = memref.load %arg0[%1] : memref<?xi32>
%3 = muli %2, %2 : i32
memref.store %3, %arg0[%1] : memref<?xi32>
}
```
Reviewed By: mehdi_amini, ftynse, Anthony
Differential Revision: https://reviews.llvm.org/D104289
Added:
mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
mlir/test/Dialect/SCF/loop-range.mlir
Modified:
mlir/include/mlir/Dialect/SCF/Passes.h
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h
index f3dda9bec335c..2d1f8b5aff051 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Passes.h
@@ -35,6 +35,10 @@ std::unique_ptr<Pass> createParallelLoopSpecializationPass();
std::unique_ptr<Pass>
createParallelLoopTilingPass(llvm::ArrayRef<int64_t> tileSize = {});
+/// Creates a pass which folds arith ops on induction variable into
+/// loop range.
+std::unique_ptr<Pass> createForLoopRangeFoldingPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 458cf769bfc9e..172fb63206809 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -45,4 +45,10 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
let dependentDialects = ["AffineDialect"];
}
+def SCFForLoopRangeFolding
+ : Pass<"for-loop-range-folding"> {
+ let summary = "Fold add/mul ops into loop range";
+ let constructor = "mlir::createForLoopRangeFoldingPass()";
+}
+
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index 3c44e1d0472d7..689b53762a673 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRSCFTransforms
Bufferize.cpp
+ LoopRangeFolding.cpp
LoopSpecialization.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
new file mode 100644
index 0000000000000..e929f710c8758
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopRangeFolding.cpp
@@ -0,0 +1,86 @@
+//===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements loop range folding.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/SCF/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+struct ForLoopRangeFolding
+ : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ForLoopRangeFolding::runOnOperation() {
+ getOperation()->walk([&](ForOp op) {
+ Value indVar = op.getInductionVar();
+
+ auto canBeFolded = [&](Value value) {
+ return op.isDefinedOutsideOfLoop(value) || value == indVar;
+ };
+
+ // Fold until a fixed point is reached
+ while (true) {
+
+ // If the induction variable is used more than once, we can't fold its
+ // arith ops into the loop range
+ if (!indVar.hasOneUse())
+ break;
+
+ Operation *user = *indVar.getUsers().begin();
+ if (!isa<AddIOp, MulIOp>(user))
+ break;
+
+ if (!llvm::all_of(user->getOperands(), canBeFolded))
+ break;
+
+ OpBuilder b(op);
+ BlockAndValueMapping lbMap;
+ lbMap.map(indVar, op.lowerBound());
+ BlockAndValueMapping ubMap;
+ ubMap.map(indVar, op.upperBound());
+ BlockAndValueMapping stepMap;
+ stepMap.map(indVar, op.step());
+
+ if (isa<AddIOp>(user)) {
+ Operation *lbFold = b.clone(*user, lbMap);
+ Operation *ubFold = b.clone(*user, ubMap);
+
+ op.setLowerBound(lbFold->getResult(0));
+ op.setUpperBound(ubFold->getResult(0));
+
+ } else if (isa<MulIOp>(user)) {
+ Operation *ubFold = b.clone(*user, ubMap);
+ Operation *stepFold = b.clone(*user, stepMap);
+
+ op.setUpperBound(ubFold->getResult(0));
+ op.setStep(stepFold->getResult(0));
+ }
+
+ ValueRange wrapIndvar(indVar);
+ user->replaceAllUsesWith(wrapIndvar);
+ user->erase();
+ }
+ });
+}
+
+std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
+ return std::make_unique<ForLoopRangeFolding>();
+}
diff --git a/mlir/test/Dialect/SCF/loop-range.mlir b/mlir/test/Dialect/SCF/loop-range.mlir
new file mode 100644
index 0000000000000..3164a8936adc6
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop-range.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -pass-pipeline='func(for-loop-range-folding)' -split-input-file | FileCheck %s
+
+func @fold_one_loop(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ scf.for %i = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %i : index
+ %1 = muli %0, %c4 : index
+ %2 = memref.load %arg0[%1] : memref<?xi32>
+ %3 = muli %2, %2 : i32
+ memref.store %3, %arg0[%1] : memref<?xi32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_one_loop
+// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] {
+// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+func @fold_one_loop2(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %c10 = constant 10 : index
+ scf.for %j = %c0 to %c10 step %c1 {
+ scf.for %i = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %i : index
+ %1 = muli %0, %c4 : index
+ %2 = memref.load %arg0[%1] : memref<?xi32>
+ %3 = muli %2, %2 : i32
+ memref.store %3, %arg0[%1] : memref<?xi32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_one_loop2
+// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[C10:.*]] = constant 10 : index
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C10]] step %[[C1]] {
+// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I2]] step %[[I3]] {
+// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+func @fold_two_loops(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %c10 = constant 10 : index
+ scf.for %j = %c0 to %c10 step %c1 {
+ scf.for %i = %j to %arg1 step %c1 {
+ %0 = addi %arg2, %i : index
+ %1 = muli %0, %c4 : index
+ %2 = memref.load %arg0[%1] : memref<?xi32>
+ %3 = muli %2, %2 : i32
+ memref.store %3, %arg0[%1] : memref<?xi32>
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_two_loops
+// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[C10:.*]] = constant 10 : index
+// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[C10]] : index
+// CHECK: scf.for %[[J:.*]] = %[[I0]] to %[[I1]] step %[[C1]] {
+// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK: %[[I2:.*]] = muli %[[I1]], %[[C4]] : index
+// CHECK: %[[I3:.*]] = muli %[[C1]], %[[C4]] : index
+// CHECK: scf.for %[[I:.*]] = %[[J]] to %[[I2]] step %[[I3]] {
+// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I]]
+// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I]]
+
+// If an instruction's operands are not defined outside the loop, we cannot
+// perform the optimization, as is the case with the muli below. (If paired
+// with loop invariant code motion we can continue.)
+func @fold_only_first_add(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ scf.for %i = %c0 to %arg1 step %c1 {
+ %0 = addi %arg2, %i : index
+ %1 = addi %arg2, %c4 : index
+ %2 = muli %0, %1 : index
+ %3 = memref.load %arg0[%2] : memref<?xi32>
+ %4 = muli %3, %3 : i32
+ memref.store %4, %arg0[%2] : memref<?xi32>
+ }
+ return
+}
+
+// CHECK-LABEL: func @fold_only_first_add
+// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[C4:.*]] = constant 4 : index
+// CHECK: %[[I0:.*]] = addi %[[ARG2]], %[[C0]] : index
+// CHECK: %[[I1:.*]] = addi %[[ARG2]], %[[ARG1]] : index
+// CHECK: scf.for %[[I:.*]] = %[[I0]] to %[[I1]] step %[[C1]] {
+// CHECK: %[[I2:.*]] = addi %[[ARG2]], %[[C4]] : index
+// CHECK: %[[I3:.*]] = muli %[[I]], %[[I2]] : index
+// CHECK: %[[I4:.*]] = memref.load %[[ARG0]]{{\[}}%[[I3]]
+// CHECK: %[[I5:.*]] = muli %[[I4]], %[[I4]] : i32
+// CHECK: memref.store %[[I5]], %[[ARG0]]{{\[}}%[[I3]]
More information about the Mlir-commits
mailing list