[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Wed Jun 19 10:10:22 PDT 2024


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/73829

>From 08b3f4a1599ae11301e985c04121d1bc40d88e55 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 22 Nov 2023 14:01:38 +0000
Subject: [PATCH 01/11] [Flang] Extracting internal constants from scalar
 literals

Constants actual arguments in function/subroutine calls are currently lowered
as allocas + store. This can sometimes inhibit LTO and the constant will not
be propagated to the called function. Particularly in cases where the
function/subroutine call happens inside a condition.

This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.

Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
 .../flang/Optimizer/Transforms/Passes.td      |  10 +
 flang/include/flang/Tools/CLOptions.inc       |   4 -
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Optimizer/Transforms/ConstExtruder.cpp    | 216 ++++++++++++++++++
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   1 +
 .../test/Driver/mlir-debug-pass-pipeline.f90  |   1 +
 flang/test/Driver/mlir-pass-pipeline.f90      |   1 +
 flang/test/Fir/basic-program.fir              |   1 +
 flang/test/Fir/boxproc.fir                    |   4 +-
 .../test/Lower/character-local-variables.f90  |   3 +-
 flang/test/Lower/dummy-arguments.f90          |   4 +-
 flang/test/Lower/host-associated.f90          |   7 +-
 flang/test/Transforms/const-extrude.f90       |  32 +++
 13 files changed, 269 insertions(+), 16 deletions(-)
 create mode 100644 flang/lib/Optimizer/Transforms/ConstExtruder.cpp
 create mode 100644 flang/test/Transforms/const-extrude.f90

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 7a3baca4c19da..4c35240114e76 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -251,6 +251,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
   ];
 }
 
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+  let summary = "Convert scalar literals of function arguments to global constants.";
+  let description = [{
+    Convert scalar literals of function arguments to global constants.
+  }];
+  let dependentDialects = [ "fir::FIROpsDialect" ];
+  let constructor = "::fir::createConstExtruderPass()";
+}
+
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
   let summary = "Move local array allocations from heap memory into stack memory";
   let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 2a0cfc04aa350..f180aed04ddb3 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -282,10 +282,6 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
-  // FIR Inliner Callback
-  pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
-
-  pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
 
   // Polymorphic types
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 5ef930fdb2c2f..800fe44dfdc11 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_flang_library(FIRTransforms
   AnnotateConstant.cpp
   AssumedRankOpConversion.cpp
   CharacterConversion.cpp
+  ConstExtruder.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 0000000000000..1bb1cd2269871
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+  if (!a || !a->getDefiningOp())
+    return false;
+
+  // is alloca
+  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+    // alloca has annotation
+    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+          auto constant_def = store->getOperand(0).getDefiningOp();
+          // Expect constant definition operation
+          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+  mlir::DominanceInfo &di;
+
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+      : OpRewritePattern(ctx), di(_di) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+    auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, module);
+    llvm::SmallVector<mlir::Value> newOperands;
+    llvm::SmallVector<mlir::Operation *> toErase;
+    for (const auto &a : callOp.getArgs()) {
+      if (auto alloca =
+              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+        if (needsExtrusion(&a)) {
+
+          mlir::Type varTy = alloca.getInType();
+          assert(!fir::hasDynamicSize(varTy) &&
+                 "only expect statically sized scalars to be by value");
+
+          // find immediate store with const argument
+          llvm::SmallVector<mlir::Operation *> stores;
+          for (mlir::Operation *s : alloca.getOperation()->getUsers())
+            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+              stores.push_back(s);
+          assert(stores.size() == 1 && "expected exactly one store");
+          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+          // Expect constant definition operation or force legalisation of the
+          // callOp and continue with its next argument
+          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            // unable to remove alloca arg
+            newOperands.push_back(a);
+            continue;
+          }
+
+          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+          auto loc = callOp.getLoc();
+          llvm::StringRef globalPrefix = "_extruded_";
+
+          std::string globalName;
+          while (!globalName.length() || builder.getNamedGlobal(globalName))
+            globalName =
+                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+          if (alloca->hasOneUse()) {
+            toErase.push_back(alloca);
+            toErase.push_back(stores[0]);
+          } else {
+            int count = -2;
+            for (mlir::Operation *s : alloca.getOperation()->getUsers())
+              if (di.dominates(stores[0], s))
+                ++count;
+
+            // delete if dominates itself and one more operation (which should
+            // be callOp)
+            if (!count)
+              toErase.push_back(stores[0]);
+          }
+          auto global = builder.createGlobalConstant(
+              loc, varTy, globalName,
+              [&](fir::FirOpBuilder &builder) {
+                mlir::Operation *cln = constant_def->clone();
+                builder.insert(cln);
+                fir::ExtendedValue exv{cln->getResult(0)};
+                mlir::Value valBase = fir::getBase(exv);
+                mlir::Value val = builder.createConvert(loc, varTy, valBase);
+                builder.create<fir::HasValueOp>(loc, val);
+              },
+              builder.createInternalLinkage());
+          mlir::Value ope = {builder.create<fir::AddrOfOp>(
+              loc, global.resultType(), global.getSymbol())};
+          newOperands.push_back(ope);
+        } else {
+          // alloca but without attr, add it
+          newOperands.push_back(a);
+        }
+      } else {
+        // non-alloca operand, add it
+        newOperands.push_back(a);
+      }
+    }
+
+    auto loc = callOp.getLoc();
+    llvm::SmallVector<mlir::Type> newResultTypes;
+    newResultTypes.append(callOp.getResultTypes().begin(),
+                          callOp.getResultTypes().end());
+    fir::CallOp newOp = builder.create<fir::CallOp>(
+        loc, newResultTypes,
+        callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                       : mlir::SymbolRefAttr{},
+        newOperands, callOp.getFastmathAttr());
+    rewriter.replaceOp(callOp, newOp);
+
+    for (auto e : toErase)
+      rewriter.eraseOp(e);
+
+    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                            << newOp << '\n');
+    return mlir::success();
+  }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+  mlir::DominanceInfo *di;
+
+public:
+  ConstExtruderOpt() {}
+
+  void runOnOperation() override {
+    mlir::ModuleOp mod = getOperation();
+    di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+  }
+
+  void runOnFunc(mlir::func::FuncOp &func) {
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target(*context);
+
+    // If func is a declaration, skip it.
+    if (func.empty())
+      return;
+
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+      for (auto a : op.getArgs()) {
+        if (needsExtrusion(&a))
+          return false;
+      }
+      return true;
+    });
+
+    patterns.insert<CallOpRewriter>(context, *di);
+    if (mlir::failed(
+            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+      mlir::emitError(func.getLoc(),
+                      "error in constant extrusion optimization\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+  return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c94b98c7c5805..8e81309209cf7 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,6 +38,7 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
+! CHECK-NEXT:   ConstExtruderOpt
 
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 49b1f8c5c3134..42180d2966cb7 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,6 +65,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 8e1a3d43edd1c..c7bb5cb5406d0 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -72,6 +72,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index dd184d99cb809..517c6054ed7a9 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -72,6 +72,7 @@ func.func @_QQmain() {
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
+// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af0..2ddc0ef525ac4 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
 
 // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
 // CHECK-SAME:              %[[VAL_0:.*]])
-// CHECK:         %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK:         store i32 4, ptr %[[VAL_1]], align 4
-// CHECK:         call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK:         call void %[[VAL_0]](ptr @{{.*}})
 
 func.func @_QPtest_proc_dummy() {
   %c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e..b1cfc540f4389 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
-  ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+  ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
   ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 331e089a60fa0..7c85b7c0a746d 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
-  ! CHECK-DAG: %[[TEN:.*]] = arith.constant
-  ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index cdc7e6a05288a..0b5311402d51e 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
 
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK:         %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK:         fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
 ! CHECK:         return
 ! CHECK:       }
 
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 0000000000000..70cdaf496f34a
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+  implicit none
+  integer x, y
+  
+  call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }

>From 4d9a00c5d1cfad13bc53efd0c1c0f1688817ec29 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 6 Dec 2023 16:12:14 +0000
Subject: [PATCH 02/11] Use greedy rewriter

---
 .../Optimizer/Transforms/ConstExtruder.cpp    | 23 +++++++------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 1bb1cd2269871..00c9f30e9dbac 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -16,7 +16,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
 protected:
   mlir::DominanceInfo *di;
+  mlir::GreedyRewriteConfig config;
 
 public:
-  ConstExtruderOpt() {}
+  ConstExtruderOpt() {
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+  }
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
   void runOnFunc(mlir::func::FuncOp &func) {
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target(*context);
 
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
-      for (auto a : op.getArgs()) {
-        if (needsExtrusion(&a))
-          return false;
-      }
-      return true;
-    });
-
     patterns.insert<CallOpRewriter>(context, *di);
-    if (mlir::failed(
-            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            func, std::move(patterns), config))) {
       mlir::emitError(func.getLoc(),
                       "error in constant extrusion optimization\n");
       signalPassFailure();

>From 87f1009107938f5696ece0c414a9458bfc5cc1a8 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 12 Dec 2023 14:36:10 +0000
Subject: [PATCH 03/11] Fix review comments

---
 .../flang/Optimizer/Transforms/Passes.td      |   2 +-
 .../Optimizer/Transforms/ConstExtruder.cpp    | 238 ++++++++----------
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   2 +-
 3 files changed, 111 insertions(+), 131 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 4c35240114e76..02ab33c89b398 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -253,7 +253,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
 
 // This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
 def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
-  let summary = "Convert scalar literals of function arguments to global constants.";
+  let summary = "Convert constant function arguments to global constants.";
   let description = [{
     Convert scalar literals of function arguments to global constants.
   }];
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 00c9f30e9dbac..796a6a05a580f 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp -----------------------------------------------===//
+//===- ConstExtruder.cpp --------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include <atomic>
 
 namespace fir {
 #define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,38 +25,16 @@ namespace fir {
 #define DEBUG_TYPE "flang-const-extruder-opt"
 
 namespace {
-std::atomic<int> uniqueLitId = 1;
-
-static bool needsExtrusion(const mlir::Value *a) {
-  if (!a || !a->getDefiningOp())
-    return false;
-
-  // is alloca
-  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
-    // alloca has annotation
-    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
-      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
-        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
-          auto constant_def = store->getOperand(0).getDefiningOp();
-          // Expect constant definition operation
-          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            return true;
-          }
-        }
-      }
-    }
-  }
-  return false;
-}
+unsigned uniqueLitId = 1;
 
 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 protected:
-  mlir::DominanceInfo &di;
+  const mlir::DominanceInfo &di;
 
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+  CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
       : OpRewritePattern(ctx), di(_di) {}
 
   mlir::LogicalResult
@@ -68,131 +42,137 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
                   mlir::PatternRewriter &rewriter) const override {
     LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
     auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    bool needUpdate = false;
     fir::FirOpBuilder builder(rewriter, module);
     llvm::SmallVector<mlir::Value> newOperands;
     llvm::SmallVector<mlir::Operation *> toErase;
-    for (const auto &a : callOp.getArgs()) {
-      if (auto alloca =
-              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
-        if (needsExtrusion(&a)) {
-
-          mlir::Type varTy = alloca.getInType();
-          assert(!fir::hasDynamicSize(varTy) &&
-                 "only expect statically sized scalars to be by value");
-
-          // find immediate store with const argument
-          llvm::SmallVector<mlir::Operation *> stores;
-          for (mlir::Operation *s : alloca.getOperation()->getUsers())
-            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
-              stores.push_back(s);
-          assert(stores.size() == 1 && "expected exactly one store");
-          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
-
-          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
-          // Expect constant definition operation or force legalisation of the
-          // callOp and continue with its next argument
-          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            // unable to remove alloca arg
-            newOperands.push_back(a);
-            continue;
-          }
+    for (const mlir::Value &a : callOp.getArgs()) {
+      auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
+      // We can convert arguments that are alloca, and that has
+      // the value by reference attribute. All else is just added
+      // to the argument list.
+      if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+        newOperands.push_back(a);
+        continue;
+      }
 
-          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
-
-          auto loc = callOp.getLoc();
-          llvm::StringRef globalPrefix = "_extruded_";
-
-          std::string globalName;
-          while (!globalName.length() || builder.getNamedGlobal(globalName))
-            globalName =
-                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
-
-          if (alloca->hasOneUse()) {
-            toErase.push_back(alloca);
-            toErase.push_back(stores[0]);
-          } else {
-            int count = -2;
-            for (mlir::Operation *s : alloca.getOperation()->getUsers())
-              if (di.dominates(stores[0], s))
-                ++count;
-
-            // delete if dominates itself and one more operation (which should
-            // be callOp)
-            if (!count)
-              toErase.push_back(stores[0]);
+      mlir::Type varTy = alloca.getInType();
+      assert(!fir::hasDynamicSize(varTy) &&
+             "only expect statically sized scalars to be by value");
+
+      // Find immediate store with const argument
+      mlir::Operation *store = nullptr;
+      for (mlir::Operation *s : alloca->getUsers()) {
+        if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
+          // We can only deal with ONE store - if already found one,
+          // set to nullptr and exit the loop.
+          if (store) {
+            store = nullptr;
+            break;
           }
-          auto global = builder.createGlobalConstant(
-              loc, varTy, globalName,
-              [&](fir::FirOpBuilder &builder) {
-                mlir::Operation *cln = constant_def->clone();
-                builder.insert(cln);
-                fir::ExtendedValue exv{cln->getResult(0)};
-                mlir::Value valBase = fir::getBase(exv);
-                mlir::Value val = builder.createConvert(loc, varTy, valBase);
-                builder.create<fir::HasValueOp>(loc, val);
-              },
-              builder.createInternalLinkage());
-          mlir::Value ope = {builder.create<fir::AddrOfOp>(
-              loc, global.resultType(), global.getSymbol())};
-          newOperands.push_back(ope);
-        } else {
-          // alloca but without attr, add it
-          newOperands.push_back(a);
+          store = s;
         }
-      } else {
-        // non-alloca operand, add it
+      }
+
+      // If we didn't find one signle store, add argument as is, and move on.
+      if (!store) {
+        newOperands.push_back(a);
+        continue;
+      }
+
+      LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
+
+      mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
+      // Expect constant definition operation or force legalisation of the
+      // callOp and continue with its next argument
+      if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+        // Unable to remove alloca arg
         newOperands.push_back(a);
+        continue;
       }
+
+      LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+      std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+      assert(!builder.getNamedGlobal(globalName) &&
+             "We should have a unique name here");
+
+      unsigned count = 0;
+      for (mlir::Operation *s : alloca->getUsers())
+        if (di.dominates(store, s))
+          ++count;
+
+      // Delete if dominates itself and one more operation (which should
+      // be callOp)
+      if (count == 2)
+        toErase.push_back(store);
+
+      auto loc = callOp.getLoc();
+      fir::GlobalOp global = builder.createGlobalConstant(
+          loc, varTy, globalName,
+          [&](fir::FirOpBuilder &builder) {
+            mlir::Operation *cln = constant_def->clone();
+            builder.insert(cln);
+            mlir::Value val =
+                builder.createConvert(loc, varTy, cln->getResult(0));
+            builder.create<fir::HasValueOp>(loc, val);
+          },
+          builder.createInternalLinkage());
+      mlir::Value addr = {builder.create<fir::AddrOfOp>(
+          loc, global.resultType(), global.getSymbol())};
+      newOperands.push_back(addr);
+      needUpdate = true;
     }
 
-    auto loc = callOp.getLoc();
-    llvm::SmallVector<mlir::Type> newResultTypes;
-    newResultTypes.append(callOp.getResultTypes().begin(),
-                          callOp.getResultTypes().end());
-    fir::CallOp newOp = builder.create<fir::CallOp>(
-        loc, newResultTypes,
-        callOp.getCallee().has_value() ? callOp.getCallee().value()
-                                       : mlir::SymbolRefAttr{},
-        newOperands, callOp.getFastmathAttr());
-    rewriter.replaceOp(callOp, newOp);
-
-    for (auto e : toErase)
-      rewriter.eraseOp(e);
-
-    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
-                            << newOp << '\n');
-    return mlir::success();
+    if (needUpdate) {
+      auto loc = callOp.getLoc();
+      llvm::SmallVector<mlir::Type> newResultTypes;
+      newResultTypes.append(callOp.getResultTypes().begin(),
+                            callOp.getResultTypes().end());
+      fir::CallOp newOp = builder.create<fir::CallOp>(
+          loc, newResultTypes,
+          callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                         : mlir::SymbolRefAttr{},
+          newOperands, callOp.getFastmathAttr());
+      rewriter.replaceOp(callOp, newOp);
+
+      for (auto e : toErase)
+        rewriter.eraseOp(e);
+      LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                              << newOp << '\n');
+      return mlir::success();
+    }
+
+    // Failure here just means "we couldn't do the conversion", which is
+    // perfectly acceptable to the upper layers of this function.
+    return mlir::failure();
   }
 };
 
-// This pass attempts to convert immediate scalar literals in function calls
+// this pass attempts to convert immediate scalar literals in function calls
 // to global constants to allow transformations as Dead Argument Elimination
 class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
-protected:
-  mlir::DominanceInfo *di;
-  mlir::GreedyRewriteConfig config;
-
 public:
-  ConstExtruderOpt() {
-    config.enableRegionSimplification = false;
-    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
-  }
+  ConstExtruderOpt() = default;
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
-    di = &getAnalysis<mlir::DominanceInfo>();
-    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+    mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
   }
 
-  void runOnFunc(mlir::func::FuncOp &func) {
-    auto *context = &getContext();
-    mlir::RewritePatternSet patterns(context);
-
+  void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::GreedyRewriteConfig config;
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+
     patterns.insert<CallOpRewriter>(context, *di);
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 8e81309209cf7..28c06826ef415 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,8 +38,8 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
-! CHECK-NEXT:   ConstExtruderOpt
 
+! CHECK-NEXT: ConstExtruderOpt
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: CSE

>From a56bab60e5fc823597bd1df07fae8641f726c46f Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 13 Dec 2023 15:35:24 +0000
Subject: [PATCH 04/11] Rename and add option to disable

---
 .../flang/Optimizer/Transforms/Passes.td      |  4 +-
 flang/include/flang/Tools/CLOptions.inc       |  4 ++
 flang/lib/Optimizer/Transforms/CMakeLists.txt |  2 +-
 ....cpp => ConstantArgumentGlobalisation.cpp} | 13 +++---
 flang/test/Transforms/const-extrude.f90       | 32 -------------
 .../constant-argument-globalisation.fir       | 45 +++++++++++++++++++
 6 files changed, 59 insertions(+), 41 deletions(-)
 rename flang/lib/Optimizer/Transforms/{ConstExtruder.cpp => ConstantArgumentGlobalisation.cpp} (94%)
 delete mode 100644 flang/test/Transforms/const-extrude.f90
 create mode 100644 flang/test/Transforms/constant-argument-globalisation.fir

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 02ab33c89b398..767d04981c9ab 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -252,13 +252,13 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
 }
 
 // This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
-def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt", "mlir::ModuleOp"> {
   let summary = "Convert constant function arguments to global constants.";
   let description = [{
     Convert scalar literals of function arguments to global constants.
   }];
   let dependentDialects = [ "fir::FIROpsDialect" ];
-  let constructor = "::fir::createConstExtruderPass()";
+  let constructor = "::fir::createConstantArgumentGlobalisationPass()";
 }
 
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index f180aed04ddb3..d62beb9d810cd 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -86,6 +86,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
+DisableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
+    "disable the local constants to global constant conversion");
 
 using PassConstructor = std::unique_ptr<mlir::Pass>();
 
@@ -270,6 +272,8 @@ inline void createDefaultFIROptimizerPassPipeline(
     // These passes may increase code size.
     pm.addPass(fir::createSimplifyIntrinsics());
     pm.addPass(fir::createAlgebraicSimplificationPass(config));
+    if (!disableConstantArgumentGlobalisation)
+      pm.addPass(fir::createConstantArgumentGlobalisationPass());
   }
 
   if (pc.LoopVersioning)
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 800fe44dfdc11..cdf5ca368ebd1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,7 +6,7 @@ add_flang_library(FIRTransforms
   AnnotateConstant.cpp
   AssumedRankOpConversion.cpp
   CharacterConversion.cpp
-  ConstExtruder.cpp
+  ConstantArgumentGlobalisation.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
similarity index 94%
rename from flang/lib/Optimizer/Transforms/ConstExtruder.cpp
rename to flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 796a6a05a580f..2859a57226f16 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -18,7 +18,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace fir {
-#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
@@ -151,10 +151,11 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 
 // this pass attempts to convert immediate scalar literals in function calls
 // to global constants to allow transformations as Dead Argument Elimination
-class ConstExtruderOpt
-    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+class ConstantArgumentGlobalisationOpt
+    : public fir::impl::ConstantArgumentGlobalisationOptBase<
+          ConstantArgumentGlobalisationOpt> {
 public:
-  ConstExtruderOpt() = default;
+  ConstantArgumentGlobalisationOpt() = default;
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
@@ -184,6 +185,6 @@ class ConstExtruderOpt
 };
 } // namespace
 
-std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
-  return std::make_unique<ConstExtruderOpt>();
+std::unique_ptr<mlir::Pass> fir::createConstantArgumentGlobalisationPass() {
+  return std::make_unique<ConstantArgumentGlobalisationOpt>();
 }
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
deleted file mode 100644
index 70cdaf496f34a..0000000000000
--- a/flang/test/Transforms/const-extrude.f90
+++ /dev/null
@@ -1,32 +0,0 @@
-! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
-
-subroutine sub1(x,y)
-  implicit none
-  integer x, y
-  
-  call sub2(0.0d0, 1.0d0, x, y, 1)
-end subroutine sub1
-
-!CHECK-LABEL: func.func @_QPsub1
-!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
-!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
-!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
-!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
-!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
-!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
-!CHECK: return
-
-!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
-!CHECK: %{{.*}} = arith.constant 1 : i32
-!CHECK: fir.has_value %{{.*}} : i32
-!CHECK: }
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
new file mode 100644
index 0000000000000..328191ea32512
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -0,0 +1,45 @@
+// RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+module {
+  func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
+    %0 = fir.alloca i32 {adapt.valuebyref}
+    %1 = fir.alloca f64 {adapt.valuebyref}
+    %2 = fir.alloca f64 {adapt.valuebyref}
+    %c1_i32 = arith.constant 1 : i32
+    %cst = arith.constant 1.000000e+00 : f64
+    %cst_0 = arith.constant 0.000000e+00 : f64
+    %3 = fir.declare %arg0 {uniq_name = "_QFsub1Ex"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    %4 = fir.declare %arg1 {uniq_name = "_QFsub1Ey"} : (!fir.ref<i32>) -> !fir.ref<i32>
+    fir.store %cst_0 to %2 : !fir.ref<f64>
+    %false = arith.constant false
+    fir.store %cst to %1 : !fir.ref<f64>
+    %false_1 = arith.constant false
+    fir.store %c1_i32 to %0 : !fir.ref<i32>
+    %false_2 = arith.constant false
+    fir.call @sub2(%2, %1, %3, %4, %0) fastmath<contract> : (!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
+}
+// CHECK-LABEL: func.func @sub1
+// CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+// CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+// CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+// CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+// CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+// CHECK: return
+
+// CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }

>From 416514605848dc18a965c43eec189d6dae7e5957 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 14 Dec 2023 18:50:04 +0000
Subject: [PATCH 05/11] Add more testing

---
 .../constant-argument-globalisation-2.fir     | 80 +++++++++++++++++++
 .../constant-argument-globalisation.fir       | 25 +++++-
 2 files changed, 103 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Transforms/constant-argument-globalisation-2.fir

diff --git a/flang/test/Transforms/constant-argument-globalisation-2.fir b/flang/test/Transforms/constant-argument-globalisation-2.fir
new file mode 100644
index 0000000000000..03855b5bfb762
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation-2.fir
@@ -0,0 +1,80 @@
+// RUN: fir-opt --split-input-file --constant-argument-globalisation-opt < %s | FileCheck %s
+
+module {
+// Test for "two conditional writes to the same alloca doesn't get replaced."
+  func.func @func(%arg0: i32, %arg1: i1) {
+    %c2_i32 = arith.constant 2 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.if %arg1 {
+      fir.store %c2_i32 to %addr : !fir.ref<i32>
+    } else {
+      fir.store %arg0 to %addr : !fir.ref<i32>
+    }
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK-SAME: [[ARG0:%.*]]: i32
+// CHECK-SAME: [[ARG1:%.*]]: i1)
+// CHECK: [[CONST:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.if [[ARG1]]
+// CHECK: fir.store [[CONST]] to [[ADDR]]
+// CHECK:  } else {
+// CHECK: fir.store [[ARG0]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "two writes to the same alloca doesn't get replaced."
+  func.func @func() {
+    %c1_i32 = arith.constant 1 : i32
+    %c2_i32 = arith.constant 2 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.store %c1_i32 to %addr : !fir.ref<i32>
+    fir.store %c2_i32 to %addr : !fir.ref<i32>
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[CONST1:%.*]] = arith.constant
+// CHECK: [[CONST2:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.store [[CONST1]] to [[ADDR]]
+// CHECK: fir.store [[CONST2]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "one write to the the alloca gets replaced."
+  func.func @func() {
+    %c1_i32 = arith.constant 1 : i32
+    %addr = fir.alloca i32 {adapt.valuebyref}
+    fir.store %c1_i32 to %addr : !fir.ref<i32>
+    fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+    return
+  }
+  func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[ADDR:%.*]] = fir.address_of([[EXTR:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+// CHECK: fir.global internal [[EXTR]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }
+
+}
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
index 328191ea32512..1598f303755cb 100644
--- a/flang/test/Transforms/constant-argument-globalisation.fir
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -1,4 +1,5 @@
 // RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+// RUN: %flang_fc1 -emit-llvm -flang-deprecated-no-hlfir -O2 -mllvm --disable-constant-argument-globalisation -o - %s | FileCheck --check-prefix=DISABLE %s
 module {
   func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
     %0 = fir.alloca i32 {adapt.valuebyref}
@@ -19,8 +20,8 @@ module {
     return
   }
   func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
-}
-// CHECK-LABEL: func.func @sub1
+
+// CHECK-LABEL: func.func @sub1(
 // CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
 // CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
 // CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
@@ -43,3 +44,23 @@ module {
 // CHECK: %{{.*}} = arith.constant 1 : i32
 // CHECK: fir.has_value %{{.*}} : i32
 // CHECK: }
+
+// DISABLE-LABEL: ; ModuleID =
+// DISABLE-NOT: @_extruded
+// DISABLE:  define void @sub1(
+// DISABLE-SAME: ptr [[ARG0:%.*]],
+// DISABLE-SAME: ptr [[ARG1:%.*]])
+// DISABLE-SMAE: {
+// DISABLE: [[CONST_I:%.*]] = alloca i32
+// DISABLE: [[CONST_R1:%.*]] = alloca double
+// DISABLE: [[CONST_R0:%.*]] = alloca double
+// DISABLE: store double 0.0{{.*}}+00, ptr [[CONST_R0]]
+// DISABLE: store double 1.0{{.*}}+00, ptr [[CONST_R1]]
+// DISABLE: store i32 1, ptr [[CONST_I]]
+// DISABLE: call void @sub2(ptr nonnull [[CONST_R0]],
+// DISABLE-SAME: ptr nonnull [[CONST_R1]], 
+// DISABLE-SAME: ptr [[ARG0]], ptr [[ARG1]],
+// DISABLE-SAME: ptr nonnull [[CONST_I]])
+// DISABLE: ret void
+// DISABLE: }
+}

>From d4f22bb86a841b8c0920bbd308acf62d7f419054 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 15 Dec 2023 11:44:01 +0000
Subject: [PATCH 06/11] Fix tests

---
 flang/test/Driver/bbc-mlir-pass-pipeline.f90   | 2 +-
 flang/test/Driver/mlir-debug-pass-pipeline.f90 | 1 -
 flang/test/Driver/mlir-pass-pipeline.f90       | 2 +-
 flang/test/Fir/basic-program.fir               | 2 +-
 4 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 28c06826ef415..3a4e17f16dc81 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -32,6 +32,7 @@
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: SimplifyIntrinsics
 ! CHECK-NEXT: AlgebraicSimplification
+! CHECK-NEXT: ConstantArgumentGlobalisationOpt
 ! CHECK-NEXT: CSE
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
@@ -39,7 +40,6 @@
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
 
-! CHECK-NEXT: ConstExtruderOpt
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: CSE
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42180d2966cb7..49b1f8c5c3134 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,7 +65,6 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
-! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index c7bb5cb5406d0..43461b76cf2ab 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -66,13 +66,13 @@
 ! ALL-NEXT: SimplifyRegionLite
 !  O2-NEXT: SimplifyIntrinsics
 !  O2-NEXT: AlgebraicSimplification
+!  O2-NEXT: ConstantArgumentGlobalisationOpt
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
-! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 517c6054ed7a9..0a72d400ef410 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -66,13 +66,13 @@ func.func @_QQmain() {
 // PASSES-NEXT: SimplifyRegionLite
 // PASSES-NEXT: SimplifyIntrinsics
 // PASSES-NEXT: AlgebraicSimplification
+// PASSES-NEXT: ConstantArgumentGlobalisationOpt
 // PASSES-NEXT: CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
-// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite

>From 0052dd8f5e3377a279b3c15af50e1d1f501a865f Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 10 Jun 2024 18:04:38 +0100
Subject: [PATCH 07/11] Fix rebase issues

---
 flang/include/flang/Optimizer/Transforms/Passes.h | 2 ++
 flang/include/flang/Tools/CLOptions.inc           | 4 ++++
 2 files changed, 6 insertions(+)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 9fa819e2bf502..ff6921f1a79ac 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -56,6 +56,8 @@ namespace fir {
 #define GEN_PASS_DECL_OMPFUNCTIONFILTERING
 #define GEN_PASS_DECL_VSCALEATTR
 #define GEN_PASS_DECL_FUNCTIONATTR
+#define GEN_PASS_DECL_CONSTANTARGUMENTGLOBALISATIONOPT
+
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 
 std::unique_ptr<mlir::Pass> createAffineDemotionPass();
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d62beb9d810cd..04933b7eea01a 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -286,6 +286,10 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
+  // FIR Inliner Callback
+  pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
+
+  pm.addPass(fir::createSimplifyRegionLite());
   pm.addPass(mlir::createCSEPass());
 
   // Polymorphic types

>From f589f300823bcf5f446653ee15cc4e104b082f9c Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 14 Jun 2024 12:17:38 +0100
Subject: [PATCH 08/11] Review comment updates

---
 .../flang/Optimizer/Transforms/Passes.td      |  1 -
 flang/include/flang/Tools/CLOptions.inc       |  2 +-
 .../ConstantArgumentGlobalisation.cpp         | 30 +++++++------------
 .../test/Lower/character-local-variables.f90  |  2 +-
 flang/test/Lower/dummy-arguments.f90          |  2 +-
 flang/test/Lower/host-associated.f90          |  2 +-
 6 files changed, 14 insertions(+), 25 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 767d04981c9ab..8f66793ed7a49 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -258,7 +258,6 @@ def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt
     Convert scalar literals of function arguments to global constants.
   }];
   let dependentDialects = [ "fir::FIROpsDialect" ];
-  let constructor = "::fir::createConstantArgumentGlobalisationPass()";
 }
 
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 04933b7eea01a..3a322b111876c 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -273,7 +273,7 @@ inline void createDefaultFIROptimizerPassPipeline(
     pm.addPass(fir::createSimplifyIntrinsics());
     pm.addPass(fir::createAlgebraicSimplificationPass(config));
     if (!disableConstantArgumentGlobalisation)
-      pm.addPass(fir::createConstantArgumentGlobalisationPass());
+      pm.addPass(fir::createConstantArgumentGlobalisationOpt());
   }
 
   if (pc.LoopVersioning)
diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 2859a57226f16..7b7b1af9fc09d 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp --------------------------------------------------===//
+//===- ConstantArgumentGlobalisation.cpp ----------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -22,7 +22,7 @@ namespace fir {
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
-#define DEBUG_TYPE "flang-const-extruder-opt"
+#define DEBUG_TYPE "flang-constang-argument-globalisation-opt"
 
 namespace {
 unsigned uniqueLitId = 1;
@@ -93,7 +93,8 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 
       LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
 
-      std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+      std::string globalName =
+          "_global_const_." + std::to_string(uniqueLitId++);
       assert(!builder.getNamedGlobal(globalName) &&
              "We should have a unique name here");
 
@@ -138,7 +139,7 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 
       for (auto e : toErase)
         rewriter.eraseOp(e);
-      LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+      LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
                               << newOp << '\n');
       return mlir::success();
     }
@@ -150,7 +151,8 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 };
 
 // this pass attempts to convert immediate scalar literals in function calls
-// to global constants to allow transformations as Dead Argument Elimination
+// to global constants to allow transformations such as Dead Argument
+// Elimination
 class ConstantArgumentGlobalisationOpt
     : public fir::impl::ConstantArgumentGlobalisationOptBase<
           ConstantArgumentGlobalisationOpt> {
@@ -160,14 +162,6 @@ class ConstantArgumentGlobalisationOpt
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
     mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
-    mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
-  }
-
-  void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
-    // If func is a declaration, skip it.
-    if (func.empty())
-      return;
-
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     mlir::GreedyRewriteConfig config;
@@ -176,15 +170,11 @@ class ConstantArgumentGlobalisationOpt
 
     patterns.insert<CallOpRewriter>(context, *di);
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            func, std::move(patterns), config))) {
-      mlir::emitError(func.getLoc(),
-                      "error in constant extrusion optimization\n");
+            mod, std::move(patterns), config))) {
+      mlir::emitError(mod.getLoc(),
+                      "error in constant globalisation optimization\n");
       signalPassFailure();
     }
   }
 };
 } // namespace
-
-std::unique_ptr<mlir::Pass> fir::createConstantArgumentGlobalisationPass() {
-  return std::make_unique<ConstantArgumentGlobalisationOpt>();
-}
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index b1cfc540f4389..70a6c49afc7be 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,7 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
+  ! CHECK: %[[tmp:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i64>
   ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 7c85b7c0a746d..7953020937ddb 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,7 +2,7 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i32>
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index 0b5311402d51e..bf1fe3a6c574d 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -449,7 +449,7 @@ subroutine bar()
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i32>
 ! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
 
 ! CHECK:         return

>From 1f531fd59e154addca02ef67a975d2f077a3e6aa Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 17 Jun 2024 14:05:07 +0000
Subject: [PATCH 09/11] Copy all attributes when creating new operation.

Instead of just copying the fast-math attribute, copy all of the
attributes of the call operation.

Also amend the test to check that attributes are still there in the
updated code.
---
 .../Optimizer/Transforms/ConstantArgumentGlobalisation.cpp    | 4 +++-
 flang/test/Transforms/constant-argument-globalisation.fir     | 1 +
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 7b7b1af9fc09d..de86fe7188541 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -134,7 +134,9 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
           loc, newResultTypes,
           callOp.getCallee().has_value() ? callOp.getCallee().value()
                                          : mlir::SymbolRefAttr{},
-          newOperands, callOp.getFastmathAttr());
+          newOperands);
+      // Copy all the attributes from the old to new op.
+      newOp->setAttrs(callOp->getAttrs());
       rewriter.replaceOp(callOp, newOp);
 
       for (auto e : toErase)
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
index 1598f303755cb..e88493a01d515 100644
--- a/flang/test/Transforms/constant-argument-globalisation.fir
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -30,6 +30,7 @@ module {
 // CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
 // CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
 // CHECK: fir.call @sub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+// CHECK-SAME: fastmath<contract>
 // CHECK: return
 
 // CHECK: fir.global internal [[EXTR_0]] constant : f64 {

>From e39059034f69ffd7417af1081fd38f61d0e77c07 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 18 Jun 2024 09:41:15 +0000
Subject: [PATCH 10/11] Fix further review comments.

Additionally fixed a bug where same argument used twice would not have the
store and alloca removed.

Add test to check for those instructions being removed.
---
 .../ConstantArgumentGlobalisation.cpp         | 47 ++++++++++---------
 .../constant-argument-globalisation-2.fir     | 18 +++++++
 2 files changed, 43 insertions(+), 22 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index de86fe7188541..168fd49026022 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -22,7 +22,7 @@ namespace fir {
 #include "flang/Optimizer/Transforms/Passes.h.inc"
 } // namespace fir
 
-#define DEBUG_TYPE "flang-constang-argument-globalisation-opt"
+#define DEBUG_TYPE "flang-constant-argument-globalisation-opt"
 
 namespace {
 unsigned uniqueLitId = 1;
@@ -45,7 +45,7 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
     bool needUpdate = false;
     fir::FirOpBuilder builder(rewriter, module);
     llvm::SmallVector<mlir::Value> newOperands;
-    llvm::SmallVector<mlir::Operation *> toErase;
+    llvm::SmallVector<std::pair<mlir::Operation *, mlir::Operation *>> allocas;
     for (const mlir::Value &a : callOp.getArgs()) {
       auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
       // We can convert arguments that are alloca, and that has
@@ -74,7 +74,8 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
         }
       }
 
-      // If we didn't find one signle store, add argument as is, and move on.
+      // If we didn't find any store, or multiple stores, add argument as is
+      // and move on.
       if (!store) {
         newOperands.push_back(a);
         continue;
@@ -82,45 +83,36 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 
       LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
 
-      mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
-      // Expect constant definition operation or force legalisation of the
-      // callOp and continue with its next argument
-      if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+      mlir::Operation *definingOp = store->getOperand(0).getDefiningOp();
+      // If not a constant, add to operands and move on.
+      if (!mlir::isa<mlir::arith::ConstantOp>(definingOp)) {
         // Unable to remove alloca arg
         newOperands.push_back(a);
         continue;
       }
 
-      LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+      LLVM_DEBUG(llvm::dbgs() << " found define " << *definingOp << "\n");
 
       std::string globalName =
           "_global_const_." + std::to_string(uniqueLitId++);
       assert(!builder.getNamedGlobal(globalName) &&
              "We should have a unique name here");
 
-      unsigned count = 0;
-      for (mlir::Operation *s : alloca->getUsers())
-        if (di.dominates(store, s))
-          ++count;
-
-      // Delete if dominates itself and one more operation (which should
-      // be callOp)
-      if (count == 2)
-        toErase.push_back(store);
+      allocas.push_back(std::make_pair(alloca, store));
 
       auto loc = callOp.getLoc();
       fir::GlobalOp global = builder.createGlobalConstant(
           loc, varTy, globalName,
           [&](fir::FirOpBuilder &builder) {
-            mlir::Operation *cln = constant_def->clone();
+            mlir::Operation *cln = definingOp->clone();
             builder.insert(cln);
             mlir::Value val =
                 builder.createConvert(loc, varTy, cln->getResult(0));
             builder.create<fir::HasValueOp>(loc, val);
           },
           builder.createInternalLinkage());
-      mlir::Value addr = {builder.create<fir::AddrOfOp>(
-          loc, global.resultType(), global.getSymbol())};
+      mlir::Value addr = builder.create<fir::AddrOfOp>(
+          loc, global.resultType(), global.getSymbol());
       newOperands.push_back(addr);
       needUpdate = true;
     }
@@ -139,8 +131,19 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
       newOp->setAttrs(callOp->getAttrs());
       rewriter.replaceOp(callOp, newOp);
 
-      for (auto e : toErase)
-        rewriter.eraseOp(e);
+      for (auto a : allocas) {
+	unsigned count = 0;
+      
+	for (auto i : a.first->getUsers())
+	  ++count;
+
+	// If the alloca is only used for a store and the call operand, the
+	// store is no longer required.
+	if (count == 1) {
+	  rewriter.eraseOp(a.second);
+	  rewriter.eraseOp(a.first);
+	}
+      }
       LLVM_DEBUG(llvm::dbgs() << "global constant for " << callOp << " as "
                               << newOp << '\n');
       return mlir::success();
diff --git a/flang/test/Transforms/constant-argument-globalisation-2.fir b/flang/test/Transforms/constant-argument-globalisation-2.fir
index 03855b5bfb762..03e08a0dcb914 100644
--- a/flang/test/Transforms/constant-argument-globalisation-2.fir
+++ b/flang/test/Transforms/constant-argument-globalisation-2.fir
@@ -78,3 +78,21 @@ module {
 // CHECK: }
 
 }
+
+// -----
+// Check that same argument used twice is converted.
+module {
+  func.func @func(%arg0: !fir.ref<i32>, %arg1: i1) {
+    %c2_i32 = arith.constant 2 : i32
+    %addr1 = fir.alloca i32 {adapt.valuebyref}
+    fir.store %c2_i32 to %addr1 : !fir.ref<i32>
+    fir.call @sub1(%addr1, %addr1) : (!fir.ref<i32>, !fir.ref<i32>) -> ()
+    return
+  }
+}
+
+// CHECK-LABEL: func.func @func
+// CHECK-NEXT: %[[ARG1:.*]] = fir.address_of([[CONST1:@.*]]) : !fir.ref<i32>
+// CHECK-NEXT: %[[ARG2:.*]] = fir.address_of([[CONST2:@.*]]) : !fir.ref<i32>
+// CHECK-NEXT: fir.call @sub1(%[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: return

>From 4742569fe8db2d4eb0c0c3d13f61664d3541e8c6 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 19 Jun 2024 16:01:36 +0000
Subject: [PATCH 11/11] Make constant argument globalisation disabled by
 default.

This pass can convert local constant arguments into global values, but it can
have negative effects on code that is borderline to non-conforming, where a
constant argument is modified inside the callee - many compilers accept such
code. So, to use this pass, it needs to be enabled explicitly.

To enable in flang-new, use -mmlir --enable-constant-argument-globalisation
---
 flang/include/flang/Tools/CLOptions.inc                | 10 +++++++---
 .../Transforms/ConstantArgumentGlobalisation.cpp       |  2 +-
 flang/test/Driver/bbc-mlir-pass-pipeline.f90           |  1 -
 flang/test/Driver/mlir-pass-pipeline.f90               |  1 -
 flang/test/Fir/basic-program.fir                       |  1 -
 flang/test/Fir/boxproc.fir                             |  2 +-
 flang/test/Lower/character-local-variables.f90         |  8 ++++++--
 flang/test/Lower/dummy-arguments.f90                   |  4 +++-
 flang/test/Lower/host-associated.f90                   |  7 ++++---
 .../Transforms/constant-argument-globalisation.fir     |  4 ++--
 10 files changed, 24 insertions(+), 16 deletions(-)

diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 3a322b111876c..feabefaac1918 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -25,6 +25,10 @@
   static llvm::cl::opt<bool> disable##DOName("disable-" DOOption, \
       llvm::cl::desc("disable " DODescription " pass"), llvm::cl::init(false), \
       llvm::cl::Hidden)
+#define EnableOption(EOName, EOOption, EODescription) \
+  static llvm::cl::opt<bool> enable##EOName("enable-" EOOption, \
+      llvm::cl::desc("enable " EODescription " pass"), llvm::cl::init(false), \
+      llvm::cl::Hidden)
 
 /// Shared option in tools to control whether dynamically sized array
 /// allocations should always be on the heap.
@@ -86,8 +90,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
 
 DisableOption(ExternalNameConversion, "external-name-interop",
     "convert names with external convention");
-DisableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
-    "disable the local constants to global constant conversion");
+EnableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
+    "enable the local constants to global constant conversion");
 
 using PassConstructor = std::unique_ptr<mlir::Pass>();
 
@@ -272,7 +276,7 @@ inline void createDefaultFIROptimizerPassPipeline(
     // These passes may increase code size.
     pm.addPass(fir::createSimplifyIntrinsics());
     pm.addPass(fir::createAlgebraicSimplificationPass(config));
-    if (!disableConstantArgumentGlobalisation)
+    if (enableConstantArgumentGlobalisation)
       pm.addPass(fir::createConstantArgumentGlobalisationOpt());
   }
 
diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 168fd49026022..33d4fbe739f9c 100644
--- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -134,7 +134,7 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
       for (auto a : allocas) {
 	unsigned count = 0;
       
-	for (auto i : a.first->getUsers())
+	for ([[maybe_unused]]auto i : a.first->getUsers())
 	  ++count;
 
 	// If the alloca is only used for a store and the call operand, the
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 3a4e17f16dc81..c94b98c7c5805 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -32,7 +32,6 @@
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: SimplifyIntrinsics
 ! CHECK-NEXT: AlgebraicSimplification
-! CHECK-NEXT: ConstantArgumentGlobalisationOpt
 ! CHECK-NEXT: CSE
 ! CHECK-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! CHECK-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 43461b76cf2ab..8e1a3d43edd1c 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -66,7 +66,6 @@
 ! ALL-NEXT: SimplifyRegionLite
 !  O2-NEXT: SimplifyIntrinsics
 !  O2-NEXT: AlgebraicSimplification
-!  O2-NEXT: ConstantArgumentGlobalisationOpt
 ! ALL-NEXT: CSE
 ! ALL-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 ! ALL-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a72d400ef410..dd184d99cb809 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -66,7 +66,6 @@ func.func @_QQmain() {
 // PASSES-NEXT: SimplifyRegionLite
 // PASSES-NEXT: SimplifyIntrinsics
 // PASSES-NEXT: AlgebraicSimplification
-// PASSES-NEXT: ConstantArgumentGlobalisationOpt
 // PASSES-NEXT: CSE
 // PASSES-NEXT:   (S) 0 num-cse'd - Number of operations CSE'd
 // PASSES-NEXT:   (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 2ddc0ef525ac4..9f28b77e0a0b7 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,7 +16,7 @@
 
 // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
 // CHECK-SAME:              %[[VAL_0:.*]])
-// CHECK:         call void %[[VAL_0]](ptr @{{.*}})
+// CHECK:         call void %[[VAL_0]](ptr %{{.*}})
 
 func.func @_QPtest_proc_dummy() {
   %c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 70a6c49afc7be..d5b959eca1ff6 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -1,4 +1,6 @@
 ! RUN: bbc -hlfir=false %s -o - | FileCheck %s
+! RUN: bbc -hlfir=false --enable-constant-argument-globalisation %s -o - \
+! RUN:    | FileCheck %s --check-prefix=CHECK-CONST
 
 ! Test lowering of local character variables
 
@@ -116,8 +118,10 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[tmp:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i64>
-  ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
+  ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
+  ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+  ! CHECK-CONST: %[[tmp:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i64>
+  ! CHECK-CONST: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
 
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 7953020937ddb..331e089a60fa0 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,7 +2,9 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i32>
+  ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
+  ! CHECK-DAG: %[[TEN:.*]] = arith.constant
+  ! CHECK: fir.store %[[TEN]] to %[[TMP]]
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index bf1fe3a6c574d..cdc7e6a05288a 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,10 +448,11 @@ subroutine bar()
 
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
+! CHECK:         %[[VAL_1:.*]] = arith.constant 4 : i32
+! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
+! CHECK:         fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_global_const_.{{.*}}) : !fir.ref<i32>
-! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
-
+! CHECK:         fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
 ! CHECK:         return
 ! CHECK:       }
 
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
index e88493a01d515..f0be8bcef2c6d 100644
--- a/flang/test/Transforms/constant-argument-globalisation.fir
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -1,5 +1,5 @@
-// RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
-// RUN: %flang_fc1 -emit-llvm -flang-deprecated-no-hlfir -O2 -mllvm --disable-constant-argument-globalisation -o - %s | FileCheck --check-prefix=DISABLE %s
+// RUN: fir-opt --constant-argument-globalisation-opt  < %s | FileCheck %s
+// RUN: %flang_fc1 -emit-llvm -flang-deprecated-no-hlfir -O2 -o - %s | FileCheck --check-prefix=DISABLE %s
 module {
   func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
     %0 = fir.alloca i32 {adapt.valuebyref}



More information about the flang-commits mailing list