[flang-commits] [flang] cb33e4a - [flang] Generalize `AbstractResultOpt` pass

Daniil Dudkin via flang-commits flang-commits at lists.llvm.org
Wed Jul 27 01:56:24 PDT 2022


Author: Daniil Dudkin
Date: 2022-07-27T11:55:17+03:00
New Revision: cb33e4ab149363169fad47d6c26e2eda167bc8ff

URL: https://github.com/llvm/llvm-project/commit/cb33e4ab149363169fad47d6c26e2eda167bc8ff
DIFF: https://github.com/llvm/llvm-project/commit/cb33e4ab149363169fad47d6c26e2eda167bc8ff.diff

LOG: [flang] Generalize `AbstractResultOpt` pass

This change decouples common functionality for convering abstract
results, so it can be reused later.

Depends on D129485

Reviewed By: clementval, jeanPerier

Differential Revision: https://reviews.llvm.org/D129778

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/include/flang/Tools/CLOptions.inc
    flang/lib/Optimizer/Transforms/AbstractResult.cpp
    flang/test/Driver/mlir-pass-pipeline.f90
    flang/test/Fir/abstract-results.fir
    flang/test/Fir/basic-program.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index f9c48b374fdcb..7896fd3a90069 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -26,7 +26,7 @@ namespace fir {
 // Passes defined in Passes.td
 //===----------------------------------------------------------------------===//
 
-std::unique_ptr<mlir::Pass> createAbstractResultOptPass();
+std::unique_ptr<mlir::Pass> createAbstractResultOnFuncOptPass();
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
 std::unique_ptr<mlir::Pass> createArrayValueCopyPass();
 std::unique_ptr<mlir::Pass> createFirToCfgPass();

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 69684c6861c31..442f542bf897b 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -16,14 +16,14 @@
 
 include "mlir/Pass/PassBase.td"
 
-def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> {
+class AbstractResultOptBase<string optExt, string operation> 
+  : Pass<"abstract-result-on-" # optExt # "-opt", operation> {
   let summary = "Convert fir.array, fir.box and fir.rec function result to "
                 "function argument";
   let description = [{
     This pass is required before code gen to the LLVM IR dialect,
     including the pre-cg rewrite pass.
   }];
-  let constructor = "::fir::createAbstractResultOptPass()";
   let dependentDialects = [
     "fir::FIROpsDialect", "mlir::func::FuncDialect"
   ];
@@ -35,6 +35,10 @@ def AbstractResultOpt : Pass<"abstract-result-opt", "mlir::func::FuncOp"> {
   ];
 }
 
+def AbstractResultOnFuncOpt : AbstractResultOptBase<"func", "mlir::func::FuncOp"> {
+  let constructor = "::fir::createAbstractResultOnFuncOptPass()";
+}
+
 def AffineDialectPromotion : Pass<"promote-to-affine", "::mlir::func::FuncOp"> {
   let summary = "Promotes `fir.{do_loop,if}` to `affine.{for,if}`.";
   let description = [{

diff  --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 5cf89d79ffc85..fd770fe542f40 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -191,7 +191,8 @@ inline void createDefaultFIROptimizerPassPipeline(
 #if !defined(FLANG_EXCLUDE_CODEGEN)
 inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm) {
   fir::addBoxedProcedurePass(pm);
-  pm.addNestedPass<mlir::func::FuncOp>(fir::createAbstractResultOptPass());
+  pm.addNestedPass<mlir::func::FuncOp>(
+      fir::createAbstractResultOnFuncOptPass());
   fir::addCodeGenRewritePass(pm);
   fir::addTargetRewritePass(pm);
   fir::addExternalNameConversionPass(pm);

diff  --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 3e86620901bdb..0a15ffd1af8b6 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -191,40 +191,26 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
   bool shouldBoxResult;
 };
 
-class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
+/// @brief Base CRTP class for AbstractResult pass family.
+/// Contains common logic for abstract result conversion in a reusable fashion.
+/// @tparam Pass target class that implements operation-specific logic.
+/// @tparam PassBase base class template for the pass generated by TableGen.
+/// The `Pass` class must define runOnSpecificOperation(OpTy, bool,
+/// mlir::RewritePatternSet&, mlir::ConversionTarget&) member function.
+/// This function should implement operation-specific functionality.
+template <typename Pass, template <typename> class PassBase>
+class AbstractResultOptTemplate : public PassBase<Pass> {
 public:
   void runOnOperation() override {
-    auto *context = &getContext();
-    auto func = getOperation();
-    auto loc = func.getLoc();
+    auto *context = &this->getContext();
+    auto op = this->getOperation();
+
     mlir::RewritePatternSet patterns(context);
     mlir::ConversionTarget target = *context;
-    const bool shouldBoxResult = passResultAsBox.getValue();
-
-    // Convert function type itself if it has an abstract result
-    auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
-    if (hasAbstractResult(funcTy)) {
-      func.setType(getNewFunctionType(funcTy, shouldBoxResult));
-      unsigned zero = 0;
-      if (!func.empty()) {
-        // Insert new argument
-        mlir::OpBuilder rewriter(context);
-        auto resultType = funcTy.getResult(0);
-        auto argTy = getResultArgumentType(resultType, shouldBoxResult);
-        mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
-        if (mustEmboxResult(resultType, shouldBoxResult)) {
-          auto bufferType = fir::ReferenceType::get(resultType);
-          rewriter.setInsertionPointToStart(&func.front());
-          newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
-        }
-        patterns.insert<ReturnOpConversion>(context, newArg);
-        target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
-            [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
-      }
-    }
+    const bool shouldBoxResult = this->passResultAsBox.getValue();
 
-    if (func.empty())
-      return;
+    auto &self = static_cast<Pass &>(*this);
+    self.runOnSpecificOperation(op, shouldBoxResult, patterns, target);
 
     // Convert the calls and, if needed,  the ReturnOp in the function body.
     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
@@ -253,15 +239,47 @@ class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
     patterns.insert<SaveResultOpConversion>(context);
     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
     if (mlir::failed(
-            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
-      mlir::emitError(func.getLoc(), "error in converting abstract results\n");
-      signalPassFailure();
+            mlir::applyPartialConversion(op, target, std::move(patterns)))) {
+      mlir::emitError(op.getLoc(), "error in converting abstract results\n");
+      this->signalPassFailure();
+    }
+  }
+};
+
+class AbstractResultOnFuncOpt
+    : public AbstractResultOptTemplate<AbstractResultOnFuncOpt,
+                                       fir::AbstractResultOnFuncOptBase> {
+public:
+  void runOnSpecificOperation(mlir::func::FuncOp func, bool shouldBoxResult,
+                              mlir::RewritePatternSet &patterns,
+                              mlir::ConversionTarget &target) {
+    auto loc = func.getLoc();
+    auto *context = &getContext();
+    // Convert function type itself if it has an abstract result.
+    auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
+    if (hasAbstractResult(funcTy)) {
+      func.setType(getNewFunctionType(funcTy, shouldBoxResult));
+      if (!func.empty()) {
+        // Insert new argument.
+        mlir::OpBuilder rewriter(context);
+        auto resultType = funcTy.getResult(0);
+        auto argTy = getResultArgumentType(resultType, shouldBoxResult);
+        mlir::Value newArg = func.front().insertArgument(0u, argTy, loc);
+        if (mustEmboxResult(resultType, shouldBoxResult)) {
+          auto bufferType = fir::ReferenceType::get(resultType);
+          rewriter.setInsertionPointToStart(&func.front());
+          newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
+        }
+        patterns.insert<ReturnOpConversion>(context, newArg);
+        target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
+            [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
+      }
     }
   }
 };
 } // end anonymous namespace
 } // namespace fir
 
-std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
-  return std::make_unique<AbstractResultOpt>();
+std::unique_ptr<mlir::Pass> fir::createAbstractResultOnFuncOptPass() {
+  return std::make_unique<AbstractResultOnFuncOpt>();
 }

diff  --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 5d9f0ce32b795..ac70f0576892b 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -52,7 +52,7 @@
 ! ALL-NEXT: BoxedProcedurePass
 
 ! ALL-NEXT: 'func.func' Pipeline
-! ALL-NEXT:   AbstractResultOpt
+! ALL-NEXT:   AbstractResultOnFuncOpt
 
 ! ALL-NEXT: CodeGenRewrite
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations eliminated

diff  --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir
index 4816717ac00dc..52e34241cf87e 100644
--- a/flang/test/Fir/abstract-results.fir
+++ b/flang/test/Fir/abstract-results.fir
@@ -1,8 +1,8 @@
 // Test rewrite of functions that return fir.array<>, fir.type<>, fir.box<> to
 // functions that take an additional argument for the result.
 
-// RUN: fir-opt %s --abstract-result-opt | FileCheck %s
-// RUN: fir-opt %s --abstract-result-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX
+// RUN: fir-opt %s --abstract-result-on-func-opt | FileCheck %s
+// RUN: fir-opt %s --abstract-result-on-func-opt=abstract-result-as-box | FileCheck %s --check-prefix=CHECK-BOX
 
 // ----------------------- Test declaration rewrite ----------------------------
 

diff  --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 29efa845fafac..7501f82b90175 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -52,7 +52,7 @@ func.func @_QQmain() {
 // PASSES-NEXT: BoxedProcedurePass
 
 // PASSES-NEXT: 'func.func' Pipeline
-// PASSES-NEXT:   AbstractResultOpt
+// PASSES-NEXT:   AbstractResultOnFuncOpt
 
 // PASSES-NEXT: CodeGenRewrite
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations eliminated


        


More information about the flang-commits mailing list