[flang-commits] [flang] 2e7202b - [fir] Add data flow optimization pass

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Nov 29 02:00:16 PST 2021


Author: Valentin Clement
Date: 2021-11-29T11:00:09+01:00
New Revision: 2e7202b0082fd0e22589949aa4d3472d201949b2

URL: https://github.com/llvm/llvm-project/commit/2e7202b0082fd0e22589949aa4d3472d201949b2
DIFF: https://github.com/llvm/llvm-project/commit/2e7202b0082fd0e22589949aa4d3472d201949b2.diff

LOG: [fir] Add data flow optimization pass

Add pass to perform store/load forwarding and potentially removing dead
stores.

This patch is part of the upstreaming effort from fir-dev branch.

Reviewed By: kiranchandramohan, schweitz, mehdi_amini, awarzynski

Differential Revision: https://reviews.llvm.org/D111288

Added: 
    flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
    flang/test/Fir/memref-data-flow.fir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index bdcd6fc9f7cb5..ddc83d6fdb39e 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -31,6 +31,7 @@ std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass> createFirToCfgPass();
 std::unique_ptr<mlir::Pass> createCharacterConversionPass();
 std::unique_ptr<mlir::Pass> createExternalNameConversionPass();
+std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
 std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 
 /// Support for inlining on FIR.

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 64cee129c18ee..bc16c8a62826c 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -120,4 +120,17 @@ def ExternalNameConversion : Pass<"external-name-interop", "mlir::ModuleOp"> {
   let constructor = "::fir::createExternalNameConversionPass()";
 }
 
+def MemRefDataFlowOpt : FunctionPass<"fir-memref-dataflow-opt"> {
+  let summary =
+    "Perform store/load forwarding and potentially removing dead stores.";
+  let description = [{
+    This pass performs store to load forwarding to eliminate memory accesses and
+    potentially the entire allocation if all the accesses are forwarded.
+  }];
+  let constructor = "::fir::createMemDataFlowOptPass()";
+  let dependentDialects = [
+    "fir::FIROpsDialect", "mlir::StandardOpsDialect"
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 16184c486f4e6..11e30730dc9ee 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
   CharacterConversion.cpp
   Inliner.cpp
   ExternalNameConversion.cpp
+  MemRefDataFlowOpt.cpp
   RewriteLoop.cpp
 
   DEPENDS

diff  --git a/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
new file mode 100644
index 0000000000000..83c15b807a77d
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/MemRefDataFlowOpt.cpp
@@ -0,0 +1,130 @@
+//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+#define DEBUG_TYPE "fir-memref-dataflow-opt"
+
+namespace {
+
+template <typename OpT>
+static std::vector<OpT> getSpecificUsers(mlir::Value v) {
+  std::vector<OpT> ops;
+  for (mlir::Operation *user : v.getUsers())
+    if (auto op = dyn_cast<OpT>(user))
+      ops.push_back(op);
+  return ops;
+}
+
+/// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead
+/// and AffineWrite interface
+template <typename ReadOp, typename WriteOp>
+class LoadStoreForwarding {
+public:
+  LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {}
+
+  // FIXME: This algorithm has a bug. It ignores escaping references between a
+  // store and a load.
+  llvm::Optional<WriteOp> findStoreToForward(ReadOp loadOp,
+                                             std::vector<WriteOp> &&storeOps) {
+    llvm::SmallVector<WriteOp> candidateSet;
+
+    for (auto storeOp : storeOps)
+      if (domInfo->dominates(storeOp, loadOp))
+        candidateSet.push_back(storeOp);
+
+    if (candidateSet.empty())
+      return {};
+
+    llvm::Optional<WriteOp> nearestStore;
+    for (auto candidate : candidateSet) {
+      auto nearerThan = [&](WriteOp otherStore) {
+        if (candidate == otherStore)
+          return false;
+        bool rv = domInfo->properlyDominates(candidate, otherStore);
+        if (rv) {
+          LLVM_DEBUG(llvm::dbgs()
+                     << "candidate " << candidate << " is not the nearest to "
+                     << loadOp << " because " << otherStore << " is closer\n");
+        }
+        return rv;
+      };
+      if (!llvm::any_of(candidateSet, nearerThan)) {
+        nearestStore = mlir::cast<WriteOp>(candidate);
+        break;
+      }
+    }
+    if (!nearestStore) {
+      LLVM_DEBUG(
+          llvm::dbgs()
+          << "load " << loadOp << " has " << candidateSet.size()
+          << " store candidates, but this algorithm can't find a best.\n");
+    }
+    return nearestStore;
+  }
+
+  llvm::Optional<ReadOp> findReadForWrite(WriteOp storeOp,
+                                          std::vector<ReadOp> &&loadOps) {
+    for (auto &loadOp : loadOps) {
+      if (domInfo->dominates(storeOp, loadOp))
+        return loadOp;
+    }
+    return {};
+  }
+
+private:
+  mlir::DominanceInfo *domInfo;
+};
+
+class MemDataFlowOpt : public fir::MemRefDataFlowOptBase<MemDataFlowOpt> {
+public:
+  void runOnFunction() override {
+    mlir::FuncOp f = getFunction();
+
+    auto *domInfo = &getAnalysis<mlir::DominanceInfo>();
+    LoadStoreForwarding<fir::LoadOp, fir::StoreOp> lsf(domInfo);
+    f.walk([&](fir::LoadOp loadOp) {
+      auto maybeStore = lsf.findStoreToForward(
+          loadOp, getSpecificUsers<fir::StoreOp>(loadOp.memref()));
+      if (maybeStore) {
+        auto storeOp = maybeStore.getValue();
+        LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
+                                << " erasing load " << loadOp
+                                << " with value from " << storeOp << '\n');
+        loadOp.getResult().replaceAllUsesWith(storeOp.value());
+        loadOp.erase();
+      }
+    });
+    f.walk([&](fir::AllocaOp alloca) {
+      for (auto &storeOp : getSpecificUsers<fir::StoreOp>(alloca.getResult())) {
+        if (!lsf.findReadForWrite(
+                storeOp, getSpecificUsers<fir::LoadOp>(storeOp.memref()))) {
+          LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName()
+                                  << " erasing store " << storeOp << '\n');
+          storeOp.erase();
+        }
+      }
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createMemDataFlowOptPass() {
+  return std::make_unique<MemDataFlowOpt>();
+}

diff  --git a/flang/test/Fir/memref-data-flow.fir b/flang/test/Fir/memref-data-flow.fir
new file mode 100644
index 0000000000000..797d2a0ab3d2b
--- /dev/null
+++ b/flang/test/Fir/memref-data-flow.fir
@@ -0,0 +1,79 @@
+// RUN: fir-opt --split-input-file --fir-memref-dataflow-opt %s | FileCheck %s
+
+// Test that all load-store chains are removed
+
+func @load_store_chain_removal(%arg0: !fir.ref<!fir.array<60xi32>>, %arg1: !fir.ref<!fir.array<60xi32>>, %arg2: !fir.ref<!fir.array<60xi32>>) {
+  %c1_i64 = arith.constant 1 : i64
+  %c60 = arith.constant 60 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFf1dcEi"}
+  %1 = fir.alloca !fir.array<60xi32> {bindc_name = "t1", uniq_name = "_QFf1dcEt1"}
+  br ^bb1(%c1, %c60 : index, index)
+^bb1(%2: index, %3: index):  // 2 preds: ^bb0, ^bb2
+  %4 = arith.cmpi sgt, %3, %c0 : index
+  cond_br %4, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  %5 = fir.convert %2 : (index) -> i32
+  fir.store %5 to %0 : !fir.ref<i32>
+  %6 = fir.load %0 : !fir.ref<i32>
+  %7 = fir.convert %6 : (i32) -> i64
+  %8 = arith.subi %7, %c1_i64 : i64
+  %9 = fir.coordinate_of %arg0, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+  %10 = fir.load %9 : !fir.ref<i32>
+  %11 = arith.addi %10, %10 : i32
+  %12 = fir.coordinate_of %1, %8 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+  fir.store %11 to %12 : !fir.ref<i32>
+  %13 = arith.addi %2, %c1 : index
+  %14 = arith.subi %3, %c1 : index
+  br ^bb1(%13, %14 : index, index)
+^bb3:  // pred: ^bb1
+  %15 = fir.convert %2 : (index) -> i32
+  fir.store %15 to %0 : !fir.ref<i32>
+  br ^bb4(%c1, %c60 : index, index)
+^bb4(%16: index, %17: index):  // 2 preds: ^bb3, ^bb5
+  %18 = arith.cmpi sgt, %17, %c0 : index
+  cond_br %18, ^bb5, ^bb6
+^bb5:  // pred: ^bb4
+  %19 = fir.convert %16 : (index) -> i32
+  fir.store %19 to %0 : !fir.ref<i32>
+  %20 = fir.load %0 : !fir.ref<i32>
+  %21 = fir.convert %20 : (i32) -> i64
+  %22 = arith.subi %21, %c1_i64 : i64
+  %23 = fir.coordinate_of %1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+  %24 = fir.load %23 : !fir.ref<i32>
+  %25 = fir.coordinate_of %arg1, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+  %26 = fir.load %25 : !fir.ref<i32>
+  %27 = arith.muli %24, %26 : i32
+  %28 = fir.coordinate_of %arg2, %22 : (!fir.ref<!fir.array<60xi32>>, i64) -> !fir.ref<i32>
+  fir.store %27 to %28 : !fir.ref<i32>
+  %29 = arith.addi %16, %c1 : index
+  %30 = arith.subi %17, %c1 : index
+  br ^bb4(%29, %30 : index, index)
+^bb6:  // pred: ^bb4
+  %31 = fir.convert %16 : (index) -> i32
+  fir.store %31 to %0 : !fir.ref<i32>
+  return
+}
+
+// CHECK-LABEL: func @load_store_chain_removal
+// CHECK-LABEL: ^bb1
+// CHECK-LABEL: ^bb2:
+// Make sure the previous fir.store/fir.load pair have been elimated and we 
+// preserve the last pair of fir.load/fir.store.
+// CHECK-COUNT-1: %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK-COUNT-1: fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb3:
+// Make sure the fir.store has been removed.
+// CHECK-NOT:     fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb5:
+// CHECK:         %{{.*}} = fir.convert %{{.*}} : (index) -> i32
+// Check that the fir.store/fir.load pair has been removed between the convert.
+// CHECK-NOT:     fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-NOT:     %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK:         %{{.*}} = fir.convert %{{.*}} : (i32) -> i64
+// CHECK:         %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK:         %{{.*}} = fir.load %{{.*}} : !fir.ref<i32>
+// CHECK:         fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>
+// CHECK-LABEL: ^bb6:
+// CHECK-NOT:     fir.store %{{.*}} to %{{.*}} : !fir.ref<i32>


        


More information about the flang-commits mailing list