[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)
Mats Petersson via flang-commits
flang-commits at lists.llvm.org
Wed Dec 13 04:08:40 PST 2023
https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/73829
>From ea5e93a06bd32cad0ccf402ecccfe9451e2d6c9b Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 22 Nov 2023 14:01:38 +0000
Subject: [PATCH 1/3] [Flang] Extracting internal constants from scalar
literals
Constants actual arguments in function/subroutine calls are currently lowered
as allocas + store. This can sometimes inhibit LTO and the constant will not
be propagated to the called function. Particularly in cases where the
function/subroutine call happens inside a condition.
This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.
Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
.../flang/Optimizer/Transforms/Passes.h | 1 +
.../flang/Optimizer/Transforms/Passes.td | 10 +
flang/include/flang/Tools/CLOptions.inc | 2 +
flang/lib/Optimizer/Transforms/CMakeLists.txt | 1 +
.../Optimizer/Transforms/ConstExtruder.cpp | 216 ++++++++++++++++++
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 1 +
.../test/Driver/mlir-debug-pass-pipeline.f90 | 1 +
flang/test/Driver/mlir-pass-pipeline.f90 | 1 +
flang/test/Fir/basic-program.fir | 1 +
flang/test/Fir/boxproc.fir | 4 +-
.../test/Lower/character-local-variables.f90 | 3 +-
flang/test/Lower/dummy-arguments.f90 | 4 +-
flang/test/Lower/host-associated.f90 | 7 +-
flang/test/Transforms/const-extrude.f90 | 32 +++
14 files changed, 272 insertions(+), 12 deletions(-)
create mode 100644 flang/lib/Optimizer/Transforms/ConstExtruder.cpp
create mode 100644 flang/test/Transforms/const-extrude.f90
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 92bc7246eca700..f1c38a02666024 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -60,6 +60,7 @@ createExternalNameConversionPass(bool appendUnderscore);
std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
+std::unique_ptr<mlir::Pass> createConstExtruderPass();
std::unique_ptr<mlir::Pass> createStackArraysPass();
std::unique_ptr<mlir::Pass> createAliasTagsPass();
std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c3768fd2d689c1..179833876a7b33 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -242,6 +242,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
let constructor = "::fir::createMemoryAllocationPass()";
}
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+ let summary = "Convert scalar literals of function arguments to global constants.";
+ let description = [{
+ Convert scalar literals of function arguments to global constants.
+ }];
+ let dependentDialects = [ "fir::FIROpsDialect" ];
+ let constructor = "::fir::createConstExtruderPass()";
+}
+
def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
let summary = "Move local array allocations from heap memory into stack memory";
let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d3e4dc6cd4a243..b902621dfe4217 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -216,6 +216,8 @@ inline void createDefaultFIROptimizerPassPipeline(
else
fir::addMemoryAllocationOpt(pm);
+ pm.addPass(fir::createConstExtruderPass());
+
// The default inliner pass adds the canonicalizer pass with the default
// configuration. Create the inliner pass with tco config.
llvm::StringMap<mlir::OpPassManager> pipelines;
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 03b67104a93b57..bada67729ede95 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
AffineDemotion.cpp
AnnotateConstant.cpp
CharacterConversion.cpp
+ ConstExtruder.cpp
ControlFlowConverter.cpp
ArrayValueCopy.cpp
ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 00000000000000..1bb1cd22698711
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+ if (!a || !a->getDefiningOp())
+ return false;
+
+ // is alloca
+ if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+ // alloca has annotation
+ if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+ for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+ if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+ auto constant_def = store->getOperand(0).getDefiningOp();
+ // Expect constant definition operation
+ if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+ mlir::DominanceInfo &di;
+
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+ : OpRewritePattern(ctx), di(_di) {}
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::CallOp callOp,
+ mlir::PatternRewriter &rewriter) const override {
+ LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+ auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ fir::FirOpBuilder builder(rewriter, module);
+ llvm::SmallVector<mlir::Value> newOperands;
+ llvm::SmallVector<mlir::Operation *> toErase;
+ for (const auto &a : callOp.getArgs()) {
+ if (auto alloca =
+ mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+ if (needsExtrusion(&a)) {
+
+ mlir::Type varTy = alloca.getInType();
+ assert(!fir::hasDynamicSize(varTy) &&
+ "only expect statically sized scalars to be by value");
+
+ // find immediate store with const argument
+ llvm::SmallVector<mlir::Operation *> stores;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+ stores.push_back(s);
+ assert(stores.size() == 1 && "expected exactly one store");
+ LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+ auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+ // Expect constant definition operation or force legalisation of the
+ // callOp and continue with its next argument
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ // unable to remove alloca arg
+ newOperands.push_back(a);
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+ auto loc = callOp.getLoc();
+ llvm::StringRef globalPrefix = "_extruded_";
+
+ std::string globalName;
+ while (!globalName.length() || builder.getNamedGlobal(globalName))
+ globalName =
+ globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+ if (alloca->hasOneUse()) {
+ toErase.push_back(alloca);
+ toErase.push_back(stores[0]);
+ } else {
+ int count = -2;
+ for (mlir::Operation *s : alloca.getOperation()->getUsers())
+ if (di.dominates(stores[0], s))
+ ++count;
+
+ // delete if dominates itself and one more operation (which should
+ // be callOp)
+ if (!count)
+ toErase.push_back(stores[0]);
+ }
+ auto global = builder.createGlobalConstant(
+ loc, varTy, globalName,
+ [&](fir::FirOpBuilder &builder) {
+ mlir::Operation *cln = constant_def->clone();
+ builder.insert(cln);
+ fir::ExtendedValue exv{cln->getResult(0)};
+ mlir::Value valBase = fir::getBase(exv);
+ mlir::Value val = builder.createConvert(loc, varTy, valBase);
+ builder.create<fir::HasValueOp>(loc, val);
+ },
+ builder.createInternalLinkage());
+ mlir::Value ope = {builder.create<fir::AddrOfOp>(
+ loc, global.resultType(), global.getSymbol())};
+ newOperands.push_back(ope);
+ } else {
+ // alloca but without attr, add it
+ newOperands.push_back(a);
+ }
+ } else {
+ // non-alloca operand, add it
+ newOperands.push_back(a);
+ }
+ }
+
+ auto loc = callOp.getLoc();
+ llvm::SmallVector<mlir::Type> newResultTypes;
+ newResultTypes.append(callOp.getResultTypes().begin(),
+ callOp.getResultTypes().end());
+ fir::CallOp newOp = builder.create<fir::CallOp>(
+ loc, newResultTypes,
+ callOp.getCallee().has_value() ? callOp.getCallee().value()
+ : mlir::SymbolRefAttr{},
+ newOperands, callOp.getFastmathAttr());
+ rewriter.replaceOp(callOp, newOp);
+
+ for (auto e : toErase)
+ rewriter.eraseOp(e);
+
+ LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+ << newOp << '\n');
+ return mlir::success();
+ }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+ : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+ mlir::DominanceInfo *di;
+
+public:
+ ConstExtruderOpt() {}
+
+ void runOnOperation() override {
+ mlir::ModuleOp mod = getOperation();
+ di = &getAnalysis<mlir::DominanceInfo>();
+ mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+ }
+
+ void runOnFunc(mlir::func::FuncOp &func) {
+ auto *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ mlir::ConversionTarget target(*context);
+
+ // If func is a declaration, skip it.
+ if (func.empty())
+ return;
+
+ target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+ mlir::func::FuncDialect>();
+ target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+ for (auto a : op.getArgs()) {
+ if (needsExtrusion(&a))
+ return false;
+ }
+ return true;
+ });
+
+ patterns.insert<CallOpRewriter>(context, *di);
+ if (mlir::failed(
+ mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+ mlir::emitError(func.getLoc(),
+ "error in constant extrusion optimization\n");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+ return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 243a620a9fd003..c43149e07bf55f 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,6 +31,7 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
+! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index a3ff416f4d7795..5eb2354b67012e 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -51,6 +51,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 3d8c42f123e2eb..10bc9b90cfd769 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -42,6 +42,7 @@
! ALL-NEXT: 'func.func' Pipeline
! ALL-NEXT: MemoryAllocationOpt
+! ALL-NEXT: ConstExtruderOpt
! ALL-NEXT: Inliner
! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index d8a9e74c318ce1..b9aabd322399ca 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -48,6 +48,7 @@ func.func @_QQmain() {
// PASSES-NEXT: 'func.func' Pipeline
// PASSES-NEXT: MemoryAllocationOpt
+// PASSES-NEXT: ConstExtruderOpt
// PASSES-NEXT: Inliner
// PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af04..2ddc0ef525ac48 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
// CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
// CHECK-SAME: %[[VAL_0:.*]])
-// CHECK: %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK: store i32 4, ptr %[[VAL_1]], align 4
-// CHECK: call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK: call void %[[VAL_0]](ptr @{{.*}})
func.func @_QPtest_proc_dummy() {
%c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e7..b1cfc540f43896 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
subroutine assumed_length_param(n)
character(*), parameter :: c(1)=(/"abcd"/)
integer :: n
- ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
- ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+ ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
call take_int(len(c(n), kind=8))
end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 43d8e3c1e5d448..46e4323e886204 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
! CHECK-LABEL: _QQmain
program test1
- ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
- ! CHECK-DAG: %[[TEN:.*]] = arith.constant
- ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+ ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
! CHECK-NEXT: fir.call @_QFPfoo
call foo(10)
contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index e2db6deb8803d0..26598ef1f16ea3 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
! CHECK-LABEL: func @_QPtest_proc_dummy_other(
! CHECK-SAME: %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK: %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK: fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
! CHECK: %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK: fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK: %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK: fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
! CHECK: return
! CHECK: }
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 00000000000000..70cdaf496f34ac
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+ implicit none
+ integer x, y
+
+ call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }
>From 6a3c3ba62e08d0be8dbf9293b7f89bfb6dd09389 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 6 Dec 2023 16:12:14 +0000
Subject: [PATCH 2/3] Use greedy rewriter
---
.../Optimizer/Transforms/ConstExtruder.cpp | 23 +++++++------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 1bb1cd22698711..00c9f30e9dbacc 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -16,7 +16,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
#include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
protected:
mlir::DominanceInfo *di;
+ mlir::GreedyRewriteConfig config;
public:
- ConstExtruderOpt() {}
+ ConstExtruderOpt() {
+ config.enableRegionSimplification = false;
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+ }
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
void runOnFunc(mlir::func::FuncOp &func) {
auto *context = &getContext();
mlir::RewritePatternSet patterns(context);
- mlir::ConversionTarget target(*context);
// If func is a declaration, skip it.
if (func.empty())
return;
- target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
- mlir::func::FuncDialect>();
- target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
- for (auto a : op.getArgs()) {
- if (needsExtrusion(&a))
- return false;
- }
- return true;
- });
-
patterns.insert<CallOpRewriter>(context, *di);
- if (mlir::failed(
- mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+ if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+ func, std::move(patterns), config))) {
mlir::emitError(func.getLoc(),
"error in constant extrusion optimization\n");
signalPassFailure();
>From f936dde64196dcd0004238525a12fd84b1b3838f Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 12 Dec 2023 14:36:10 +0000
Subject: [PATCH 3/3] Fix review comments
---
.../flang/Optimizer/Transforms/Passes.td | 2 +-
.../Optimizer/Transforms/ConstExtruder.cpp | 239 ++++++++----------
flang/test/Driver/bbc-mlir-pass-pipeline.f90 | 2 +-
3 files changed, 111 insertions(+), 132 deletions(-)
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 179833876a7b33..1ed8c90830d08e 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -244,7 +244,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
- let summary = "Convert scalar literals of function arguments to global constants.";
+ let summary = "Convert constant function arguments to global constants.";
let description = [{
Convert scalar literals of function arguments to global constants.
}];
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 00c9f30e9dbacc..355166f3422bb8 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp -----------------------------------------------===//
+//===- ConstExtruder.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include <atomic>
namespace fir {
#define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,38 +25,16 @@ namespace fir {
#define DEBUG_TYPE "flang-const-extruder-opt"
namespace {
-std::atomic<int> uniqueLitId = 1;
-
-static bool needsExtrusion(const mlir::Value *a) {
- if (!a || !a->getDefiningOp())
- return false;
-
- // is alloca
- if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
- // alloca has annotation
- if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
- for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
- if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
- auto constant_def = store->getOperand(0).getDefiningOp();
- // Expect constant definition operation
- if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
- return true;
- }
- }
- }
- }
- }
- return false;
-}
+unsigned uniqueLitId = 1;
class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
protected:
- mlir::DominanceInfo &di;
+ const mlir::DominanceInfo &di;
public:
using OpRewritePattern::OpRewritePattern;
- CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+ CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
: OpRewritePattern(ctx), di(_di) {}
mlir::LogicalResult
@@ -68,131 +42,136 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
mlir::PatternRewriter &rewriter) const override {
LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ bool needUpdate = false;
fir::FirOpBuilder builder(rewriter, module);
llvm::SmallVector<mlir::Value> newOperands;
llvm::SmallVector<mlir::Operation *> toErase;
- for (const auto &a : callOp.getArgs()) {
- if (auto alloca =
- mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
- if (needsExtrusion(&a)) {
-
- mlir::Type varTy = alloca.getInType();
- assert(!fir::hasDynamicSize(varTy) &&
- "only expect statically sized scalars to be by value");
-
- // find immediate store with const argument
- llvm::SmallVector<mlir::Operation *> stores;
- for (mlir::Operation *s : alloca.getOperation()->getUsers())
- if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
- stores.push_back(s);
- assert(stores.size() == 1 && "expected exactly one store");
- LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
-
- auto constant_def = stores[0]->getOperand(0).getDefiningOp();
- // Expect constant definition operation or force legalisation of the
- // callOp and continue with its next argument
- if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
- // unable to remove alloca arg
- newOperands.push_back(a);
- continue;
- }
-
- LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
-
- auto loc = callOp.getLoc();
- llvm::StringRef globalPrefix = "_extruded_";
-
- std::string globalName;
- while (!globalName.length() || builder.getNamedGlobal(globalName))
- globalName =
- globalPrefix.str() + "." + std::to_string(uniqueLitId++);
-
- if (alloca->hasOneUse()) {
- toErase.push_back(alloca);
- toErase.push_back(stores[0]);
- } else {
- int count = -2;
- for (mlir::Operation *s : alloca.getOperation()->getUsers())
- if (di.dominates(stores[0], s))
- ++count;
-
- // delete if dominates itself and one more operation (which should
- // be callOp)
- if (!count)
- toErase.push_back(stores[0]);
+ for (const mlir::Value &a : callOp.getArgs()) {
+ auto alloca =
+ mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
+ // We can convert arguments that are alloca, and that has
+ // the value by reference attribute. All else is just added
+ // to the argument list.
+ if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+ newOperands.push_back(a);
+ continue;
+ }
+
+ mlir::Type varTy = alloca.getInType();
+ assert(!fir::hasDynamicSize(varTy) &&
+ "only expect statically sized scalars to be by value");
+
+ // Find immediate store with const argument
+ mlir::Operation *store = nullptr;
+ for (mlir::Operation *s : alloca->getUsers()) {
+ if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
+ // We can only deal with ONE store - if already found one,
+ // set to nullptr and exit the loop.
+ if (store) {
+ store = nullptr;
+ break;
}
- auto global = builder.createGlobalConstant(
- loc, varTy, globalName,
- [&](fir::FirOpBuilder &builder) {
- mlir::Operation *cln = constant_def->clone();
- builder.insert(cln);
- fir::ExtendedValue exv{cln->getResult(0)};
- mlir::Value valBase = fir::getBase(exv);
- mlir::Value val = builder.createConvert(loc, varTy, valBase);
- builder.create<fir::HasValueOp>(loc, val);
- },
- builder.createInternalLinkage());
- mlir::Value ope = {builder.create<fir::AddrOfOp>(
- loc, global.resultType(), global.getSymbol())};
- newOperands.push_back(ope);
- } else {
- // alloca but without attr, add it
- newOperands.push_back(a);
+ store = s;
}
- } else {
- // non-alloca operand, add it
+ }
+
+ // If we didn't find one signle store, add argument as is, and move on.
+ if (!store) {
newOperands.push_back(a);
+ continue;
}
+
+ LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
+
+ mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
+ // Expect constant definition operation or force legalisation of the
+ // callOp and continue with its next argument
+ if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+ // Unable to remove alloca arg
+ newOperands.push_back(a);
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+ std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+ assert(!builder.getNamedGlobal(globalName) && "We should have a unique name here");
+
+ unsigned count = 0;
+ for (mlir::Operation *s : alloca->getUsers())
+ if (di.dominates(store, s))
+ ++count;
+
+ // Delete if dominates itself and one more operation (which should
+ // be callOp)
+ if (count == 2)
+ toErase.push_back(store);
+
+ auto loc = callOp.getLoc();
+ fir::GlobalOp global = builder.createGlobalConstant(
+ loc, varTy, globalName,
+ [&](fir::FirOpBuilder &builder) {
+ mlir::Operation *cln = constant_def->clone();
+ builder.insert(cln);
+ mlir::Value val = builder.createConvert(loc, varTy, cln->getResult(0));
+ builder.create<fir::HasValueOp>(loc, val);
+ },
+ builder.createInternalLinkage());
+ mlir::Value addr = {builder.create<fir::AddrOfOp>(
+ loc, global.resultType(), global.getSymbol())};
+ newOperands.push_back(addr);
+ needUpdate = true;
}
- auto loc = callOp.getLoc();
- llvm::SmallVector<mlir::Type> newResultTypes;
- newResultTypes.append(callOp.getResultTypes().begin(),
- callOp.getResultTypes().end());
- fir::CallOp newOp = builder.create<fir::CallOp>(
- loc, newResultTypes,
- callOp.getCallee().has_value() ? callOp.getCallee().value()
- : mlir::SymbolRefAttr{},
- newOperands, callOp.getFastmathAttr());
- rewriter.replaceOp(callOp, newOp);
-
- for (auto e : toErase)
- rewriter.eraseOp(e);
-
- LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
- << newOp << '\n');
- return mlir::success();
+ if (needUpdate) {
+ auto loc = callOp.getLoc();
+ llvm::SmallVector<mlir::Type> newResultTypes;
+ newResultTypes.append(callOp.getResultTypes().begin(),
+ callOp.getResultTypes().end());
+ fir::CallOp newOp = builder.create<fir::CallOp>(
+ loc, newResultTypes,
+ callOp.getCallee().has_value() ? callOp.getCallee().value()
+ : mlir::SymbolRefAttr{},
+ newOperands, callOp.getFastmathAttr());
+ rewriter.replaceOp(callOp, newOp);
+
+ for (auto e : toErase)
+ rewriter.eraseOp(e);
+ LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+ << newOp << '\n');
+ return mlir::success();
+ }
+
+ // Failure here just means "we couldn't do the conversion", which is
+ // perfectly acceptable to the upper layers of this function.
+ return mlir::failure();
}
};
-// This pass attempts to convert immediate scalar literals in function calls
+// this pass attempts to convert immediate scalar literals in function calls
// to global constants to allow transformations as Dead Argument Elimination
class ConstExtruderOpt
: public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
-protected:
- mlir::DominanceInfo *di;
- mlir::GreedyRewriteConfig config;
-
public:
- ConstExtruderOpt() {
- config.enableRegionSimplification = false;
- config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
- }
+ ConstExtruderOpt() = default;
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
- di = &getAnalysis<mlir::DominanceInfo>();
- mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+ mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
+ mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
}
- void runOnFunc(mlir::func::FuncOp &func) {
- auto *context = &getContext();
- mlir::RewritePatternSet patterns(context);
-
+ void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo* di) {
// If func is a declaration, skip it.
if (func.empty())
return;
+ auto *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ mlir::GreedyRewriteConfig config;
+ config.enableRegionSimplification = false;
+ config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
func, std::move(patterns), config))) {
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c43149e07bf55f..942ee76d5155e1 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,8 +31,8 @@
! CHECK-NEXT: 'func.func' Pipeline
! CHECK-NEXT: MemoryAllocationOpt
-! CHECK-NEXT: ConstExtruderOpt
+! CHECK-NEXT: ConstExtruderOpt
! CHECK-NEXT: Inliner
! CHECK-NEXT: SimplifyRegionLite
! CHECK-NEXT: CSE
More information about the flang-commits
mailing list