[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Tue Jun 11 01:56:38 PDT 2024
https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/73829
>From 59f4487f1e78d7e7209409b8e40ab04dc95054f9 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 22 Nov 2023 14:01:38 +0000
Subject: [PATCH 1/7] [Flang] Extracting internal constants from scalar
literals
Constants actual arguments in function/subroutine calls are currently lowered
as allocas + store. This can sometimes inhibit LTO and the constant will not
be propagated to the called function. Particularly in cases where the
function/subroutine call happens inside a condition.
This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.
Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
.../flang/Optimizer/Transforms/Passes.td | 10 +
flang/include/flang/Tools/CLOptions.inc | 4 -
flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 +
.../Optimizer/Transforms/ConstExtruder.cpp | 216 ++++++++++++++++++
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 1 +
.../test/Driver/mlir-debug-pass-pipeline.f90 | 1 +
flang/test/Driver/mlir-pass-pipeline.f90 | 1 +
flang/test/Fir/basic-program.fir | 1 +
flang/test/Fir/boxproc.fir | 4 +-
.../test/Lower/character-local-variables.f90 | 3 +-
flang/test/Lower/dummy-arguments.f90 | 4 +-
flang/test/Lower/host-associated.f90 | 7 +-
flang/test/Transforms/const-extrude.f90 | 32 +++
13 files changed, 269 insertions(+), 16 deletions(-)
create mode 100644 flang/lib/Optimizer/Transforms/ConstExtruder.cpp
create mode 100644 flang/test/Transforms/const-extrude.f90
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 82638200e5e20..df3b9fc6fb613 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -252,6 +252,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
];
}
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+ let summary = "Convert scalar literals of function arguments to global constants.";
+ let description = [{
+ Convert scalar literals of function arguments to global constants.
+ }];
+ let dependentDialects = [ "fir::FIROpsDialect" ];
+ let constructor = "::fir::createConstExtruderPass()";
+}
+
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
let summary = "Move local array allocations from heap memory into stack memory";
let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index fb3ec75d4078a..d2bf79d8789cc 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -284,10 +284,6 @@ inline void createDefaultFIROptimizerPassPipeline(
else
fir::addMemoryAllocationOpt(pm);
- // FIR Inliner Callback
- pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
-
- pm.addPass(fir::createSimplifyRegionLite());
pm.addPass(mlir::createCSEPass());
// Polymorphic types
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 5ef930fdb2c2f..800fe44dfdc11 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,6 +6,7 @@ add_flang_library(FIRTransforms
AnnotateConstant.cpp
AssumedRankOpConversion.cpp
CharacterConversion.cpp
+ ConstExtruder.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 0000000000000..1bb1cd2269871
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+ if (!a || !a->getDefiningOp())
+ return false;
+
+ // is alloca
+ if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+ // alloca has annotation
+ if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+ for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+ if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+ auto constant_def = store->getOperand(0).getDefiningOp();
+ // Expect constant definition operation
+ if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+ mlir::DominanceInfo &di;
+
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+ : OpRewritePattern(ctx), di(_di) {}
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::CallOp callOp,
+ mlir::PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+ auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, module);
+ llvm::SmallVector<mlir::Value> newOperands;
+ llvm::SmallVector<mlir::Operation *> toErase;
+ for (const auto &a : callOp.getArgs()) {
+ if (auto alloca =
+ mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+ if (needsExtrusion(&a)) {
+
+ mlir::Type varTy = alloca.getInType();
+ assert(!fir::hasDynamicSize(varTy) &&
+ "only expect statically sized scalars to be by value");
+
+ // find immediate store with const argument
+ llvm::SmallVector<mlir::Operation *> stores;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+ stores.push_back(s);
+ assert(stores.size() == 1 && "expected exactly one store");
+ LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+ auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+ // Expect constant definition operation or force legalisation of the
+ // callOp and continue with its next argument
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ // unable to remove alloca arg
+ newOperands.push_back(a);
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+ auto loc = callOp.getLoc();
+ llvm::StringRef globalPrefix = "_extruded_";
+
+ std::string globalName;
+ while (!globalName.length() || builder.getNamedGlobal(globalName))
+ globalName =
+ globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+ if (alloca->hasOneUse()) {
+ toErase.push_back(alloca);
+ toErase.push_back(stores[0]);
+ } else {
+ int count = -2;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (di.dominates(stores[0], s))
+ ++count;
+
+ // delete if dominates itself and one more operation (which should
+ // be callOp)
+ if (!count)
+ toErase.push_back(stores[0]);
+ }
+ auto global = builder.createGlobalConstant(
+ loc, varTy, globalName,
+ [&](fir::FirOpBuilder &builder) {
+ mlir::Operation *cln = constant_def->clone();
+ builder.insert(cln);
+ fir::ExtendedValue exv{cln->getResult(0)};
+ mlir::Value valBase = fir::getBase(exv);
+ mlir::Value val = builder.createConvert(loc, varTy, valBase);
+ builder.create<fir::HasValueOp>(loc, val);
+ },
+ builder.createInternalLinkage());
+ mlir::Value ope = {builder.create<fir::AddrOfOp>(
+ loc, global.resultType(), global.getSymbol())};
+ newOperands.push_back(ope);
+ } else {
+ // alloca but without attr, add it
+ newOperands.push_back(a);
+ }
+ } else {
+ // non-alloca operand, add it
+ newOperands.push_back(a);
+ }
+ }
+
+ auto loc = callOp.getLoc();
+ llvm::SmallVector<mlir::Type> newResultTypes;
+ newResultTypes.append(callOp.getResultTypes().begin(),
+ callOp.getResultTypes().end());
+ fir::CallOp newOp = builder.create<fir::CallOp>(
+ loc, newResultTypes,
+ callOp.getCallee().has_value() ? callOp.getCallee().value()
+ : mlir::SymbolRefAttr{},
+ newOperands, callOp.getFastmathAttr());
+ rewriter.replaceOp(callOp, newOp);
+
+ for (auto e : toErase)
+ rewriter.eraseOp(e);
+
+ LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+ << newOp << '\n');
+ return mlir::success();
+ }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+ : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+ mlir::DominanceInfo *di;
+
+public:
+ ConstExtruderOpt() {}
+
+ void runOnOperation() override {
+ mlir::ModuleOp mod = getOperation();
+ di = &getAnalysis<mlir::DominanceInfo>();
+ mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+ }
+
+ void runOnFunc(mlir::func::FuncOp &func) {
+ auto *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ mlir::ConversionTarget target(*context);
+
+ // If func is a declaration, skip it.
+ if (func.empty())
+ return;
+
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::func::FuncDialect>();
+ target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+ for (auto a : op.getArgs()) {
+ if (needsExtrusion(&a))
+ return false;
+ }
+ return true;
+ });
+
+ patterns.insert<CallOpRewriter>(context, *di);
+ if (mlir::failed(
+ mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+ mlir::emitError(func.getLoc(),
+ "error in constant extrusion optimization\n");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+ return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c94b98c7c5805..8e81309209cf7 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,6 +38,7 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
+! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 49b1f8c5c3134..42180d2966cb7 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,6 +65,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 8e1a3d43edd1c..c7bb5cb5406d0 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -72,6 +72,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index dd184d99cb809..517c6054ed7a9 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -72,6 +72,7 @@ func.func @_QQmain() {
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: MemoryAllocationOpt
+// PASSES-NEXT: ConstExtruderOpt
// PASSES-NEXT: Inliner
// PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af0..2ddc0ef525ac4 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
// CHECK-SAME: %[[VAL_0:.*]])
-// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK: store i32 4, ptr %[[VAL_1]], align 4
-// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK: call void %[[VAL_0]](ptr @{{.*}})
func.func @_QPtest_proc_dummy() {
%c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e..b1cfc540f4389 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
subroutine assumed_length_param(n)
character(*), parameter :: c(1)=(/"abcd"/)
integer :: n
- ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
- ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+ ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
call take_int(len(c(n), kind=8))
end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 331e089a60fa0..7c85b7c0a746d 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
! CHECK-LABEL: _QQmain
program test1
- ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
- ! CHECK-DAG: %[[TEN:.*]] = arith.constant
- ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+ ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
! CHECK-NEXT: fir.call @_QFPfoo
call foo(10)
contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index cdc7e6a05288a..0b5311402d51e 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
! CHECK-LABEL: func @_QPtest_proc_dummy_other(
! CHECK-SAME: %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK: %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK: fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
! CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK: fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK: %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK: fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
! CHECK: return
! CHECK: }
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 0000000000000..70cdaf496f34a
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+ implicit none
+ integer x, y
+
+ call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }
>From de6f18ab1433e68de885b2c5835d42b8b7b64a73 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 6 Dec 2023 16:12:14 +0000
Subject: [PATCH 2/7] Use greedy rewriter
---
.../Optimizer/Transforms/ConstExtruder.cpp | 23 +++++++------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 1bb1cd2269871..00c9f30e9dbac 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -16,7 +16,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
protected:
mlir::DominanceInfo *di;
+ mlir::GreedyRewriteConfig config;
public:
- ConstExtruderOpt() {}
+ ConstExtruderOpt() {
+ config.enableRegionSimplification = false;
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+ }
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
void runOnFunc(mlir::func::FuncOp &func) {
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
- mlir::ConversionTarget target(*context);
// If func is a declaration, skip it.
if (func.empty())
return;
- target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
- mlir::func::FuncDialect>();
- target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
- for (auto a : op.getArgs()) {
- if (needsExtrusion(&a))
- return false;
- }
- return true;
- });
-
patterns.insert<CallOpRewriter>(context, *di);
- if (mlir::failed(
- mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ func, std::move(patterns), config))) {
mlir::emitError(func.getLoc(),
"error in constant extrusion optimization\n");
signalPassFailure();
>From b9e420109955d908cc0598701666fd59fd948045 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 12 Dec 2023 14:36:10 +0000
Subject: [PATCH 3/7] Fix review comments
---
.../flang/Optimizer/Transforms/Passes.td | 2 +-
.../Optimizer/Transforms/ConstExtruder.cpp | 238 ++++++++----------
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 2 +-
3 files changed, 111 insertions(+), 131 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index df3b9fc6fb613..12c700718a6bc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -254,7 +254,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
- let summary = "Convert scalar literals of function arguments to global constants.";
+ let summary = "Convert constant function arguments to global constants.";
let description = [{
Convert scalar literals of function arguments to global constants.
}];
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 00c9f30e9dbac..796a6a05a580f 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp -----------------------------------------------===//
+//===- ConstExtruder.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include <atomic>
namespace fir {
#define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,38 +25,16 @@ namespace fir {
#define DEBUG_TYPE "flang-const-extruder-opt"
namespace {
-std::atomic<int> uniqueLitId = 1;
-
-static bool needsExtrusion(const mlir::Value *a) {
- if (!a || !a->getDefiningOp())
- return false;
-
- // is alloca
- if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
- // alloca has annotation
- if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
- for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
- if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
- auto constant_def = store->getOperand(0).getDefiningOp();
- // Expect constant definition operation
- if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
- return true;
- }
- }
- }
- }
- }
- return false;
-}
+unsigned uniqueLitId = 1;
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
- mlir::DominanceInfo &di;
+ const mlir::DominanceInfo &di;
public:
using OpRewritePattern::OpRewritePattern;
- CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+ CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
: OpRewritePattern(ctx), di(_di) {}
mlir::LogicalResult
@@ -68,131 +42,137 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ bool needUpdate = false;
fir::FirOpBuilder builder(rewriter, module);
llvm::SmallVector<mlir::Value> newOperands;
llvm::SmallVector<mlir::Operation *> toErase;
- for (const auto &a : callOp.getArgs()) {
- if (auto alloca =
- mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
- if (needsExtrusion(&a)) {
-
- mlir::Type varTy = alloca.getInType();
- assert(!fir::hasDynamicSize(varTy) &&
- "only expect statically sized scalars to be by value");
-
- // find immediate store with const argument
- llvm::SmallVector<mlir::Operation *> stores;
- for (mlir::Operation *s : alloca.getOperation()->getUsers())
- if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
- stores.push_back(s);
- assert(stores.size() == 1 && "expected exactly one store");
- LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
-
- auto constant_def = stores[0]->getOperand(0).getDefiningOp();
- // Expect constant definition operation or force legalisation of the
- // callOp and continue with its next argument
- if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
- // unable to remove alloca arg
- newOperands.push_back(a);
- continue;
- }
+ for (const mlir::Value &a : callOp.getArgs()) {
+ auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
+ // We can convert arguments that are alloca, and that has
+ // the value by reference attribute. All else is just added
+ // to the argument list.
+ if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+ newOperands.push_back(a);
+ continue;
+ }
- LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
-
- auto loc = callOp.getLoc();
- llvm::StringRef globalPrefix = "_extruded_";
-
- std::string globalName;
- while (!globalName.length() || builder.getNamedGlobal(globalName))
- globalName =
- globalPrefix.str() + "." + std::to_string(uniqueLitId++);
-
- if (alloca->hasOneUse()) {
- toErase.push_back(alloca);
- toErase.push_back(stores[0]);
- } else {
- int count = -2;
- for (mlir::Operation *s : alloca.getOperation()->getUsers())
- if (di.dominates(stores[0], s))
- ++count;
-
- // delete if dominates itself and one more operation (which should
- // be callOp)
- if (!count)
- toErase.push_back(stores[0]);
+ mlir::Type varTy = alloca.getInType();
+ assert(!fir::hasDynamicSize(varTy) &&
+ "only expect statically sized scalars to be by value");
+
+ // Find immediate store with const argument
+ mlir::Operation *store = nullptr;
+ for (mlir::Operation *s : alloca->getUsers()) {
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
+ // We can only deal with ONE store - if already found one,
+ // set to nullptr and exit the loop.
+ if (store) {
+ store = nullptr;
+ break;
}
- auto global = builder.createGlobalConstant(
- loc, varTy, globalName,
- [&](fir::FirOpBuilder &builder) {
- mlir::Operation *cln = constant_def->clone();
- builder.insert(cln);
- fir::ExtendedValue exv{cln->getResult(0)};
- mlir::Value valBase = fir::getBase(exv);
- mlir::Value val = builder.createConvert(loc, varTy, valBase);
- builder.create<fir::HasValueOp>(loc, val);
- },
- builder.createInternalLinkage());
- mlir::Value ope = {builder.create<fir::AddrOfOp>(
- loc, global.resultType(), global.getSymbol())};
- newOperands.push_back(ope);
- } else {
- // alloca but without attr, add it
- newOperands.push_back(a);
+ store = s;
}
- } else {
- // non-alloca operand, add it
+ }
+
+ // If we didn't find one signle store, add argument as is, and move on.
+ if (!store) {
+ newOperands.push_back(a);
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
+
+ mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
+ // Expect constant definition operation or force legalisation of the
+ // callOp and continue with its next argument
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ // Unable to remove alloca arg
newOperands.push_back(a);
+ continue;
}
+
+ LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+ std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+ assert(!builder.getNamedGlobal(globalName) &&
+ "We should have a unique name here");
+
+ unsigned count = 0;
+ for (mlir::Operation *s : alloca->getUsers())
+ if (di.dominates(store, s))
+ ++count;
+
+ // Delete if dominates itself and one more operation (which should
+ // be callOp)
+ if (count == 2)
+ toErase.push_back(store);
+
+ auto loc = callOp.getLoc();
+ fir::GlobalOp global = builder.createGlobalConstant(
+ loc, varTy, globalName,
+ [&](fir::FirOpBuilder &builder) {
+ mlir::Operation *cln = constant_def->clone();
+ builder.insert(cln);
+ mlir::Value val =
+ builder.createConvert(loc, varTy, cln->getResult(0));
+ builder.create<fir::HasValueOp>(loc, val);
+ },
+ builder.createInternalLinkage());
+ mlir::Value addr = {builder.create<fir::AddrOfOp>(
+ loc, global.resultType(), global.getSymbol())};
+ newOperands.push_back(addr);
+ needUpdate = true;
}
- auto loc = callOp.getLoc();
- llvm::SmallVector<mlir::Type> newResultTypes;
- newResultTypes.append(callOp.getResultTypes().begin(),
- callOp.getResultTypes().end());
- fir::CallOp newOp = builder.create<fir::CallOp>(
- loc, newResultTypes,
- callOp.getCallee().has_value() ? callOp.getCallee().value()
- : mlir::SymbolRefAttr{},
- newOperands, callOp.getFastmathAttr());
- rewriter.replaceOp(callOp, newOp);
-
- for (auto e : toErase)
- rewriter.eraseOp(e);
-
- LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
- << newOp << '\n');
- return mlir::success();
+ if (needUpdate) {
+ auto loc = callOp.getLoc();
+ llvm::SmallVector<mlir::Type> newResultTypes;
+ newResultTypes.append(callOp.getResultTypes().begin(),
+ callOp.getResultTypes().end());
+ fir::CallOp newOp = builder.create<fir::CallOp>(
+ loc, newResultTypes,
+ callOp.getCallee().has_value() ? callOp.getCallee().value()
+ : mlir::SymbolRefAttr{},
+ newOperands, callOp.getFastmathAttr());
+ rewriter.replaceOp(callOp, newOp);
+
+ for (auto e : toErase)
+ rewriter.eraseOp(e);
+ LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+ << newOp << '\n');
+ return mlir::success();
+ }
+
+ // Failure here just means "we couldn't do the conversion", which is
+ // perfectly acceptable to the upper layers of this function.
+ return mlir::failure();
}
};
-// This pass attempts to convert immediate scalar literals in function calls
+// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations as Dead Argument Elimination
class ConstExtruderOpt
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
-protected:
- mlir::DominanceInfo *di;
- mlir::GreedyRewriteConfig config;
-
public:
- ConstExtruderOpt() {
- config.enableRegionSimplification = false;
- config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
- }
+ ConstExtruderOpt() = default;
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
- di = &getAnalysis<mlir::DominanceInfo>();
- mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+ mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
+ mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
}
- void runOnFunc(mlir::func::FuncOp &func) {
- auto *context = &getContext();
- mlir::RewritePatternSet patterns(context);
-
+ void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo *di) {
// If func is a declaration, skip it.
if (func.empty())
return;
+ auto *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ mlir::GreedyRewriteConfig config;
+ config.enableRegionSimplification = false;
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 8e81309209cf7..28c06826ef415 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -38,8 +38,8 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
-! CHECK-NEXT: ConstExtruderOpt
+! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: CSE
>From 390c0ae13c5a242150847e68b1924cf74f2fe0a9 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 13 Dec 2023 15:35:24 +0000
Subject: [PATCH 4/7] Rename and add option to disable
---
.../flang/Optimizer/Transforms/Passes.td | 4 +-
flang/include/flang/Tools/CLOptions.inc | 4 ++
flang/lib/Optimizer/Transforms/CMakeLists.txt | 2 +-
....cpp => ConstantArgumentGlobalisation.cpp} | 13 +++---
flang/test/Transforms/const-extrude.f90 | 32 -------------
.../constant-argument-globalisation.fir | 45 +++++++++++++++++++
6 files changed, 59 insertions(+), 41 deletions(-)
rename flang/lib/Optimizer/Transforms/{ConstExtruder.cpp => ConstantArgumentGlobalisation.cpp} (94%)
delete mode 100644 flang/test/Transforms/const-extrude.f90
create mode 100644 flang/test/Transforms/constant-argument-globalisation.fir
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 12c700718a6bc..479272798f176 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -253,13 +253,13 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
}
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
-def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+def ConstantArgumentGlobalisationOpt : Pass<"constant-argument-globalisation-opt", "mlir::ModuleOp"> {
let summary = "Convert constant function arguments to global constants.";
let description = [{
Convert scalar literals of function arguments to global constants.
}];
let dependentDialects = [ "fir::FIROpsDialect" ];
- let constructor = "::fir::createConstExtruderPass()";
+ let constructor = "::fir::createConstantArgumentGlobalisationPass()";
}
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d2bf79d8789cc..d709efc9a34e7 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -86,6 +86,8 @@ DisableOption(BoxedProcedureRewrite, "boxed-procedure-rewrite",
DisableOption(ExternalNameConversion, "external-name-interop",
"convert names with external convention");
+DisableOption(ConstantArgumentGlobalisation, "constant-argument-globalisation",
+ "disable the local constants to global constant conversion");
using PassConstructor = std::unique_ptr<mlir::Pass>();
@@ -272,6 +274,8 @@ inline void createDefaultFIROptimizerPassPipeline(
// These passes may increase code size.
pm.addPass(fir::createSimplifyIntrinsics());
pm.addPass(fir::createAlgebraicSimplificationPass(config));
+ if (!disableConstantArgumentGlobalisation)
+ pm.addPass(fir::createConstantArgumentGlobalisationPass());
}
if (pc.LoopVersioning)
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 800fe44dfdc11..cdf5ca368ebd1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -6,7 +6,7 @@ add_flang_library(FIRTransforms
AnnotateConstant.cpp
AssumedRankOpConversion.cpp
CharacterConversion.cpp
- ConstExtruder.cpp
+ ConstantArgumentGlobalisation.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
similarity index 94%
rename from flang/lib/Optimizer/Transforms/ConstExtruder.cpp
rename to flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
index 796a6a05a580f..2859a57226f16 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp
@@ -18,7 +18,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace fir {
-#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#define GEN_PASS_DEF_CONSTANTARGUMENTGLOBALISATIONOPT
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
@@ -151,10 +151,11 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations as Dead Argument Elimination
-class ConstExtruderOpt
- : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+class ConstantArgumentGlobalisationOpt
+ : public fir::impl::ConstantArgumentGlobalisationOptBase<
+ ConstantArgumentGlobalisationOpt> {
public:
- ConstExtruderOpt() = default;
+ ConstantArgumentGlobalisationOpt() = default;
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
@@ -184,6 +185,6 @@ class ConstExtruderOpt
};
} // namespace
-std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
- return std::make_unique<ConstExtruderOpt>();
+std::unique_ptr<mlir::Pass> fir::createConstantArgumentGlobalisationPass() {
+ return std::make_unique<ConstantArgumentGlobalisationOpt>();
}
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
deleted file mode 100644
index 70cdaf496f34a..0000000000000
--- a/flang/test/Transforms/const-extrude.f90
+++ /dev/null
@@ -1,32 +0,0 @@
-! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
-
-subroutine sub1(x,y)
- implicit none
- integer x, y
-
- call sub2(0.0d0, 1.0d0, x, y, 1)
-end subroutine sub1
-
-!CHECK-LABEL: func.func @_QPsub1
-!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
-!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
-!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
-!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
-!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
-!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
-!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
-!CHECK: return
-
-!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
-!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
-!CHECK: fir.has_value %{{.*}} : f64
-!CHECK: }
-!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
-!CHECK: %{{.*}} = arith.constant 1 : i32
-!CHECK: fir.has_value %{{.*}} : i32
-!CHECK: }
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
new file mode 100644
index 0000000000000..328191ea32512
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -0,0 +1,45 @@
+// RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+module {
+ func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
+ %0 = fir.alloca i32 {adapt.valuebyref}
+ %1 = fir.alloca f64 {adapt.valuebyref}
+ %2 = fir.alloca f64 {adapt.valuebyref}
+ %c1_i32 = arith.constant 1 : i32
+ %cst = arith.constant 1.000000e+00 : f64
+ %cst_0 = arith.constant 0.000000e+00 : f64
+ %3 = fir.declare %arg0 {uniq_name = "_QFsub1Ex"} : (!fir.ref<i32>) -> !fir.ref<i32>
+ %4 = fir.declare %arg1 {uniq_name = "_QFsub1Ey"} : (!fir.ref<i32>) -> !fir.ref<i32>
+ fir.store %cst_0 to %2 : !fir.ref<f64>
+ %false = arith.constant false
+ fir.store %cst to %1 : !fir.ref<f64>
+ %false_1 = arith.constant false
+ fir.store %c1_i32 to %0 : !fir.ref<i32>
+ %false_2 = arith.constant false
+ fir.call @sub2(%2, %1, %3, %4, %0) fastmath<contract> : (!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>) -> ()
+ return
+ }
+ func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
+}
+// CHECK-LABEL: func.func @sub1
+// CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+// CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+// CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+// CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+// CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+// CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+// CHECK: return
+
+// CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+// CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+// CHECK: fir.has_value %{{.*}} : f64
+// CHECK: }
+// CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }
>From 9e038ff11d35bf8253d7d8bba3c72fd6340f3fbc Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Thu, 14 Dec 2023 18:50:04 +0000
Subject: [PATCH 5/7] Add more testing
---
.../constant-argument-globalisation-2.fir | 80 +++++++++++++++++++
.../constant-argument-globalisation.fir | 25 +++++-
2 files changed, 103 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Transforms/constant-argument-globalisation-2.fir
diff --git a/flang/test/Transforms/constant-argument-globalisation-2.fir b/flang/test/Transforms/constant-argument-globalisation-2.fir
new file mode 100644
index 0000000000000..03855b5bfb762
--- /dev/null
+++ b/flang/test/Transforms/constant-argument-globalisation-2.fir
@@ -0,0 +1,80 @@
+// RUN: fir-opt --split-input-file --constant-argument-globalisation-opt < %s | FileCheck %s
+
+module {
+// Test for "two conditional writes to the same alloca doesn't get replaced."
+ func.func @func(%arg0: i32, %arg1: i1) {
+ %c2_i32 = arith.constant 2 : i32
+ %addr = fir.alloca i32 {adapt.valuebyref}
+ fir.if %arg1 {
+ fir.store %c2_i32 to %addr : !fir.ref<i32>
+ } else {
+ fir.store %arg0 to %addr : !fir.ref<i32>
+ }
+ fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+ return
+ }
+ func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK-SAME: [[ARG0:%.*]]: i32
+// CHECK-SAME: [[ARG1:%.*]]: i1)
+// CHECK: [[CONST:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.if [[ARG1]]
+// CHECK: fir.store [[CONST]] to [[ADDR]]
+// CHECK: } else {
+// CHECK: fir.store [[ARG0]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "two writes to the same alloca doesn't get replaced."
+ func.func @func() {
+ %c1_i32 = arith.constant 1 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %addr = fir.alloca i32 {adapt.valuebyref}
+ fir.store %c1_i32 to %addr : !fir.ref<i32>
+ fir.store %c2_i32 to %addr : !fir.ref<i32>
+ fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+ return
+ }
+ func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[CONST1:%.*]] = arith.constant
+// CHECK: [[CONST2:%.*]] = arith.constant
+// CHECK: [[ADDR:%.*]] = fir.alloca i32
+// CHECK: fir.store [[CONST1]] to [[ADDR]]
+// CHECK: fir.store [[CONST2]] to [[ADDR]]
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+
+}
+
+// -----
+
+module {
+// Test for "one write to the the alloca gets replaced."
+ func.func @func() {
+ %c1_i32 = arith.constant 1 : i32
+ %addr = fir.alloca i32 {adapt.valuebyref}
+ fir.store %c1_i32 to %addr : !fir.ref<i32>
+ fir.call @sub2(%addr) : (!fir.ref<i32>) -> ()
+ return
+ }
+ func.func private @sub2(!fir.ref<i32>)
+
+// CHECK-LABEL: func.func @func
+// CHECK: [[ADDR:%.*]] = fir.address_of([[EXTR:@.*]]) : !fir.ref<i32>
+// CHECK: fir.call @sub2([[ADDR]])
+// CHECK: return
+// CHECK: fir.global internal [[EXTR]] constant : i32 {
+// CHECK: %{{.*}} = arith.constant 1 : i32
+// CHECK: fir.has_value %{{.*}} : i32
+// CHECK: }
+
+}
diff --git a/flang/test/Transforms/constant-argument-globalisation.fir b/flang/test/Transforms/constant-argument-globalisation.fir
index 328191ea32512..1598f303755cb 100644
--- a/flang/test/Transforms/constant-argument-globalisation.fir
+++ b/flang/test/Transforms/constant-argument-globalisation.fir
@@ -1,4 +1,5 @@
// RUN: fir-opt --constant-argument-globalisation-opt < %s | FileCheck %s
+// RUN: %flang_fc1 -emit-llvm -flang-deprecated-no-hlfir -O2 -mllvm --disable-constant-argument-globalisation -o - %s | FileCheck --check-prefix=DISABLE %s
module {
func.func @sub1(%arg0: !fir.ref<i32> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "y"}) {
%0 = fir.alloca i32 {adapt.valuebyref}
@@ -19,8 +20,8 @@ module {
return
}
func.func private @sub2(!fir.ref<f64>, !fir.ref<f64>, !fir.ref<i32>, !fir.ref<i32>, !fir.ref<i32>)
-}
-// CHECK-LABEL: func.func @sub1
+
+// CHECK-LABEL: func.func @sub1(
// CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
// CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
// CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
@@ -43,3 +44,23 @@ module {
// CHECK: %{{.*}} = arith.constant 1 : i32
// CHECK: fir.has_value %{{.*}} : i32
// CHECK: }
+
+// DISABLE-LABEL: ; ModuleID =
+// DISABLE-NOT: @_extruded
+// DISABLE: define void @sub1(
+// DISABLE-SAME: ptr [[ARG0:%.*]],
+// DISABLE-SAME: ptr [[ARG1:%.*]])
+// DISABLE-SMAE: {
+// DISABLE: [[CONST_I:%.*]] = alloca i32
+// DISABLE: [[CONST_R1:%.*]] = alloca double
+// DISABLE: [[CONST_R0:%.*]] = alloca double
+// DISABLE: store double 0.0{{.*}}+00, ptr [[CONST_R0]]
+// DISABLE: store double 1.0{{.*}}+00, ptr [[CONST_R1]]
+// DISABLE: store i32 1, ptr [[CONST_I]]
+// DISABLE: call void @sub2(ptr nonnull [[CONST_R0]],
+// DISABLE-SAME: ptr nonnull [[CONST_R1]],
+// DISABLE-SAME: ptr [[ARG0]], ptr [[ARG1]],
+// DISABLE-SAME: ptr nonnull [[CONST_I]])
+// DISABLE: ret void
+// DISABLE: }
+}
>From 732c393746922f2327ce3fa822f3b8402d4b930b Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Fri, 15 Dec 2023 11:44:01 +0000
Subject: [PATCH 6/7] Fix tests
---
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 2 +-
flang/test/Driver/mlir-debug-pass-pipeline.f90 | 1 -
flang/test/Driver/mlir-pass-pipeline.f90 | 2 +-
flang/test/Fir/basic-program.fir | 2 +-
4 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 28c06826ef415..3a4e17f16dc81 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -32,6 +32,7 @@
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: SimplifyIntrinsics
! CHECK-NEXT: AlgebraicSimplification
+! CHECK-NEXT: ConstantArgumentGlobalisationOpt
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
@@ -39,7 +40,6 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
-! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: CSE
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42180d2966cb7..49b1f8c5c3134 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -65,7 +65,6 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
-! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index c7bb5cb5406d0..43461b76cf2ab 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -66,13 +66,13 @@
! ALL-NEXT: SimplifyRegionLite
! O2-NEXT: SimplifyIntrinsics
! O2-NEXT: AlgebraicSimplification
+! O2-NEXT: ConstantArgumentGlobalisationOpt
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
-! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 517c6054ed7a9..0a72d400ef410 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -66,13 +66,13 @@ func.func @_QQmain() {
// PASSES-NEXT: SimplifyRegionLite
// PASSES-NEXT: SimplifyIntrinsics
// PASSES-NEXT: AlgebraicSimplification
+// PASSES-NEXT: ConstantArgumentGlobalisationOpt
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: MemoryAllocationOpt
-// PASSES-NEXT: ConstExtruderOpt
// PASSES-NEXT: Inliner
// PASSES-NEXT: SimplifyRegionLite
>From 4ca4e1ff53e134a6a39f1944e27fe0fe3c3ce4a5 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Mon, 10 Jun 2024 18:04:38 +0100
Subject: [PATCH 7/7] Fix rebase issues
---
flang/include/flang/Optimizer/Transforms/Passes.h | 4 ++++
flang/include/flang/Tools/CLOptions.inc | 4 ++++
2 files changed, 8 insertions(+)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index a7ba704fdb39b..e3bb818a285fc 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -54,6 +54,8 @@ namespace fir {
#define GEN_PASS_DECL_OMPMAPINFOFINALIZATIONPASS
#define GEN_PASS_DECL_OMPMARKDECLARETARGETPASS
#define GEN_PASS_DECL_OMPFUNCTIONFILTERING
+#define GEN_PASS_DECL_CONSTANTARGUMENTGLOBALISATIONOPT
+
#include "flang/Optimizer/Transforms/Passes.h.inc"
std::unique_ptr<mlir::Pass> createAffineDemotionPass();
@@ -77,6 +79,8 @@ std::unique_ptr<mlir::Pass> createVScaleAttrPass();
std::unique_ptr<mlir::Pass>
createVScaleAttrPass(std::pair<unsigned, unsigned> vscaleAttr);
+std::unique_ptr<mlir::Pass> createConstantArgumentGlobalisationPass();
+
struct FunctionAttrTypes {
mlir::LLVM::framePointerKind::FramePointerKind framePointerKind =
mlir::LLVM::framePointerKind::FramePointerKind::None;
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d709efc9a34e7..26b1ad0366642 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -288,6 +288,10 @@ inline void createDefaultFIROptimizerPassPipeline(
else
fir::addMemoryAllocationOpt(pm);
+ // FIR Inliner Callback
+ pc.invokeFIRInlinerCallback(pm, pc.OptLevel);
+
+ pm.addPass(fir::createSimplifyRegionLite());
pm.addPass(mlir::createCSEPass());
// Polymorphic types
More information about the flang-commits
mailing list