[flang-commits] [flang] 408f419 - [flang] use greedy mlir driver for stack arrays pass

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Wed May 31 07:14:14 PDT 2023


Author: Tom Eccles
Date: 2023-05-31T14:06:57Z
New Revision: 408f4196ba4ac66328ebfcf41cb372572257c4f6

URL: https://github.com/llvm/llvm-project/commit/408f4196ba4ac66328ebfcf41cb372572257c4f6
DIFF: https://github.com/llvm/llvm-project/commit/408f4196ba4ac66328ebfcf41cb372572257c4f6.diff

LOG: [flang] use greedy mlir driver for stack arrays pass

In upstream mlir, the dialect conversion infrastructure is used for
lowering from one dialect to another: the passes are of the form
XToYPass. Whereas, transformations within the same dialect tend to use
applyPatternsAndFoldGreedily.

In this case, the full complexity of applyPatternsAndFoldGreedily isn't
needed so we can get away with the simpler applyOpPatternsAndFold.

This change was suggested by @jeanPerier

The old differential revision for this patch was
https://reviews.llvm.org/D150853

Re-applying here fixing the issue which led to the patch being reverted. The
issue was from erasing uses of the allocation operation while still iterating
over those uses (leading to a use-after-free). I have added a regression
test which catches this bug for -fsanitize=address builds, but it is
hard to reliably cause a crash from the use-after-free in normal builds.

Differential Revision: https://reviews.llvm.org/D151728

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/StackArrays.cpp
    flang/test/Transforms/stack-arrays.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 60a30d2d1ef64..0f21e755dad27 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -26,7 +26,7 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -167,25 +167,22 @@ class StackArraysAnalysisWrapper {
 
   StackArraysAnalysisWrapper(mlir::Operation *op) {}
 
-  bool hasErrors() const;
-
-  const AllocMemMap &getCandidateOps(mlir::Operation *func);
+  // returns nullptr if analysis failed
+  const AllocMemMap *getCandidateOps(mlir::Operation *func);
 
 private:
   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
-  bool gotError = false;
 
-  void analyseFunction(mlir::Operation *func);
+  mlir::LogicalResult analyseFunction(mlir::Operation *func);
 };
 
 /// Converts a fir.allocmem to a fir.alloca
 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
-
-  AllocMemConversion(
+  explicit AllocMemConversion(
       mlir::MLIRContext *ctx,
-      const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps);
+      const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
+      : OpRewritePattern(ctx), candidateOps{candidateOps} {}
 
   mlir::LogicalResult
   matchAndRewrite(fir::AllocMemOp allocmem,
@@ -196,9 +193,8 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
 
 private:
-  /// allocmem operations that DFA has determined are safe to move to the stack
-  /// mapping to where to insert replacement freemem operations
-  const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps;
+  /// Handle to the DFA (already run)
+  const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
 
   /// If we failed to find an insertion point not inside a loop, see if it would
   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
@@ -412,7 +408,8 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
   visitOperationImpl(op, *before, after);
 }
 
-void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
+mlir::LogicalResult
+StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
   assert(mlir::isa<mlir::func::FuncOp>(func));
   mlir::DataFlowSolver solver;
   // constant propagation is required for dead code analysis, dead code analysis
@@ -426,8 +423,7 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
   solver.load<AllocationAnalysis>();
   if (failed(solver.initializeAndRun(func))) {
     llvm::errs() << "DataFlowSolver failed!";
-    gotError = true;
-    return;
+    return mlir::failure();
   }
 
   LatticePoint point{func};
@@ -458,22 +454,17 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
                   : candidateOps) {
     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
   });
+  return mlir::success();
 }
 
-bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; }
-
-const StackArraysAnalysisWrapper::AllocMemMap &
+const StackArraysAnalysisWrapper::AllocMemMap *
 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
-  if (!funcMaps.count(func))
-    analyseFunction(func);
-  return funcMaps[func];
+  if (!funcMaps.contains(func))
+    if (mlir::failed(analyseFunction(func)))
+      return nullptr;
+  return &funcMaps[func];
 }
 
-AllocMemConversion::AllocMemConversion(
-    mlir::MLIRContext *ctx,
-    const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps)
-    : OpRewritePattern(ctx), candidateOps(candidateOps) {}
-
 mlir::LogicalResult
 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
                                     mlir::PatternRewriter &rewriter) const {
@@ -485,9 +476,13 @@ AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
     return mlir::failure();
 
   // remove freemem operations
+  llvm::SmallVector<mlir::Operation *> erases;
   for (mlir::Operation *user : allocmem.getOperation()->getUsers())
     if (mlir::isa<fir::FreeMemOp>(user))
-      rewriter.eraseOp(user);
+      erases.push_back(user);
+  // now we are done iterating the users, it is safe to mutate them
+  for (mlir::Operation *erase : erases)
+    rewriter.eraseOp(erase);
 
   // replace references to heap allocation with references to stack allocation
   rewriter.replaceAllUsesWith(allocmem.getResult(), alloca->getResult());
@@ -709,29 +704,31 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
   assert(mlir::isa<mlir::func::FuncOp>(func));
 
   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
-  const auto &candidateOps = analysis.getCandidateOps(func);
-  if (analysis.hasErrors()) {
+  const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
+      analysis.getCandidateOps(func);
+  if (!candidateOps) {
     signalPassFailure();
     return;
   }
 
-  if (candidateOps.empty())
+  if (candidateOps->empty())
     return;
-  runCount += candidateOps.size();
+  runCount += candidateOps->size();
+
+  llvm::SmallVector<mlir::Operation *> opsToConvert;
+  opsToConvert.reserve(candidateOps->size());
+  for (auto [op, _] : *candidateOps)
+    opsToConvert.push_back(op);
 
   mlir::MLIRContext &context = getContext();
   mlir::RewritePatternSet patterns(&context);
-  mlir::ConversionTarget target(context);
-
-  target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                         mlir::func::FuncDialect>();
-  target.addDynamicallyLegalOp<fir::AllocMemOp>([&](fir::AllocMemOp alloc) {
-    return !candidateOps.count(alloc.getOperation());
-  });
+  mlir::GreedyRewriteConfig config;
+  // prevent the pattern driver form merging blocks
+  config.enableRegionSimplification = false;
 
-  patterns.insert<AllocMemConversion>(&context, candidateOps);
-  if (mlir::failed(
-          mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+  patterns.insert<AllocMemConversion>(&context, *candidateOps);
+  if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
+                                                std::move(patterns), config))) {
     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
     signalPassFailure();
   }

diff  --git a/flang/test/Transforms/stack-arrays.fir b/flang/test/Transforms/stack-arrays.fir
index d470ea704be48..046a402831aa8 100644
--- a/flang/test/Transforms/stack-arrays.fir
+++ b/flang/test/Transforms/stack-arrays.fir
@@ -84,6 +84,33 @@ func.func @dfa3(%arg0: i1) {
 // CHECK-NEXT:  return
 // CHECK-NEXT:  }
 
+func.func private @dfa3a_foo(!fir.ref<!fir.array<1xi8>>) -> ()
+func.func private @dfa3a_bar(!fir.ref<!fir.array<1xi8>>) -> ()
+
+// Check freemem in both regions, with other uses
+func.func @dfa3a(%arg0: i1) {
+  %a = fir.allocmem !fir.array<1xi8>
+  fir.if %arg0 {
+    %ref = fir.convert %a : (!fir.heap<!fir.array<1xi8>>) -> !fir.ref<!fir.array<1xi8>>
+    func.call @dfa3a_foo(%ref) : (!fir.ref<!fir.array<1xi8>>) -> ()
+    fir.freemem %a : !fir.heap<!fir.array<1xi8>>
+  } else {
+    %ref = fir.convert %a : (!fir.heap<!fir.array<1xi8>>) -> !fir.ref<!fir.array<1xi8>>
+    func.call @dfa3a_bar(%ref) : (!fir.ref<!fir.array<1xi8>>) -> ()
+    fir.freemem %a : !fir.heap<!fir.array<1xi8>>
+  }
+  return
+}
+// CHECK:     func.func @dfa3a(%arg0: i1) {
+// CHECK-NEXT:  %[[MEM:.*]] = fir.alloca !fir.array<1xi8>
+// CHECK-NEXT:  fir.if %arg0 {
+// CHECK-NEXT:    func.call @dfa3a_foo(%[[MEM]])
+// CHECK-NEXT:  } else {
+// CHECK-NEXT:    func.call @dfa3a_bar(%[[MEM]])
+// CHECK-NEXT:  }
+// CHECK-NEXT:  return
+// CHECK-NEXT:  }
+
 // check the alloca is placed after all operands become available
 func.func @placement1() {
   // do some stuff with other ssa values


        


More information about the flang-commits mailing list