[flang-commits] [flang] 74c2ec5 - [flang] use greedy mlir driver for stack arrays pass

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Tue May 23 07:52:01 PDT 2023


Author: Tom Eccles
Date: 2023-05-23T14:51:42Z
New Revision: 74c2ec50f393bad8b31d0dd0bd8b2ff44d361198

URL: https://github.com/llvm/llvm-project/commit/74c2ec50f393bad8b31d0dd0bd8b2ff44d361198
DIFF: https://github.com/llvm/llvm-project/commit/74c2ec50f393bad8b31d0dd0bd8b2ff44d361198.diff

LOG: [flang] use greedy mlir driver for stack arrays pass

In upstream mlir, the dialect conversion infrastructure is used for
lowering from one dialect to another: the passes are of the form
XToYPass. Whereas, transformations within the same dialect tend to use
applyPatternsAndFoldGreedily.

In this case, the full complexity of applyPatternsAndFoldGreedily isn't
needed so we can get away with the simpler applyOpPatternsAndFold.

This change was suggested by @jeanPerier

Differential Revision: https://reviews.llvm.org/D150853

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/StackArrays.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp
index 60a30d2d1ef64..d09ce292c46d7 100644
--- a/flang/lib/Optimizer/Transforms/StackArrays.cpp
+++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp
@@ -26,7 +26,7 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LogicalResult.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
@@ -167,25 +167,22 @@ class StackArraysAnalysisWrapper {
 
   StackArraysAnalysisWrapper(mlir::Operation *op) {}
 
-  bool hasErrors() const;
-
-  const AllocMemMap &getCandidateOps(mlir::Operation *func);
+  // returns nullptr if analysis failed
+  const AllocMemMap *getCandidateOps(mlir::Operation *func);
 
 private:
   llvm::DenseMap<mlir::Operation *, AllocMemMap> funcMaps;
-  bool gotError = false;
 
-  void analyseFunction(mlir::Operation *func);
+  mlir::LogicalResult analyseFunction(mlir::Operation *func);
 };
 
 /// Converts a fir.allocmem to a fir.alloca
 class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
 public:
-  using OpRewritePattern::OpRewritePattern;
-
-  AllocMemConversion(
+  explicit AllocMemConversion(
       mlir::MLIRContext *ctx,
-      const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps);
+      const StackArraysAnalysisWrapper::AllocMemMap &candidateOps)
+      : OpRewritePattern(ctx), candidateOps{candidateOps} {}
 
   mlir::LogicalResult
   matchAndRewrite(fir::AllocMemOp allocmem,
@@ -196,9 +193,8 @@ class AllocMemConversion : public mlir::OpRewritePattern<fir::AllocMemOp> {
   static InsertionPoint findAllocaInsertionPoint(fir::AllocMemOp &oldAlloc);
 
 private:
-  /// allocmem operations that DFA has determined are safe to move to the stack
-  /// mapping to where to insert replacement freemem operations
-  const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps;
+  /// Handle to the DFA (already run)
+  const StackArraysAnalysisWrapper::AllocMemMap &candidateOps;
 
   /// If we failed to find an insertion point not inside a loop, see if it would
   /// be safe to use an llvm.stacksave/llvm.stackrestore inside the loop
@@ -412,7 +408,8 @@ void AllocationAnalysis::processOperation(mlir::Operation *op) {
   visitOperationImpl(op, *before, after);
 }
 
-void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
+mlir::LogicalResult
+StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
   assert(mlir::isa<mlir::func::FuncOp>(func));
   mlir::DataFlowSolver solver;
   // constant propagation is required for dead code analysis, dead code analysis
@@ -426,8 +423,7 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
   solver.load<AllocationAnalysis>();
   if (failed(solver.initializeAndRun(func))) {
     llvm::errs() << "DataFlowSolver failed!";
-    gotError = true;
-    return;
+    return mlir::failure();
   }
 
   LatticePoint point{func};
@@ -458,22 +454,17 @@ void StackArraysAnalysisWrapper::analyseFunction(mlir::Operation *func) {
                   : candidateOps) {
     llvm::dbgs() << "StackArrays: Found candidate op: " << *allocMemOp << '\n';
   });
+  return mlir::success();
 }
 
-bool StackArraysAnalysisWrapper::hasErrors() const { return gotError; }
-
-const StackArraysAnalysisWrapper::AllocMemMap &
+const StackArraysAnalysisWrapper::AllocMemMap *
 StackArraysAnalysisWrapper::getCandidateOps(mlir::Operation *func) {
-  if (!funcMaps.count(func))
-    analyseFunction(func);
-  return funcMaps[func];
+  if (!funcMaps.contains(func))
+    if (mlir::failed(analyseFunction(func)))
+      return nullptr;
+  return &funcMaps[func];
 }
 
-AllocMemConversion::AllocMemConversion(
-    mlir::MLIRContext *ctx,
-    const llvm::DenseMap<mlir::Operation *, InsertionPoint> &candidateOps)
-    : OpRewritePattern(ctx), candidateOps(candidateOps) {}
-
 mlir::LogicalResult
 AllocMemConversion::matchAndRewrite(fir::AllocMemOp allocmem,
                                     mlir::PatternRewriter &rewriter) const {
@@ -709,29 +700,31 @@ void StackArraysPass::runOnFunc(mlir::Operation *func) {
   assert(mlir::isa<mlir::func::FuncOp>(func));
 
   auto &analysis = getAnalysis<StackArraysAnalysisWrapper>();
-  const auto &candidateOps = analysis.getCandidateOps(func);
-  if (analysis.hasErrors()) {
+  const StackArraysAnalysisWrapper::AllocMemMap *candidateOps =
+      analysis.getCandidateOps(func);
+  if (!candidateOps) {
     signalPassFailure();
     return;
   }
 
-  if (candidateOps.empty())
+  if (candidateOps->empty())
     return;
-  runCount += candidateOps.size();
+  runCount += candidateOps->size();
+
+  llvm::SmallVector<mlir::Operation *> opsToConvert;
+  opsToConvert.reserve(candidateOps->size());
+  for (auto [op, _] : *candidateOps)
+    opsToConvert.push_back(op);
 
   mlir::MLIRContext &context = getContext();
   mlir::RewritePatternSet patterns(&context);
-  mlir::ConversionTarget target(context);
-
-  target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                         mlir::func::FuncDialect>();
-  target.addDynamicallyLegalOp<fir::AllocMemOp>([&](fir::AllocMemOp alloc) {
-    return !candidateOps.count(alloc.getOperation());
-  });
+  mlir::GreedyRewriteConfig config;
+  // prevent the pattern driver form merging blocks
+  config.enableRegionSimplification = false;
 
-  patterns.insert<AllocMemConversion>(&context, candidateOps);
-  if (mlir::failed(
-          mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+  patterns.insert<AllocMemConversion>(&context, *candidateOps);
+  if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
+                                                std::move(patterns), config))) {
     mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
     signalPassFailure();
   }


        


More information about the flang-commits mailing list