[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Tue Jun 11 01:56:38 PDT 2024


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/73829

>From 59f4487f1e78d7e7209409b8e40ab04dc95054f9 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 22 Nov 2023 14:01:38 +0000
Subject: [PATCH 1/7] [Flang] Extracting internal constants from scalar
 literals

Constants actual arguments in function/subroutine calls are currently lowered
as allocas + store. This can sometimes inhibit LTO and the constant will not
be propagated to the called function. Particularly in cases where the
function/subroutine call happens inside a condition.

This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.

Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
 .../flang/Optimizer/Transforms/Passes.td      |  10 +
 flang/include/flang/Tools/CLOptions.inc       |   4 -
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Optimizer/Transforms/ConstExtruder.cpp    | 216 ++++++++++++++++++
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   1 +
 .../test/Driver/mlir-debug-pass-pipeline.f90  |   1 +
 flang/test/Driver/mlir-pass-pipeline.f90      |   1 +
 flang/test/Fir/basic-program.fir              |   1 +
 flang/test/Fir/boxproc.fir                    |   4 +-
 .../test/Lower/character-local-variables.f90  |   3 +-
 flang/test/Lower/dummy-arguments.f90          |   4 +-
 flang/test/Lower/host-associated.f90          |   7 +-
 flang/test/Transforms/const-extrude.f90       |  32 +++
 13 files changed, 269 insertions(+), 16 deletions(-)
 create mode 100644 flang/lib/Optimizer/Transforms/ConstExtruder.cpp
 create mode 100644 flang/test/Transforms/const-extrude.f90

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 82638200e5e20..df3b9fc6fb613 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -252,6 +252,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
   ];
 }
 
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+  let summary = "Convert scalar literals of function arguments to global constants.";
+  let description = [{
+    Convert scalar literals of function arguments to global constants.
+  }];
+  let dependentDialects = [ "fir::FIROpsDialect" ];
+  let constructor = "::fir::createConstExtruderPass()";
+}
+
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
   let summary = "Move local array allocations from heap memory into stack memory";
   let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index fb3ec75d4078a..d2bf79d8789cc 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -284,10 +284,6 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
-  // FIR Inliner Callback
-  pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
-
-  pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
 
   // Polymorphic types
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 5ef930fdb2c2f..800fe44dfdc11 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_flang_library(FIRTransforms
   AnnotateConstant.cpp
   AssumedRankOpConversion.cpp
   CharacterConversion.cpp
+  ConstExtruder.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 0000000000000..1bb1cd2269871
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+  if (!a || !a->getDefiningOp())
+    return false;
+
+  // is alloca
+  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+    // alloca has annotation
+    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+          auto constant_def = store->getOperand(0).getDefiningOp();
+          // Expect constant definition operation
+          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+  mlir::DominanceInfo &di;
+
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+      : OpRewritePattern(ctx), di(_di) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+    auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, module);
+    llvm::SmallVector<mlir::Value> newOperands;
+    llvm::SmallVector<mlir::Operation *> toErase;
+    for (const auto &a : callOp.getArgs()) {
+      if (auto alloca =
+              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+        if (needsExtrusion(&a)) {
+
+          mlir::Type varTy = alloca.getInType();
+          assert(!fir::hasDynamicSize(varTy) &&
+                 "only expect statically sized scalars to be by value");
+
+          // find immediate store with const argument
+          llvm::SmallVector<mlir::Operation *> stores;
+          for (mlir::Operation *s : alloca.getOperation()->getUsers())
+            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+              stores.push_back(s);
+          assert(stores.size() == 1 && "expected exactly one store");
+          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+          // Expect constant definition operation or force legalisation of the
+          // callOp and continue with its next argument
+          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            // unable to remove alloca arg
+            newOperands.push_back(a);
+            continue;
+          }
+
+          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+          auto loc = callOp.getLoc();
+          llvm::StringRef globalPrefix = "_extruded_";
+
+          std::string globalName;
+          while (!globalName.length() || builder.getNamedGlobal(globalName))
+            globalName =
+                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+          if (alloca->hasOneUse()) {
+            toErase.push_back(alloca);
+            toErase.push_back(stores[0]);
+          } else {
+            int count = -2;
+            for (mlir::Operation *s : alloca.getOperation()->getUsers())
+              if (di.dominates(stores[0], s))
+                ++count;
+
+            // delete if dominates itself and one more operation (which should
+            // be callOp)
+            if (!count)
+              toErase.push_back(stores[0]);
+          }
+          auto global = builder.createGlobalConstant(
+              loc, varTy, globalName,
+              [&](fir::FirOpBuilder &builder) {
+                mlir::Operation *cln = constant_def->clone();
+                builder.insert(cln);
+                fir::ExtendedValue exv{cln->getResult(0)};
+                mlir::Value valBase = fir::getBase(exv);
+                mlir::Value val = builder.createConvert(loc, varTy, valBase);
+                builder.create<fir::HasValueOp>(loc, val);
+              },
+              builder.createInternalLinkage());
+          mlir::Value ope = {builder.create<fir::AddrOfOp>(
+              loc, global.resultType(), global.getSymbol())};
+          newOperands.push_back(ope);
+        } else {
+          // alloca but without attr, add it
+          newOperands.push_back(a);
+        }
+      } else {
+        // non-alloca operand, add it
+        newOperands.push_back(a);
+      }
+    }
+
+    auto loc = callOp.getLoc();
+    llvm::SmallVector<mlir::Type> newResultTypes;
+    newResultTypes.append(callOp.getResultTypes().begin(),
+                          callOp.getResultTypes().end());
+    fir::CallOp newOp = builder.create<fir::CallOp>(
+        loc, newResultTypes,
+        callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                       : mlir::SymbolRefAttr{},
+        newOperands, callOp.getFastmathAttr());
+    rewriter.replaceOp(callOp, newOp);
+
+    for (auto e : toErase)
+      rewriter.eraseOp(e);
+
+    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                            << newOp << '\n');
+    return mlir::success();
+  }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+  mlir::DominanceInfo *di;
+
+public:
+  ConstExtruderOpt() {}
+
+  void runOnOperation() override {
+    mlir::ModuleOp mod = getOperation();
+    di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+  }
+
+  void runOnFunc(mlir::func::FuncOp &func) {
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target(*context);
+
+    // If func is a declaration, skip it.
+    if (func.empty())
+      return;
+
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+      for (auto a : op.getArgs()) {
+        if (needsExtrusion(&a))
+          return false;
+      }
+      return true;
+    });
+
+    patterns.insert<CallOpRewriter>(context, *di);
+    if (mlir::failed(
+            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+      mlir::emitError(func.getLoc(),
+                      "error in constant extrusion optimization\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+  return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c94b98c7c5805..8e81309209cf7 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,6 +38,7 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
+! CHECK-NEXT:   ConstExtruderOpt
 
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 49b1f8c5c3134..42180d2966cb7 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,6 +65,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 8e1a3d43edd1c..c7bb5cb5406d0 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -72,6 +72,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index dd184d99cb809..517c6054ed7a9 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -72,6 +72,7 @@ func.func @_QQmain() {
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
+// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af0..2ddc0ef525ac4 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
 
 // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
 // CHECK-SAME:              %[[VAL_0:.*]])
-// CHECK:         %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK:         store i32 4, ptr %[[VAL_1]], align 4
-// CHECK:         call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK:         call void %[[VAL_0]](ptr @{{.*}})
 
 func.func @_QPtest_proc_dummy() {
   %c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e..b1cfc540f4389 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
-  ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+  ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
   ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 331e089a60fa0..7c85b7c0a746d 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
-  ! CHECK-DAG: %[[TEN:.*]] = arith.constant
-  ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index cdc7e6a05288a..0b5311402d51e 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
 
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK:         %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK:         fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
 ! CHECK:         return
 ! CHECK:       }
 
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 0000000000000..70cdaf496f34a
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+  implicit none
+  integer x, y
+  
+  call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }

>From de6f18ab1433e68de885b2c5835d42b8b7b64a73 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 6 Dec 2023 16:12:14 +0000
Subject: [PATCH 2/7] Use greedy rewriter

---
 .../Optimizer/Transforms/ConstExtruder.cpp    | 23 +++++++------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 1bb1cd2269871..00c9f30e9dbac 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -16,7 +16,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
 protected:
   mlir::DominanceInfo *di;
+  mlir::GreedyRewriteConfig config;
 
 public:
-  ConstExtruderOpt() {}
+  ConstExtruderOpt() {
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+  }
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
   void runOnFunc(mlir::func::FuncOp &func) {
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target(*context);
 
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
-      for (auto a : op.getArgs()) {
-        if (needsExtrusion(&a))
-          return false;
-      }
-      return true;
-    });
-
     patterns.insert<CallOpRewriter>(context, *di);
-    if (mlir::failed(
-            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            func, std::move(patterns), config))) {
       mlir::emitError(func.getLoc(),
                       "error in constant extrusion optimization\n");
       signalPassFailure();

>From b9e420109955d908cc0598701666fd59fd948045 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 12 Dec 2023 14:36:10 +0000
Subject: [PATCH 3/7] Fix review comments

---
 .../flang/Optimizer/Transforms/Passes.td      |   2 +-
 .../Optimizer/Transforms/ConstExtruder.cpp    | 238 ++++++++----------
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   2 +-
 3 files changed, 111 insertions(+), 131 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index df3b9fc6fb613..12c700718a6bc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -254,7 +254,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
 
 // This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
 def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
-  let summary = "Convert scalar literals of function arguments to global constants.";
+  let summary = "Convert constant function arguments to global constants.";
   let description = [{
     Convert scalar literals of function arguments to global constants.
   }];
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 00c9f30e9dbac..796a6a05a580f 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp -----------------------------------------------===//
+//===- ConstExtruder.cpp --------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include <atomic>
 
 namespace fir {
 #define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,38 +25,16 @@ namespace fir {
 #define DEBUG_TYPE "flang-const-extruder-opt"
 
 namespace {
-std::atomic<int> uniqueLitId = 1;
-
-static bool needsExtrusion(const mlir::Value *a) {
-  if (!a || !a->getDefiningOp())
-    return false;
-
-  // is alloca
-  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
-    // alloca has annotation
-    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
-      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
-        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
-          auto constant_def = store->getOperand(0).getDefiningOp();
-          // Expect constant definition operation
-          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            return true;
-          }
-        }
-      }
-    }
-  }
-  return false;
-}
+unsigned uniqueLitId = 1;
 
 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 protected:
-  mlir::DominanceInfo &di;
+  const mlir::DominanceInfo &di;
 
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+  CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
       : OpRewritePattern(ctx), di(_di) {}
 
   mlir::LogicalResult
@@ -68,131 +42,137 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
                   mlir::PatternRewriter &rewriter) const override {
     LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
     auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    bool needUpdate = false;
     fir::FirOpBuilder builder(rewriter, module);
     llvm::SmallVector<mlir::Value> newOperands;
     llvm::SmallVector<mlir::Operation *> toErase;
-    for (const auto &a : callOp.getArgs()) {
-      if (auto alloca =
-              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
-        if (needsExtrusion(&a)) {
-
-          mlir::Type varTy = alloca.getInType();
-          assert(!fir::hasDynamicSize(varTy) &&
-                 "only expect statically sized scalars to be by value");
-
-          // find immediate store with const argument
-          llvm::SmallVector<mlir::Operation *> stores;
-          for (mlir::Operation *s : alloca.getOperation()->getUsers())
-            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
-              stores.push_back(s);
-          assert(stores.size() == 1 && "expected exactly one store");
-          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
-
-          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
-          // Expect constant definition operation or force legalisation of the
-          // callOp and continue with its next argument
-          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            // unable to remove alloca arg
-            newOperands.push_back(a);
-            continue;
-          }
+    for (const mlir::Value &a : callOp.getArgs()) {
+      auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
+      // We can convert arguments that are alloca, and that has
+      // the value by reference attribute. All else is just added
+      // to the argument list.
+      if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+        newOperands.push_back(a);
+        continue;
+      }
 
-          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
-
-          auto loc = callOp.getLoc();
-          llvm::StringRef globalPrefix = "_extruded_";
-
-          std::string globalName;
-          while (!globalName.length() || builder.getNamedGlobal(globalName))
-            globalName =
-                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
-
-          if (alloca->hasOneUse()) {
-            toErase.push_back(alloca);
-            toErase.push_back(stores[0]);
-          } else {
-            int count = -2;
-            for (mlir::Operation *s : alloca.getOperation()->getUsers())
-              if (di.dominates(stores[0], s))
-                ++count;
-
-            // delete if dominates itself and one more operation (which should
-            // be callOp)
-            if (!count)
-              toErase.push_back(stores[0]);
+      mlir::Type varTy = alloca.getInType();
+      assert(!fir::hasDynamicSize(varTy) &&
+             "only expect statically sized scalars to be by value");
+
+      // Find immediate store with const argument
+      mlir::Operation *store = nullptr;
+      for (mlir::Operation *s : alloca->getUsers()) {
+        if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
+          // We can only deal with ONE store - if already found one,
+          // set to nullptr and exit the loop.
+          if (store) {
+            store = nullptr;
+            break;
           }
-          auto global = builder.createGlobalConstant(
-              loc, varTy, globalName,
-              [&](fir::FirOpBuilder &builder) {
-                mlir::Operation *cln = constant_def->clone();
-                builder.insert(cln);
-                fir::ExtendedValue exv{cln->getResult(0)};
-                mlir::Value valBase = fir::getBase(exv);
-                mlir::Value val = builder.createConvert(loc, varTy, valBase);
-                builder.create<fir::HasValueOp>(loc, val);
-              },
-              builder.createInternalLinkage());
-          mlir::Value ope = {builder.create<fir::AddrOfOp>(
-              loc, global.resultType(), global.getSymbol())};
-          newOperands.push_back(ope);
-        } else {
-          // alloca but without attr, add it
-          newOperands.push_back(a);
+          store = s;
         }
-      } else {
-        // non-alloca operand, add it
+      }
+
+      // If we didn't find one signle store, add argument as is, and move on.
+      if (!store) {
+        newOperands.push_back(a);
+        continue;
+      }
+
+      LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
+
+      mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
+      // Expect constant definition operation or force legalisation of the
+      // callOp and continue with its next argument
+      if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+        // Unable to remove alloca arg
         newOperands.push_back(a);
+        continue;
       }
+
+      LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+      std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+      assert(!builder.getNamedGlobal(globalName) &&
+             "We should have a unique name here");
+
+      unsigned count = 0;
+      for (mlir::Operation *s : alloca->getUsers())
+        if (di.dominates(store, s))
+          ++count;
+
+      // Delete if dominates itself and one more operation (which should
+      // be callOp)
+      if (count == 2)
+        toErase.push_back(store);
+
+      auto loc = callOp.getLoc();
+      fir::GlobalOp global = builder.createGlobalConstant(
+          loc, varTy, globalName,
+          [&](fir::FirOpBuilder &builder) {
+            mlir::Operation *cln = constant_def->clone();
+            builder.insert(cln);
+            mlir::Value val =
+                builder.createConvert(loc, varTy, cln->getResult(0));
+            builder.create<fir::HasValueOp>(loc, val);
+          },
+          builder.createInternalLinkage());
+      mlir::Value addr = {builder.create<fir::AddrOfOp>(
+          loc, global.resultType(), global.getSymbol())};
+      newOperands.push_back(addr);
+      needUpdate = true;
     }
 
-    auto loc = callOp.getLoc();
-    llvm::SmallVector<mlir::Type> newResultTypes;
-    newResultTypes.append(callOp.getResultTypes().begin(),
-                          callOp.getResultTypes().end());
-    fir::CallOp newOp = builder.create<fir::CallOp>(
-        loc, newResultTypes,
-        callOp.getCallee().has_value() ? callOp.getCallee().value()
-                                       : mlir::SymbolRefAttr{},
-        newOperands, callOp.getFastmathAttr());
-    rewriter.replaceOp(callOp, newOp);
-
-    for (auto e : toErase)
-      rewriter.eraseOp(e);
-
-    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
-                            << newOp << '\n');
-    return mlir::success();
+    if (needUpdate) {
+      auto loc = callOp.getLoc();
+      llvm::SmallVector<mlir::Type> newResultTypes;
+      newResultTypes.append(callOp.getResultTypes().begin(),
+                            callOp.getResultTypes().end());
+      fir::CallOp newOp = builder.create<fir::CallOp>(
+          loc, newResultTypes,
+          callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                         : mlir::SymbolRefAttr{},
+          newOperands, callOp.getFastmathAttr());
+      rewriter.replaceOp(callOp, newOp);
+
+      for (auto e : toErase)
+        rewriter.eraseOp(e);
+      LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                              << newOp << '\n');
+      return mlir::success();
+    }
+
+    // Failure here just means "we couldn't do the conversion", which is
+    // perfectly acceptable to the upper layers of this function.
+    return mlir::failure();
   }
 };
 
-// This pass attempts to convert immediate scalar literals in function calls
+// this pass attempts to convert immediate scalar literals in function calls
 // to global constants to allow transformations as Dead Argument Elimination
 class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
-protected:
-  mlir::DominanceInfo *di;
-  mlir::GreedyRewriteConfig config;
-
 public:
-  ConstExtruderOpt() {
-    config.enableRegionSimplification = false;
-    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
-  }
+  ConstExtruderOpt() = default;
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
-    di = &getAnalysis<mlir::DominanceInfo>();
-    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+    mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
   }
 
-  void runOnFunc(mlir::func::FuncOp &func) {
-    auto *context = &getContext();
-    mlir::RewritePatternSet patterns(context);
-
+  void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::GreedyRewriteConfig config;
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+
     patterns.insert<CallOpRewriter>(context, *di);
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 8e81309209cf7..28c06826ef415 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,8 +38,8 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
-! CHECK-NEXT:   ConstExtruderOpt
 
+! CHECK-NEXT: ConstExtruderOpt
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: CSE

>From 390c0ae13c5a242150847e68b1924cf74f2fe0a9 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 13 Dec 2023 15:35:24 +0000
Subject: [PATCH 4/7] Rename and add option to disable

---
 .../flang/Optimizer/Transforms/Passes.td      |  4 +-
 flang/include/flang/Tools/CLOptions.inc       |  4 ++
 flang/lib/Optimizer/Transforms/CMakeLists.txt |  2 +-
 ....cpp => ConstantArgumentGlobalisation.cpp} | 13 +++---
 flang/test/Transforms/const-extrude.f90       | 32 -------------
 .../constant-argument-globalisation.fir       | 45 +++++++++++++++++++
 6 files changed, 59 insertions(+), 41 deletions(-)
 rename flang/lib/Optimizer/Transforms/{ConstExtruder.cpp => ConstantArgumentGlobalisation.cpp} (94%)
 delete mode 100644 flang/test/Transforms/const-extrude.f90
 create mode 100644 flang/test/Transforms/constant-argument-globalisation.fir

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 12c700718a6bc..479272798f176 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -253,13 +253,13 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
 }
 
 // This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
-def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt", "mlir::ModuleOp"> {
   let summary = "Convert constant function arguments to global constants.";
   let description = [{
     Convert scalar literals of function arguments to global constants.
   }];
   let dependentDialects = [ "fir::FIROpsDialect" ];
-  let constructor = "::fir::createConstExtruderPass()";
+  let constructor = "::fir::createConstantArgumentGlobalisationPass()";
 }
 
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d2bf79d8789cc..d709efc9a34e7 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -86,6 +86,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
+DisableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
+    "disable the local constants to global constant conversion");
 
 using PassConstructor = std::unique_ptr<mlir::Pass>();
 
@@ -272,6 +274,8 @@ inline void createDefaultFIROptimizerPassPipeline(
     // These passes may increase code size.
     pm.addPass(fir::createSimplifyIntrinsics());
     pm.addPass(fir::createAlgebraicSimplificationPass(config));
+    if (!disableConstantArgumentGlobalisation)
+      pm.addPass(fir::createConstantArgumentGlobalisationPass());
   }
 
   if (pc.LoopVersioning)
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 800fe44dfdc11..cdf5ca368ebd1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,7 +6,7 @@ add_flang_library(FIRTransforms
   AnnotateConstant.cpp
   AssumedRankOpConversion.cpp
   CharacterConversion.cpp
-  ConstExtruder.cpp
+  ConstantArgumentGlobalisation.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
similarity index 94%
rename from flang/lib/Optimizer/Transforms/ConstExtruder.cpp
rename to flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 796a6a05a580f..2859a57226f16 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace fir {
-#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -151,10 +151,11 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 
 // this pass attempts to convert immediate scalar literals in function calls
 // to global constants to allow transformations as Dead Argument Elimination
-class ConstExtruderOpt
-    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+class ConstantArgumentGlobalisationOpt
+    : public fir::impl::ConstantArgumentGlobalisationOptBase<
+          ConstantArgumentGlobalisationOpt> {
 public:
-  ConstExtruderOpt() = default;
+  ConstantArgumentGlobalisationOpt() = default;
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
@@ -184,6 +185,6 @@ class ConstExtruderOpt
 };
 } // namespace
 
-std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
-  return std::make_unique<ConstExtruderOpt>();
+std::unique_ptr<mlir::Pass> fir::createConstantArgumentGlobalisationPass() {
+  return std::make_unique<ConstantArgumentGlobalisationOpt>();
 }
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
deleted file mode 100644
index 70cdaf496f34a..0000000000000
--- a/flang/test/Transforms/const-extrude.f90
+++ /dev/null
@@ -1,32 +0,0 @@
-! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
-
-subroutine sub1(x,y)
-  implicit none
-  integer x, y
-  
-  call sub2(0.0d0, 1.0d0, x, y, 1)
-end subroutine sub1
-
-!CHECK-LABEL: func.func @_QPsub1
-!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
-!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
-!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
-!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
-!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
-!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
-!CHECK: return
-
-!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
-!CHECK: %{{.*}} = arith.constant 1 : i32
-!CHECK: fir.has_value %{{.*}} : i32
-!CHECK: }
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
new file mode 100644
index 0000000000000..328191ea32512
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -0,0 +1,45 @@
+// RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+module {
+  func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
+    %0 = fir.alloca i32 {adapt.valuebyref}
+    %1 = fir.alloca f64 {adapt.valuebyref}
+    %2 = fir.alloca f64 {adapt.valuebyref}
+    %c1_i32 = arith.constant 1 : i32
+    %cst = arith.constant 1.000000e+00 : f64
+    %cst_0 = arith.constant 0.000000e+00 : f64
+    %3 = fir.declare %arg0 {uniq_name = "_QFsub1Ex"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %4 = fir.declare %arg1 {uniq_name = "_QFsub1Ey"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    fir.store %cst_0 to %2 : !fir.ref<f64>
+    %false = arith.constant false
+    fir.store %cst to %1 : !fir.ref<f64>
+    %false_1 = arith.constant false
+    fir.store %c1_i32 to %0 : !fir.ref<i32>
+    %false_2 = arith.constant false
+    fir.call @sub2(%2, %1, %3, %4, %0) fastmath<contract> : (!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
+}
+// CHECK-LABEL: func.func @sub1
+// CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+// CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+// CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+// CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+// CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+// CHECK: return
+
+// CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }

>From 9e038ff11d35bf8253d7d8bba3c72fd6340f3fbc Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 14 Dec 2023 18:50:04 +0000
Subject: [PATCH 5/7] Add more testing

---
 .../constant-argument-globalisation-2.fir     | 80 +++++++++++++++++++
 .../constant-argument-globalisation.fir       | 25 +++++-
 2 files changed, 103 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/constant-argument-globalisation-2.fir

diff --git a/flang/test/Transforms/constant-argument-globalisation-2.fir b/flang/test/Transforms/constant-argument-globalisation-2.fir
new file mode 100644
index 0000000000000..03855b5bfb762
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation-2.fir
@@ -0,0 +1,80 @@
+// RUN: fir-opt --split-input-file --constant-argument-globalisation-opt < %s | FileCheck %s
+
+module {
+// Test for "two conditional writes to the same alloca doesn't get replaced."
+  func.func @func(%arg0: i32, %arg1: i1) {
+    %c2_i32 = arith.constant 2 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.if %arg1 {
+      fir.store %c2_i32 to %addr : !fir.ref<i32>
+    } else {
+      fir.store %arg0 to %addr : !fir.ref<i32>
+    }
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK-SAME: [[ARG0:%.*]]: i32
+// CHECK-SAME: [[ARG1:%.*]]: i1)
+// CHECK: [[CONST:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.if [[ARG1]]
+// CHECK: fir.store [[CONST]] to [[ADDR]]
+// CHECK:  } else {
+// CHECK: fir.store [[ARG0]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "two writes to the same alloca doesn't get replaced."
+  func.func @func() {
+    %c1_i32 = arith.constant 1 : i32
+    %c2_i32 = arith.constant 2 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.store %c1_i32 to %addr : !fir.ref<i32>
+    fir.store %c2_i32 to %addr : !fir.ref<i32>
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[CONST1:%.*]] = arith.constant
+// CHECK: [[CONST2:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.store [[CONST1]] to [[ADDR]]
+// CHECK: fir.store [[CONST2]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "one write to the the alloca gets replaced."
+  func.func @func() {
+    %c1_i32 = arith.constant 1 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.store %c1_i32 to %addr : !fir.ref<i32>
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[ADDR:%.*]] = fir.address_of([[EXTR:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+// CHECK: fir.global internal [[EXTR]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }
+
+}
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
index 328191ea32512..1598f303755cb 100644
--- a/flang/test/Transforms/constant-argument-globalisation.fir
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -1,4 +1,5 @@
 // RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+// RUN: %flang_fc1 -emit-llvm -flang-deprecated-no-hlfir -O2 -mllvm --disable-constant-argument-globalisation -o - %s | FileCheck --check-prefix=DISABLE %s
 module {
   func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
     %0 = fir.alloca i32 {adapt.valuebyref}
@@ -19,8 +20,8 @@ module {
     return
   }
   func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
-}
-// CHECK-LABEL: func.func @sub1
+
+// CHECK-LABEL: func.func @sub1(
 // CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
 // CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
 // CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
@@ -43,3 +44,23 @@ module {
 // CHECK: %{{.*}} = arith.constant 1 : i32
 // CHECK: fir.has_value %{{.*}} : i32
 // CHECK: }
+
+// DISABLE-LABEL: ; ModuleID =
+// DISABLE-NOT: @_extruded
+// DISABLE:  define void @sub1(
+// DISABLE-SAME: ptr [[ARG0:%.*]],
+// DISABLE-SAME: ptr [[ARG1:%.*]])
+// DISABLE-SMAE: {
+// DISABLE: [[CONST_I:%.*]] = alloca i32
+// DISABLE: [[CONST_R1:%.*]] = alloca double
+// DISABLE: [[CONST_R0:%.*]] = alloca double
+// DISABLE: store double 0.0{{.*}}+00, ptr [[CONST_R0]]
+// DISABLE: store double 1.0{{.*}}+00, ptr [[CONST_R1]]
+// DISABLE: store i32 1, ptr [[CONST_I]]
+// DISABLE: call void @sub2(ptr nonnull [[CONST_R0]],
+// DISABLE-SAME: ptr nonnull [[CONST_R1]], 
+// DISABLE-SAME: ptr [[ARG0]], ptr [[ARG1]],
+// DISABLE-SAME: ptr nonnull [[CONST_I]])
+// DISABLE: ret void
+// DISABLE: }
+}

>From 732c393746922f2327ce3fa822f3b8402d4b930b Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 15 Dec 2023 11:44:01 +0000
Subject: [PATCH 6/7] Fix tests

---
 flang/test/Driver/bbc-mlir-pass-pipeline.f90   | 2 +-
 flang/test/Driver/mlir-debug-pass-pipeline.f90 | 1 -
 flang/test/Driver/mlir-pass-pipeline.f90       | 2 +-
 flang/test/Fir/basic-program.fir               | 2 +-
 4 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 28c06826ef415..3a4e17f16dc81 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -32,6 +32,7 @@
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: SimplifyIntrinsics
 ! CHECK-NEXT: AlgebraicSimplification
+! CHECK-NEXT: ConstantArgumentGlobalisationOpt
 ! CHECK-NEXT: CSE
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
@@ -39,7 +40,6 @@
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
 
-! CHECK-NEXT: ConstExtruderOpt
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: CSE
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42180d2966cb7..49b1f8c5c3134 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,7 +65,6 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
-! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index c7bb5cb5406d0..43461b76cf2ab 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -66,13 +66,13 @@
 ! ALL-NEXT: SimplifyRegionLite
 !  O2-NEXT: SimplifyIntrinsics
 !  O2-NEXT: AlgebraicSimplification
+!  O2-NEXT: ConstantArgumentGlobalisationOpt
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
-! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 517c6054ed7a9..0a72d400ef410 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -66,13 +66,13 @@ func.func @_QQmain() {
 // PASSES-NEXT: SimplifyRegionLite
 // PASSES-NEXT: SimplifyIntrinsics
 // PASSES-NEXT: AlgebraicSimplification
+// PASSES-NEXT: ConstantArgumentGlobalisationOpt
 // PASSES-NEXT: CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
-// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite

>From 4ca4e1ff53e134a6a39f1944e27fe0fe3c3ce4a5 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 10 Jun 2024 18:04:38 +0100
Subject: [PATCH 7/7] Fix rebase issues

---
 flang/include/flang/Optimizer/Transforms/Passes.h | 4 ++++
 flang/include/flang/Tools/CLOptions.inc           | 4 ++++
 2 files changed, 8 insertions(+)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index a7ba704fdb39b..e3bb818a285fc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -54,6 +54,8 @@ namespace fir {
 #define GEN_PASS_DECL_OMPMAPINFOFINALIZATIONPASS
 #define GEN_PASS_DECL_OMPMARKDECLARETARGETPASS
 #define GEN_PASS_DECL_OMPFUNCTIONFILTERING
+#define GEN_PASS_DECL_CONSTANTARGUMENTGLOBALISATIONOPT
+
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
@@ -77,6 +79,8 @@ std::unique_ptr<mlir::Pass> createVScaleAttrPass();
 std::unique_ptr<mlir::Pass>
 createVScaleAttrPass(std::pair<unsigned, unsigned> vscaleAttr);
 
+std::unique_ptr<mlir::Pass> createConstantArgumentGlobalisationPass();
+
 struct FunctionAttrTypes {
   mlir::LLVM::framePointerKind::FramePointerKind framePointerKind =
       mlir::LLVM::framePointerKind::FramePointerKind::None;
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d709efc9a34e7..26b1ad0366642 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -288,6 +288,10 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
+  // FIR Inliner Callback
+  pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
+
+  pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
 
   // Polymorphic types



More information about the flang-commits mailing list