[flang-commits] [flang] 2e7202b - [fir] Add data flow optimization pass
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Mon Nov 29 02:00:16 PST 2021
Author: Valentin Clement
Date: 2021-11-29T11:00:09+01:00
New Revision: 2e7202b0082fd0e22589949aa4d3472d201949b2
URL: https://github.com/llvm/llvm-project/commit/2e7202b0082fd0e22589949aa4d3472d201949b2
DIFF: https://github.com/llvm/llvm-project/commit/2e7202b0082fd0e22589949aa4d3472d201949b2.diff
LOG: [fir] Add data flow optimization pass
Add pass to perform store/load forwarding and potentially removing dead
stores.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: kiranchandramohan, schweitz, mehdi_amini, awarzynski
Differential Revision: https://reviews.llvm.org/D111288
Added:
flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
flang/test/Fir/memref-data-flow.fir
Modified:
flang/include/flang/Optimizer/Transforms/Passes.h
flang/include/flang/Optimizer/Transforms/Passes.td
flang/lib/Optimizer/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index bdcd6fc9f7cb5..ddc83d6fdb39e 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createAffineDemotionPass();
std::unique_ptr<mlir::Pass> createFirToCfgPass();
std::unique_ptr<mlir::Pass> createCharacterConversionPass();
std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
+std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
/// Support for inlining on FIR.
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 64cee129c18ee..bc16c8a62826c 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -120,4 +120,17 @@ def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
let constructor = "::fir::createExternalNameConversionPass()";
}
+def MemRefDataFlowOpt : FunctionPass<"fir-memref-dataflow-opt"> {
+ let summary =
+ "Perform store/load forwarding and potentially removing dead stores.";
+ let description = [{
+ This pass performs store to load forwarding to eliminate memory accesses and
+ potentially the entire allocation if all the accesses are forwarded.
+ }];
+ let constructor = "::fir::createMemDataFlowOptPass()";
+ let dependentDialects = [
+ "fir::FIROpsDialect", "mlir::StandardOpsDialect"
+ ];
+}
+
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 16184c486f4e6..11e30730dc9ee 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
CharacterConversion.cpp
Inliner.cpp
ExternalNameConversion.cpp
+ MemRefDataFlowOpt.cpp
RewriteLoop.cpp
DEPENDS
diff --git a/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
new file mode 100644
index 0000000000000..83c15b807a77d
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
@@ -0,0 +1,130 @@
+//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+#define DEBUG_TYPE "fir-memref-dataflow-opt"
+
+namespace {
+
+template <typename OpT>
+static std::vector<OpT> getSpecificUsers(mlir::Value v) {
+ std::vector<OpT> ops;
+ for (mlir::Operation *user : v.getUsers())
+ if (auto op = dyn_cast<OpT>(user))
+ ops.push_back(op);
+ return ops;
+}
+
+/// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
+/// and AffineWrite interface
+template <typename ReadOp, typename WriteOp>
+class LoadStoreForwarding {
+public:
+ LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
+
+ // FIXME: This algorithm has a bug. It ignores escaping references between a
+ // store and a load.
+ llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp,
+ std::vector<WriteOp> &&storeOps) {
+ llvm::SmallVector<WriteOp> candidateSet;
+
+ for (auto storeOp : storeOps)
+ if (domInfo->dominates(storeOp, loadOp))
+ candidateSet.push_back(storeOp);
+
+ if (candidateSet.empty())
+ return {};
+
+ llvm::Optional<WriteOp> nearestStore;
+ for (auto candidate : candidateSet) {
+ auto nearerThan = [&](WriteOp otherStore) {
+ if (candidate == otherStore)
+ return false;
+ bool rv = domInfo->properlyDominates(candidate, otherStore);
+ if (rv) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "candidate " << candidate << " is not the nearest to "
+ << loadOp << " because " << otherStore << " is closer\n");
+ }
+ return rv;
+ };
+ if (!llvm::any_of(candidateSet, nearerThan)) {
+ nearestStore = mlir::cast<WriteOp>(candidate);
+ break;
+ }
+ }
+ if (!nearestStore) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "load " << loadOp << " has " << candidateSet.size()
+ << " store candidates, but this algorithm can't find a best.\n");
+ }
+ return nearestStore;
+ }
+
+ llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp,
+ std::vector<ReadOp> &&loadOps) {
+ for (auto &loadOp : loadOps) {
+ if (domInfo->dominates(storeOp, loadOp))
+ return loadOp;
+ }
+ return {};
+ }
+
+private:
+ mlir::DominanceInfo *domInfo;
+};
+
+class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> {
+public:
+ void runOnFunction() override {
+ mlir::FuncOp f = getFunction();
+
+ auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
+ LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
+ f.walk([&](fir::LoadOp loadOp) {
+ auto maybeStore = lsf.findStoreToForward(
+ loadOp, getSpecificUsers<fir::StoreOp>(loadOp.memref()));
+ if (maybeStore) {
+ auto storeOp = maybeStore.getValue();
+ LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
+ << " erasing load " << loadOp
+ << " with value from " << storeOp << '\n');
+ loadOp.getResult().replaceAllUsesWith(storeOp.value());
+ loadOp.erase();
+ }
+ });
+ f.walk([&](fir::AllocaOp alloca) {
+ for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
+ if (!lsf.findReadForWrite(
+ storeOp, getSpecificUsers<fir::LoadOp>(storeOp.memref()))) {
+ LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
+ << " erasing store " << storeOp << '\n');
+ storeOp.erase();
+ }
+ }
+ });
+ }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
+ return std::make_unique<MemDataFlowOpt>();
+}
diff --git a/flang/test/Fir/memref-data-flow.fir b/flang/test/Fir/memref-data-flow.fir
new file mode 100644
index 0000000000000..797d2a0ab3d2b
--- /dev/null
+++ b/flang/test/Fir/memref-data-flow.fir
@@ -0,0 +1,79 @@
+// RUN: fir-opt --split-input-file --fir-memref-dataflow-opt %s | FileCheck %s
+
+// Test that all load-store chains are removed
+
+func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.ref<!fir.array<60xi32>>, %arg2: !fir.ref<!fir.array<60xi32>>) {
+ %c1_i64 = arith.constant 1 : i64
+ %c60 = arith.constant 60 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
+ %1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
+ br ^bb1(%c1, %c60 : index, index)
+^bb1(%2: index, %3: index): // 2 preds: ^bb0, ^bb2
+ %4 = arith.cmpi sgt, %3, %c0 : index
+ cond_br %4, ^bb2, ^bb3
+^bb2: // pred: ^bb1
+ %5 = fir.convert %2 : (index) -> i32
+ fir.store %5 to %0 : !fir.ref<i32>
+ %6 = fir.load %0 : !fir.ref<i32>
+ %7 = fir.convert %6 : (i32) -> i64
+ %8 = arith.subi %7, %c1_i64 : i64
+ %9 = fir.coordinate_of %arg0, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+ %10 = fir.load %9 : !fir.ref<i32>
+ %11 = arith.addi %10, %10 : i32
+ %12 = fir.coordinate_of %1, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+ fir.store %11 to %12 : !fir.ref<i32>
+ %13 = arith.addi %2, %c1 : index
+ %14 = arith.subi %3, %c1 : index
+ br ^bb1(%13, %14 : index, index)
+^bb3: // pred: ^bb1
+ %15 = fir.convert %2 : (index) -> i32
+ fir.store %15 to %0 : !fir.ref<i32>
+ br ^bb4(%c1, %c60 : index, index)
+^bb4(%16: index, %17: index): // 2 preds: ^bb3, ^bb5
+ %18 = arith.cmpi sgt, %17, %c0 : index
+ cond_br %18, ^bb5, ^bb6
+^bb5: // pred: ^bb4
+ %19 = fir.convert %16 : (index) -> i32
+ fir.store %19 to %0 : !fir.ref<i32>
+ %20 = fir.load %0 : !fir.ref<i32>
+ %21 = fir.convert %20 : (i32) -> i64
+ %22 = arith.subi %21, %c1_i64 : i64
+ %23 = fir.coordinate_of %1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+ %24 = fir.load %23 : !fir.ref<i32>
+ %25 = fir.coordinate_of %arg1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+ %26 = fir.load %25 : !fir.ref<i32>
+ %27 = arith.muli %24, %26 : i32
+ %28 = fir.coordinate_of %arg2, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+ fir.store %27 to %28 : !fir.ref<i32>
+ %29 = arith.addi %16, %c1 : index
+ %30 = arith.subi %17, %c1 : index
+ br ^bb4(%29, %30 : index, index)
+^bb6: // pred: ^bb4
+ %31 = fir.convert %16 : (index) -> i32
+ fir.store %31 to %0 : !fir.ref<i32>
+ return
+}
+
+// CHECK-LABEL: func @load_store_chain_removal
+// CHECK-LABEL: ^bb1
+// CHECK-LABEL: ^bb2:
+// Make sure the previous fir.store/fir.load pair have been elimated and we
+// preserve the last pair of fir.load/fir.store.
+// CHECK-COUNT-1: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK-COUNT-1: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb3:
+// Make sure the fir.store has been removed.
+// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb5:
+// CHECK: %{{.*}} = fir.convert %{{.*}} : (index) -> i32
+// Check that the fir.store/fir.load pair has been removed between the convert.
+// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-NOT: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK: %{{.*}} = fir.convert %{{.*}} : (i32) -> i64
+// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb6:
+// CHECK-NOT: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
More information about the flang-commits
mailing list