[flang-commits] [flang] [flang] Add stack reclaim pass to reclaim allocas in loop (PR #95309)

via flang-commits flang-commits at lists.llvm.org
Wed Jun 12 13:49:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-driver

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Some passes in the flang pipeline are creating `fir.alloca` operation like `hlfir.concat`. When these allocas are located in a loop, the stack can quickly be used too much leading to segfaults. 

This behavior can be seen in https://github.com/jacobwilliams/json-fortran/blob/master/src/tests/jf_test_36.F90

This patch insert a call to LLVM stacksave/stackrestore in the body of the loop to reclaim the alloca in its scope. 

This PR is an alternative implementation to #<!-- -->95173

---
Full diff: https://github.com/llvm/llvm-project/pull/95309.diff


10 Files Affected:

- (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1) 
- (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+9) 
- (modified) flang/include/flang/Tools/CLOptions.inc (+1) 
- (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1) 
- (added) flang/lib/Optimizer/Transforms/StackReclaim.cpp (+56) 
- (modified) flang/test/Driver/bbc-mlir-pass-pipeline.f90 (+4) 
- (modified) flang/test/Driver/mlir-debug-pass-pipeline.f90 (+4) 
- (modified) flang/test/Driver/mlir-pass-pipeline.f90 (+4) 
- (modified) flang/test/Fir/basic-program.fir (+4) 
- (added) flang/test/Transforms/stack-reclaime.fir (+14) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 9fa819e2bf502..1ca1539e76fc6 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -49,6 +49,7 @@ namespace fir {
 #define GEN_PASS_DECL_OPENACCDATAOPERANDCONVERSION
 #define GEN_PASS_DECL_ADDDEBUGINFO
 #define GEN_PASS_DECL_STACKARRAYS
+#define GEN_PASS_DECL_STACKRECLAIM
 #define GEN_PASS_DECL_LOOPVERSIONING
 #define GEN_PASS_DECL_ADDALIASTAGS
 #define GEN_PASS_DECL_OMPMAPINFOFINALIZATIONPASS
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 7a3baca4c19da..27aee5650e75d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -260,6 +260,15 @@ def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
   let dependentDialects = [ "fir::FIROpsDialect" ];
 }
 
+def StackReclaim : Pass<"stack-reclaim"> {
+  let summary = "Insert stacksave/stackrestore in region with allocas";
+  let description = [{
+    Insert stacksave/stackrestore in loop region to reclaim alloca done in its
+    scope.
+  }];
+  let dependentDialects = [ "mlir::LLVM::LLVMDialect" ];
+}
+
 def AddAliasTags : Pass<"fir-add-alias-tags", "mlir::ModuleOp"> {
   let summary = "Add tbaa tags to operations that implement FirAliasAnalysisOpInterface";
   let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 2a0cfc04aa350..df396e04b2a76 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -295,6 +295,7 @@ inline void createDefaultFIROptimizerPassPipeline(
   if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
     pm.addPass(fir::createAddAliasTags());
 
+  addNestedPassToAllTopLevelOperations(pm, fir::createStackReclaim);
   // convert control flow to CFG form
   fir::addCfgConversionPass(pm, pc);
   pm.addPass(mlir::createConvertSCFToCFPass());
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 5ef930fdb2c2f..149afdf601c93 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -21,6 +21,7 @@ add_flang_library(FIRTransforms
   OMPFunctionFiltering.cpp
   OMPMapInfoFinalization.cpp
   OMPMarkDeclareTarget.cpp
+  StackReclaim.cpp
   VScaleAttr.cpp
   FunctionAttr.cpp
   DebugTypeGenerator.cpp
diff --git a/flang/lib/Optimizer/Transforms/StackReclaim.cpp b/flang/lib/Optimizer/Transforms/StackReclaim.cpp
new file mode 100644
index 0000000000000..541311aaa183e
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/StackReclaim.cpp
@@ -0,0 +1,56 @@
+//===- StackReclaim.cpp -- Insert stacksave/stackrestore in region --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/Fortran.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_STACKRECLAIM
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+class StackReclaimPass : public fir::impl::StackReclaimBase<StackReclaimPass> {
+public:
+  using StackReclaimBase<StackReclaimPass>::StackReclaimBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void StackReclaimPass::runOnOperation() {
+  auto *op = getOperation();
+  auto *context = &getContext();
+  mlir::OpBuilder builder(context);
+  mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(context);
+
+  op->walk([&](mlir::Operation *op) {
+    if (!mlir::isa<fir::DoLoopOp>(op))
+      return;
+
+    auto loopOp = mlir::dyn_cast<fir::DoLoopOp>(op);
+    mlir::Location loc = loopOp.getLoc();
+
+    if (!loopOp.getRegion().getOps<fir::AllocaOp>().empty()) {
+      builder.setInsertionPointToStart(&loopOp.getRegion().front());
+      auto stackSaveOp = builder.create<LLVM::StackSaveOp>(loc, voidPtr);
+
+      auto *terminator = loopOp.getRegion().back().getTerminator();
+      builder.setInsertionPoint(terminator);
+      builder.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
+    }
+  });
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c94b98c7c5805..5520d750e2ce1 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -50,12 +50,16 @@
 
 ! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! CHECK-NEXT: 'fir.global' Pipeline
+! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
 ! CHECK-NEXT: 'func.func' Pipeline
+! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
 ! CHECK-NEXT: 'omp.declare_reduction' Pipeline
+! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
 ! CHECK-NEXT: 'omp.private' Pipeline
+! CHECK-NEXT:   StackReclaim
 ! CHECK-NEXT:   CFGConversion
 
 ! CHECK-NEXT: SCFToControlFlow
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 49b1f8c5c3134..6e9846fa422e5 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -77,12 +77,16 @@
 
 ! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT:   'fir.global' Pipeline
+! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
 ! ALL-NEXT:   'func.func' Pipeline
+! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
 ! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
 ! ALL-NEXT:   'omp.private' Pipeline
+! ALL-NEXT:     StackReclaim
 ! ALL-NEXT:     CFGConversion
 ! ALL-NEXT: SCFToControlFlow
 ! ALL-NEXT: Canonicalizer
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 8e1a3d43edd1c..db4551e93fe64 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -85,12 +85,16 @@
 
 ! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 ! ALL-NEXT:    'fir.global' Pipeline
+! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
 ! ALL-NEXT:    'func.func' Pipeline
+! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
 ! ALL-NEXT:   'omp.declare_reduction' Pipeline
+! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
 ! ALL-NEXT:   'omp.private' Pipeline
+! ALL-NEXT:      StackReclaim
 ! ALL-NEXT:      CFGConversion
 
 ! ALL-NEXT: SCFToControlFlow
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index dd184d99cb809..7bbfd709b0aaf 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -85,12 +85,16 @@ func.func @_QQmain() {
 
 // PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
 // PASSES-NEXT: 'fir.global' Pipeline
+// PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
 // PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
 // PASSES-NEXT: 'omp.declare_reduction' Pipeline
+// PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
 // PASSES-NEXT: 'omp.private' Pipeline
+// PASSES-NEXT:   StackReclaim
 // PASSES-NEXT:   CFGConversion
 
 // PASSES-NEXT: SCFToControlFlow
diff --git a/flang/test/Transforms/stack-reclaime.fir b/flang/test/Transforms/stack-reclaime.fir
new file mode 100644
index 0000000000000..b53cc96035751
--- /dev/null
+++ b/flang/test/Transforms/stack-reclaime.fir
@@ -0,0 +1,14 @@
+// RUN: fir-opt --split-input-file --stack-reclaim %s | FileCheck %s
+
+func.func @alloca_in_loop(%lb : index, %ub : index, %step : index, %b : i1, %addr : !fir.ref<index>) {
+  fir.do_loop %iv = %lb to %ub step %step unordered {
+    %0 = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @alloca_in_loop
+// CHECK: fir.do_loop
+// CHECK: %[[STACKPTR:.*]] = llvm.intr.stacksave : !llvm.ptr
+// CHECK: %{{.*}} = fir.alloca !fir.box<!fir.heap<!fir.char<1,?>>>
+// CHECK: llvm.intr.stackrestore %0 : !llvm.ptr

``````````

</details>


https://github.com/llvm/llvm-project/pull/95309


More information about the flang-commits mailing list