[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)

via flang-commits flang-commits at lists.llvm.org
Wed Nov 29 09:36:28 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-driver

Author: Mats Petersson (Leporacanthicus)

<details>
<summary>Changes</summary>

Constants actual arguments in function/subroutine calls are currently lowered as allocas + store. This can sometimes inhibit LTO and the constant will not be propagated to the called function. Particularly in cases where the function/subroutine call happens inside a condition.

This patch changes the lowering of these constant actual arguments to a global constant + fir.address_of_op. This lowering makes it easier for LTO to propagate the constant.

---
Full diff: https://github.com/llvm/llvm-project/pull/73829.diff


14 Files Affected:

- (modified) flang/include/flang/Optimizer/Transforms/Passes.h (+1) 
- (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+10) 
- (modified) flang/include/flang/Tools/CLOptions.inc (+2) 
- (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1) 
- (added) flang/lib/Optimizer/Transforms/ConstExtruder.cpp (+216) 
- (modified) flang/test/Driver/bbc-mlir-pass-pipeline.f90 (+1) 
- (modified) flang/test/Driver/mlir-debug-pass-pipeline.f90 (+1) 
- (modified) flang/test/Driver/mlir-pass-pipeline.f90 (+1) 
- (modified) flang/test/Fir/basic-program.fir (+1) 
- (modified) flang/test/Fir/boxproc.fir (+1-3) 
- (modified) flang/test/Lower/character-local-variables.f90 (+1-2) 
- (modified) flang/test/Lower/dummy-arguments.f90 (+1-3) 
- (modified) flang/test/Lower/host-associated.f90 (+3-4) 
- (added) flang/test/Transforms/const-extrude.f90 (+32) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 92bc7246eca7005..f1c38a026660243 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -60,6 +60,7 @@ createExternalNameConversionPass(bool appendUnderscore);
 std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
 std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
+std::unique_ptr<mlir::Pass> createConstExtruderPass();
 std::unique_ptr<mlir::Pass> createStackArraysPass();
 std::unique_ptr<mlir::Pass> createAliasTagsPass();
 std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c3768fd2d689c1a..179833876a7b333 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -242,6 +242,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
   let constructor = "::fir::createMemoryAllocationPass()";
 }
 
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+  let summary = "Convert scalar literals of function arguments to global constants.";
+  let description = [{
+    Convert scalar literals of function arguments to global constants.
+  }];
+  let dependentDialects = [ "fir::FIROpsDialect" ];
+  let constructor = "::fir::createConstExtruderPass()";
+}
+
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
   let summary = "Move local array allocations from heap memory into stack memory";
   let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d3e4dc6cd4a243e..b902621dfe42177 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -216,6 +216,8 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
+  pm.addPass(fir::createConstExtruderPass());
+
   // The default inliner pass adds the canonicalizer pass with the default
   // configuration. Create the inliner pass with tco config.
   llvm::StringMap<mlir::OpPassManager> pipelines;
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 03b67104a93b575..bada67729ede95c 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
   AffineDemotion.cpp
   AnnotateConstant.cpp
   CharacterConversion.cpp
+  ConstExtruder.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 000000000000000..1bb1cd226987110
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+  if (!a || !a->getDefiningOp())
+    return false;
+
+  // is alloca
+  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+    // alloca has annotation
+    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+          auto constant_def = store->getOperand(0).getDefiningOp();
+          // Expect constant definition operation
+          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+  mlir::DominanceInfo &di;
+
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+      : OpRewritePattern(ctx), di(_di) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+    auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, module);
+    llvm::SmallVector<mlir::Value> newOperands;
+    llvm::SmallVector<mlir::Operation *> toErase;
+    for (const auto &a : callOp.getArgs()) {
+      if (auto alloca =
+              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+        if (needsExtrusion(&a)) {
+
+          mlir::Type varTy = alloca.getInType();
+          assert(!fir::hasDynamicSize(varTy) &&
+                 "only expect statically sized scalars to be by value");
+
+          // find immediate store with const argument
+          llvm::SmallVector<mlir::Operation *> stores;
+          for (mlir::Operation *s : alloca.getOperation()->getUsers())
+            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+              stores.push_back(s);
+          assert(stores.size() == 1 && "expected exactly one store");
+          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+          // Expect constant definition operation or force legalisation of the
+          // callOp and continue with its next argument
+          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            // unable to remove alloca arg
+            newOperands.push_back(a);
+            continue;
+          }
+
+          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+          auto loc = callOp.getLoc();
+          llvm::StringRef globalPrefix = "_extruded_";
+
+          std::string globalName;
+          while (!globalName.length() || builder.getNamedGlobal(globalName))
+            globalName =
+                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+          if (alloca->hasOneUse()) {
+            toErase.push_back(alloca);
+            toErase.push_back(stores[0]);
+          } else {
+            int count = -2;
+            for (mlir::Operation *s : alloca.getOperation()->getUsers())
+              if (di.dominates(stores[0], s))
+                ++count;
+
+            // delete if dominates itself and one more operation (which should
+            // be callOp)
+            if (!count)
+              toErase.push_back(stores[0]);
+          }
+          auto global = builder.createGlobalConstant(
+              loc, varTy, globalName,
+              [&](fir::FirOpBuilder &builder) {
+                mlir::Operation *cln = constant_def->clone();
+                builder.insert(cln);
+                fir::ExtendedValue exv{cln->getResult(0)};
+                mlir::Value valBase = fir::getBase(exv);
+                mlir::Value val = builder.createConvert(loc, varTy, valBase);
+                builder.create<fir::HasValueOp>(loc, val);
+              },
+              builder.createInternalLinkage());
+          mlir::Value ope = {builder.create<fir::AddrOfOp>(
+              loc, global.resultType(), global.getSymbol())};
+          newOperands.push_back(ope);
+        } else {
+          // alloca but without attr, add it
+          newOperands.push_back(a);
+        }
+      } else {
+        // non-alloca operand, add it
+        newOperands.push_back(a);
+      }
+    }
+
+    auto loc = callOp.getLoc();
+    llvm::SmallVector<mlir::Type> newResultTypes;
+    newResultTypes.append(callOp.getResultTypes().begin(),
+                          callOp.getResultTypes().end());
+    fir::CallOp newOp = builder.create<fir::CallOp>(
+        loc, newResultTypes,
+        callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                       : mlir::SymbolRefAttr{},
+        newOperands, callOp.getFastmathAttr());
+    rewriter.replaceOp(callOp, newOp);
+
+    for (auto e : toErase)
+      rewriter.eraseOp(e);
+
+    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                            << newOp << '\n');
+    return mlir::success();
+  }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+  mlir::DominanceInfo *di;
+
+public:
+  ConstExtruderOpt() {}
+
+  void runOnOperation() override {
+    mlir::ModuleOp mod = getOperation();
+    di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+  }
+
+  void runOnFunc(mlir::func::FuncOp &func) {
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target(*context);
+
+    // If func is a declaration, skip it.
+    if (func.empty())
+      return;
+
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+      for (auto a : op.getArgs()) {
+        if (needsExtrusion(&a))
+          return false;
+      }
+      return true;
+    });
+
+    patterns.insert<CallOpRewriter>(context, *di);
+    if (mlir::failed(
+            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+      mlir::emitError(func.getLoc(),
+                      "error in constant extrusion optimization\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+  return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 243a620a9fd003c..c43149e07bf55f3 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,6 +31,7 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
+! CHECK-NEXT:   ConstExtruderOpt
 
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index a3ff416f4d77951..5eb2354b67012e0 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -51,6 +51,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 3d8c42f123e2eb0..10bc9b90cfd769a 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -42,6 +42,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index d8a9e74c318ce18..b9aabd322399ca2 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -48,6 +48,7 @@ func.func @_QQmain() {
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
+// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af042..2ddc0ef525ac481 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
 
 // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
 // CHECK-SAME:              %[[VAL_0:.*]])
-// CHECK:         %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK:         store i32 4, ptr %[[VAL_1]], align 4
-// CHECK:         call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK:         call void %[[VAL_0]](ptr @{{.*}})
 
 func.func @_QPtest_proc_dummy() {
   %c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e73..b1cfc540f438966 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
-  ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+  ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
   ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 43d8e3c1e5d4485..46e4323e8862049 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
-  ! CHECK-DAG: %[[TEN:.*]] = arith.constant
-  ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index e2db6deb8803d08..26598ef1f16ea31 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
 
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK:         %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK:         fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
 ! CHECK:         return
 ! CHECK:       }
 
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 000000000000000..70cdaf496f34acc
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+  implicit none
+  integer x, y
+  
+  call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }

``````````

</details>


https://github.com/llvm/llvm-project/pull/73829


More information about the flang-commits mailing list