[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)

Mats Petersson via flang-commits flang-commits at lists.llvm.org
Wed Dec 13 04:08:40 PST 2023


https://github.com/Leporacanthicus updated https://github.com/llvm/llvm-project/pull/73829

>From ea5e93a06bd32cad0ccf402ecccfe9451e2d6c9b Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 22 Nov 2023 14:01:38 +0000
Subject: [PATCH 1/3] [Flang] Extracting internal constants from scalar
 literals

Constants actual arguments in function/subroutine calls are currently lowered
as allocas + store. This can sometimes inhibit LTO and the constant will not
be propagated to the called function. Particularly in cases where the
function/subroutine call happens inside a condition.

This patch changes the lowering of these constant actual arguments to a
global constant + fir.address_of_op. This lowering makes it easier for
LTO to propagate the constant.

Co-authored-by: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
---
 .../flang/Optimizer/Transforms/Passes.h       |   1 +
 .../flang/Optimizer/Transforms/Passes.td      |  10 +
 flang/include/flang/Tools/CLOptions.inc       |   2 +
 flang/lib/Optimizer/Transforms/CMakeLists.txt |   1 +
 .../Optimizer/Transforms/ConstExtruder.cpp    | 216 ++++++++++++++++++
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   1 +
 .../test/Driver/mlir-debug-pass-pipeline.f90  |   1 +
 flang/test/Driver/mlir-pass-pipeline.f90      |   1 +
 flang/test/Fir/basic-program.fir              |   1 +
 flang/test/Fir/boxproc.fir                    |   4 +-
 .../test/Lower/character-local-variables.f90  |   3 +-
 flang/test/Lower/dummy-arguments.f90          |   4 +-
 flang/test/Lower/host-associated.f90          |   7 +-
 flang/test/Transforms/const-extrude.f90       |  32 +++
 14 files changed, 272 insertions(+), 12 deletions(-)
 create mode 100644 flang/lib/Optimizer/Transforms/ConstExtruder.cpp
 create mode 100644 flang/test/Transforms/const-extrude.f90

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 92bc7246eca700..f1c38a02666024 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -60,6 +60,7 @@ createExternalNameConversionPass(bool appendUnderscore);
 std::unique_ptr<mlir::Pass> createMemDataFlowOptPass();
 std::unique_ptr<mlir::Pass> createPromoteToAffinePass();
 std::unique_ptr<mlir::Pass> createMemoryAllocationPass();
+std::unique_ptr<mlir::Pass> createConstExtruderPass();
 std::unique_ptr<mlir::Pass> createStackArraysPass();
 std::unique_ptr<mlir::Pass> createAliasTagsPass();
 std::unique_ptr<mlir::Pass> createSimplifyIntrinsicsPass();
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index c3768fd2d689c1..179833876a7b33 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -242,6 +242,16 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
   let constructor = "::fir::createMemoryAllocationPass()";
 }
 
+// This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
+def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
+  let summary = "Convert scalar literals of function arguments to global constants.";
+  let description = [{
+    Convert scalar literals of function arguments to global constants.
+  }];
+  let dependentDialects = [ "fir::FIROpsDialect" ];
+  let constructor = "::fir::createConstExtruderPass()";
+}
+
 def StackArrays : Pass<"stack-arrays", "mlir::ModuleOp"> {
   let summary = "Move local array allocations from heap memory into stack memory";
   let description = [{
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index d3e4dc6cd4a243..b902621dfe4217 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -216,6 +216,8 @@ inline void createDefaultFIROptimizerPassPipeline(
   else
     fir::addMemoryAllocationOpt(pm);
 
+  pm.addPass(fir::createConstExtruderPass());
+
   // The default inliner pass adds the canonicalizer pass with the default
   // configuration. Create the inliner pass with tco config.
   llvm::StringMap<mlir::OpPassManager> pipelines;
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 03b67104a93b57..bada67729ede95 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_flang_library(FIRTransforms
   AffineDemotion.cpp
   AnnotateConstant.cpp
   CharacterConversion.cpp
+  ConstExtruder.cpp
   ControlFlowConverter.cpp
   ArrayValueCopy.cpp
   ExternalNameConversion.cpp
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
new file mode 100644
index 00000000000000..1bb1cd22698711
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -0,0 +1,216 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+  if (!a || !a->getDefiningOp())
+    return false;
+
+  // is alloca
+  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+    // alloca has annotation
+    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+          auto constant_def = store->getOperand(0).getDefiningOp();
+          // Expect constant definition operation
+          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+  mlir::DominanceInfo &di;
+
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+      : OpRewritePattern(ctx), di(_di) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+    auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, module);
+    llvm::SmallVector<mlir::Value> newOperands;
+    llvm::SmallVector<mlir::Operation *> toErase;
+    for (const auto &a : callOp.getArgs()) {
+      if (auto alloca =
+              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+        if (needsExtrusion(&a)) {
+
+          mlir::Type varTy = alloca.getInType();
+          assert(!fir::hasDynamicSize(varTy) &&
+                 "only expect statically sized scalars to be by value");
+
+          // find immediate store with const argument
+          llvm::SmallVector<mlir::Operation *> stores;
+          for (mlir::Operation *s : alloca.getOperation()->getUsers())
+            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+              stores.push_back(s);
+          assert(stores.size() == 1 && "expected exactly one store");
+          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+          // Expect constant definition operation or force legalisation of the
+          // callOp and continue with its next argument
+          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            // unable to remove alloca arg
+            newOperands.push_back(a);
+            continue;
+          }
+
+          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+          auto loc = callOp.getLoc();
+          llvm::StringRef globalPrefix = "_extruded_";
+
+          std::string globalName;
+          while (!globalName.length() || builder.getNamedGlobal(globalName))
+            globalName =
+                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+          if (alloca->hasOneUse()) {
+            toErase.push_back(alloca);
+            toErase.push_back(stores[0]);
+          } else {
+            int count = -2;
+            for (mlir::Operation *s : alloca.getOperation()->getUsers())
+              if (di.dominates(stores[0], s))
+                ++count;
+
+            // delete if dominates itself and one more operation (which should
+            // be callOp)
+            if (!count)
+              toErase.push_back(stores[0]);
+          }
+          auto global = builder.createGlobalConstant(
+              loc, varTy, globalName,
+              [&](fir::FirOpBuilder &builder) {
+                mlir::Operation *cln = constant_def->clone();
+                builder.insert(cln);
+                fir::ExtendedValue exv{cln->getResult(0)};
+                mlir::Value valBase = fir::getBase(exv);
+                mlir::Value val = builder.createConvert(loc, varTy, valBase);
+                builder.create<fir::HasValueOp>(loc, val);
+              },
+              builder.createInternalLinkage());
+          mlir::Value ope = {builder.create<fir::AddrOfOp>(
+              loc, global.resultType(), global.getSymbol())};
+          newOperands.push_back(ope);
+        } else {
+          // alloca but without attr, add it
+          newOperands.push_back(a);
+        }
+      } else {
+        // non-alloca operand, add it
+        newOperands.push_back(a);
+      }
+    }
+
+    auto loc = callOp.getLoc();
+    llvm::SmallVector<mlir::Type> newResultTypes;
+    newResultTypes.append(callOp.getResultTypes().begin(),
+                          callOp.getResultTypes().end());
+    fir::CallOp newOp = builder.create<fir::CallOp>(
+        loc, newResultTypes,
+        callOp.getCallee().has_value() ? callOp.getCallee().value()
+                                       : mlir::SymbolRefAttr{},
+        newOperands, callOp.getFastmathAttr());
+    rewriter.replaceOp(callOp, newOp);
+
+    for (auto e : toErase)
+      rewriter.eraseOp(e);
+
+    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                            << newOp << '\n');
+    return mlir::success();
+  }
+};
+
+// This pass attempts to convert immediate scalar literals in function calls
+// to global constants to allow transformations as Dead Argument Elimination
+class ConstExtruderOpt
+    : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
+protected:
+  mlir::DominanceInfo *di;
+
+public:
+  ConstExtruderOpt() {}
+
+  void runOnOperation() override {
+    mlir::ModuleOp mod = getOperation();
+    di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+  }
+
+  void runOnFunc(mlir::func::FuncOp &func) {
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::ConversionTarget target(*context);
+
+    // If func is a declaration, skip it.
+    if (func.empty())
+      return;
+
+    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
+                           mlir::func::FuncDialect>();
+    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
+      for (auto a : op.getArgs()) {
+        if (needsExtrusion(&a))
+          return false;
+      }
+      return true;
+    });
+
+    patterns.insert<CallOpRewriter>(context, *di);
+    if (mlir::failed(
+            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+      mlir::emitError(func.getLoc(),
+                      "error in constant extrusion optimization\n");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::createConstExtruderPass() {
+  return std::make_unique<ConstExtruderOpt>();
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index 243a620a9fd003..c43149e07bf55f 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,6 +31,7 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
+! CHECK-NEXT:   ConstExtruderOpt
 
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index a3ff416f4d7795..5eb2354b67012e 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -51,6 +51,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index 3d8c42f123e2eb..10bc9b90cfd769 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -42,6 +42,7 @@
 
 ! ALL-NEXT: 'func.func' Pipeline
 ! ALL-NEXT:   MemoryAllocationOpt
+! ALL-NEXT:   ConstExtruderOpt
 
 ! ALL-NEXT: Inliner
 ! ALL-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index d8a9e74c318ce1..b9aabd322399ca 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -48,6 +48,7 @@ func.func @_QQmain() {
 
 // PASSES-NEXT: 'func.func' Pipeline
 // PASSES-NEXT:   MemoryAllocationOpt
+// PASSES-NEXT:   ConstExtruderOpt
 
 // PASSES-NEXT: Inliner
 // PASSES-NEXT: SimplifyRegionLite
diff --git a/flang/test/Fir/boxproc.fir b/flang/test/Fir/boxproc.fir
index 1fed16a808af04..2ddc0ef525ac48 100644
--- a/flang/test/Fir/boxproc.fir
+++ b/flang/test/Fir/boxproc.fir
@@ -16,9 +16,7 @@
 
 // CHECK-LABEL: define void @_QPtest_proc_dummy_other(ptr
 // CHECK-SAME:              %[[VAL_0:.*]])
-// CHECK:         %[[VAL_1:.*]] = alloca i32, i64 1, align 4
-// CHECK:         store i32 4, ptr %[[VAL_1]], align 4
-// CHECK:         call void %[[VAL_0]](ptr %[[VAL_1]])
+// CHECK:         call void %[[VAL_0]](ptr @{{.*}})
 
 func.func @_QPtest_proc_dummy() {
   %c0_i32 = arith.constant 0 : i32
diff --git a/flang/test/Lower/character-local-variables.f90 b/flang/test/Lower/character-local-variables.f90
index 0cf61a2623c4e7..b1cfc540f43896 100644
--- a/flang/test/Lower/character-local-variables.f90
+++ b/flang/test/Lower/character-local-variables.f90
@@ -116,8 +116,7 @@ subroutine dyn_array_dyn_len_lb(l, n)
 subroutine assumed_length_param(n)
   character(*), parameter :: c(1)=(/"abcd"/)
   integer :: n
-  ! CHECK: %[[c4:.*]] = arith.constant 4 : i64
-  ! CHECK: fir.store %[[c4]] to %[[tmp:.*]] : !fir.ref<i64>
+  ! CHECK: %[[tmp:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i64>
   ! CHECK: fir.call @_QPtake_int(%[[tmp]]) {{.*}}: (!fir.ref<i64>) -> ()
   call take_int(len(c(n), kind=8))
 end
diff --git a/flang/test/Lower/dummy-arguments.f90 b/flang/test/Lower/dummy-arguments.f90
index 43d8e3c1e5d448..46e4323e886204 100644
--- a/flang/test/Lower/dummy-arguments.f90
+++ b/flang/test/Lower/dummy-arguments.f90
@@ -2,9 +2,7 @@
 
 ! CHECK-LABEL: _QQmain
 program test1
-  ! CHECK-DAG: %[[TMP:.*]] = fir.alloca
-  ! CHECK-DAG: %[[TEN:.*]] = arith.constant
-  ! CHECK: fir.store %[[TEN]] to %[[TMP]]
+  ! CHECK-DAG: %[[TEN:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
   ! CHECK-NEXT: fir.call @_QFPfoo
   call foo(10)
 contains
diff --git a/flang/test/Lower/host-associated.f90 b/flang/test/Lower/host-associated.f90
index e2db6deb8803d0..26598ef1f16ea3 100644
--- a/flang/test/Lower/host-associated.f90
+++ b/flang/test/Lower/host-associated.f90
@@ -448,11 +448,10 @@ subroutine bar()
 
 ! CHECK-LABEL: func @_QPtest_proc_dummy_other(
 ! CHECK-SAME:           %[[VAL_0:.*]]: !fir.boxproc<() -> ()>) {
-! CHECK:         %[[VAL_1:.*]] = arith.constant 4 : i32
-! CHECK:         %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref}
-! CHECK:         fir.store %[[VAL_1]] to %[[VAL_2]] : !fir.ref<i32>
 ! CHECK:         %[[VAL_3:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxproc<() -> ()>) -> ((!fir.ref<i32>) -> ())
-! CHECK:         fir.call %[[VAL_3]](%[[VAL_2]]) {{.*}}: (!fir.ref<i32>) -> ()
+! CHECK:         %[[VAL_1:.*]] = fir.address_of(@_extruded_.{{.*}}) : !fir.ref<i32>
+! CHECK:         fir.call %[[VAL_3]](%[[VAL_1]]) {{.*}}: (!fir.ref<i32>) -> ()
+
 ! CHECK:         return
 ! CHECK:       }
 
diff --git a/flang/test/Transforms/const-extrude.f90 b/flang/test/Transforms/const-extrude.f90
new file mode 100644
index 00000000000000..70cdaf496f34ac
--- /dev/null
+++ b/flang/test/Transforms/const-extrude.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-fir %s -o - | fir-opt --const-extruder-opt | FileCheck %s
+
+subroutine sub1(x,y)
+  implicit none
+  integer x, y
+  
+  call sub2(0.0d0, 1.0d0, x, y, 1)
+end subroutine sub1
+
+!CHECK-LABEL: func.func @_QPsub1
+!CHECK-SAME: [[ARG0:%.*]]: !fir.ref<i32> {{{.*}}},
+!CHECK-SAME: [[ARG1:%.*]]: !fir.ref<i32> {{{.*}}}) {
+!CHECK: [[X:%.*]] = fir.declare [[ARG0]] {{.*}}
+!CHECK: [[Y:%.*]] = fir.declare [[ARG1]] {{.*}}
+!CHECK: [[CONST_R0:%.*]] = fir.address_of([[EXTR_0:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_R1:%.*]] = fir.address_of([[EXTR_1:@.*]]) : !fir.ref<f64>
+!CHECK: [[CONST_I:%.*]] = fir.address_of([[EXTR_2:@.*]]) : !fir.ref<i32>
+!CHECK: fir.call @_QPsub2([[CONST_R0]], [[CONST_R1]], [[X]], [[Y]], [[CONST_I]])
+!CHECK: return
+
+!CHECK: fir.global internal [[EXTR_0]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 0.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_1]] constant : f64 {
+!CHECK: %{{.*}} = arith.constant 1.000000e+00 : f64
+!CHECK: fir.has_value %{{.*}} : f64
+!CHECK: }
+!CHECK: fir.global internal [[EXTR_2]] constant : i32 {
+!CHECK: %{{.*}} = arith.constant 1 : i32
+!CHECK: fir.has_value %{{.*}} : i32
+!CHECK: }

>From 6a3c3ba62e08d0be8dbf9293b7f89bfb6dd09389 Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Wed, 6 Dec 2023 16:12:14 +0000
Subject: [PATCH 2/3] Use greedy rewriter

---
 .../Optimizer/Transforms/ConstExtruder.cpp    | 23 +++++++------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 1bb1cd22698711..00c9f30e9dbacc 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -16,7 +16,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include <atomic>
@@ -171,9 +171,13 @@ class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
 protected:
   mlir::DominanceInfo *di;
+  mlir::GreedyRewriteConfig config;
 
 public:
-  ConstExtruderOpt() {}
+  ConstExtruderOpt() {
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+  }
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
@@ -184,25 +188,14 @@ class ConstExtruderOpt
   void runOnFunc(mlir::func::FuncOp &func) {
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
-    mlir::ConversionTarget target(*context);
 
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
-    target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
-                           mlir::func::FuncDialect>();
-    target.addDynamicallyLegalOp<fir::CallOp>([&](fir::CallOp op) {
-      for (auto a : op.getArgs()) {
-        if (needsExtrusion(&a))
-          return false;
-      }
-      return true;
-    });
-
     patterns.insert<CallOpRewriter>(context, *di);
-    if (mlir::failed(
-            mlir::applyPartialConversion(func, target, std::move(patterns)))) {
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            func, std::move(patterns), config))) {
       mlir::emitError(func.getLoc(),
                       "error in constant extrusion optimization\n");
       signalPassFailure();

>From f936dde64196dcd0004238525a12fd84b1b3838f Mon Sep 17 00:00:00 2001
From: Mats Petersson <mats.petersson at arm.com>
Date: Tue, 12 Dec 2023 14:36:10 +0000
Subject: [PATCH 3/3] Fix review comments

---
 .../flang/Optimizer/Transforms/Passes.td      |   2 +-
 .../Optimizer/Transforms/ConstExtruder.cpp    | 239 ++++++++----------
 flang/test/Driver/bbc-mlir-pass-pipeline.f90  |   2 +-
 3 files changed, 111 insertions(+), 132 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 179833876a7b33..1ed8c90830d08e 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -244,7 +244,7 @@ def MemoryAllocationOpt : Pass<"memory-allocation-opt", "mlir::func::FuncOp"> {
 
 // This needs to be a "mlir::ModuleOp" pass, because it inserts global constants
 def ConstExtruderOpt : Pass<"const-extruder-opt", "mlir::ModuleOp"> {
-  let summary = "Convert scalar literals of function arguments to global constants.";
+  let summary = "Convert constant function arguments to global constants.";
   let description = [{
     Convert scalar literals of function arguments to global constants.
   }];
diff --git a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
index 00c9f30e9dbacc..355166f3422bb8 100644
--- a/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
+++ b/flang/lib/Optimizer/Transforms/ConstExtruder.cpp
@@ -1,4 +1,4 @@
-//===- ConstExtruder.cpp -----------------------------------------------===//
+//===- ConstExtruder.cpp --------------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "flang/Optimizer/Builder/BoxValue.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
@@ -17,9 +16,6 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include <atomic>
 
 namespace fir {
 #define GEN_PASS_DEF_CONSTEXTRUDEROPT
@@ -29,38 +25,16 @@ namespace fir {
 #define DEBUG_TYPE "flang-const-extruder-opt"
 
 namespace {
-std::atomic<int> uniqueLitId = 1;
-
-static bool needsExtrusion(const mlir::Value *a) {
-  if (!a || !a->getDefiningOp())
-    return false;
-
-  // is alloca
-  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
-    // alloca has annotation
-    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
-      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
-        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
-          auto constant_def = store->getOperand(0).getDefiningOp();
-          // Expect constant definition operation
-          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            return true;
-          }
-        }
-      }
-    }
-  }
-  return false;
-}
+unsigned uniqueLitId = 1;
 
 class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
 protected:
-  mlir::DominanceInfo &di;
+  const mlir::DominanceInfo &di;
 
 public:
   using OpRewritePattern::OpRewritePattern;
 
-  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+  CallOpRewriter(mlir::MLIRContext *ctx, const mlir::DominanceInfo &_di)
       : OpRewritePattern(ctx), di(_di) {}
 
   mlir::LogicalResult
@@ -68,131 +42,136 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
                   mlir::PatternRewriter &rewriter) const override {
     LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
     auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    bool needUpdate = false;
     fir::FirOpBuilder builder(rewriter, module);
     llvm::SmallVector<mlir::Value> newOperands;
     llvm::SmallVector<mlir::Operation *> toErase;
-    for (const auto &a : callOp.getArgs()) {
-      if (auto alloca =
-              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
-        if (needsExtrusion(&a)) {
-
-          mlir::Type varTy = alloca.getInType();
-          assert(!fir::hasDynamicSize(varTy) &&
-                 "only expect statically sized scalars to be by value");
-
-          // find immediate store with const argument
-          llvm::SmallVector<mlir::Operation *> stores;
-          for (mlir::Operation *s : alloca.getOperation()->getUsers())
-            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
-              stores.push_back(s);
-          assert(stores.size() == 1 && "expected exactly one store");
-          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
-
-          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
-          // Expect constant definition operation or force legalisation of the
-          // callOp and continue with its next argument
-          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
-            // unable to remove alloca arg
-            newOperands.push_back(a);
-            continue;
-          }
-
-          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
-
-          auto loc = callOp.getLoc();
-          llvm::StringRef globalPrefix = "_extruded_";
-
-          std::string globalName;
-          while (!globalName.length() || builder.getNamedGlobal(globalName))
-            globalName =
-                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
-
-          if (alloca->hasOneUse()) {
-            toErase.push_back(alloca);
-            toErase.push_back(stores[0]);
-          } else {
-            int count = -2;
-            for (mlir::Operation *s : alloca.getOperation()->getUsers())
-              if (di.dominates(stores[0], s))
-                ++count;
-
-            // delete if dominates itself and one more operation (which should
-            // be callOp)
-            if (!count)
-              toErase.push_back(stores[0]);
+    for (const mlir::Value &a : callOp.getArgs()) {
+      auto alloca =
+          mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp());
+      // We can convert arguments that are alloca, and that has
+      // the value by reference attribute. All else is just added
+      // to the argument list.
+      if (!alloca || !alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+        newOperands.push_back(a);
+        continue;
+      }
+        
+      mlir::Type varTy = alloca.getInType();
+      assert(!fir::hasDynamicSize(varTy) &&
+               "only expect statically sized scalars to be by value");
+
+      // Find immediate store with const argument
+      mlir::Operation *store = nullptr;
+      for (mlir::Operation *s : alloca->getUsers()) {
+        if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp)) {
+          // We can only deal with ONE store - if already found one,
+          // set to nullptr and exit the loop. 
+          if (store) {
+            store = nullptr;
+            break;
           }
-          auto global = builder.createGlobalConstant(
-              loc, varTy, globalName,
-              [&](fir::FirOpBuilder &builder) {
-                mlir::Operation *cln = constant_def->clone();
-                builder.insert(cln);
-                fir::ExtendedValue exv{cln->getResult(0)};
-                mlir::Value valBase = fir::getBase(exv);
-                mlir::Value val = builder.createConvert(loc, varTy, valBase);
-                builder.create<fir::HasValueOp>(loc, val);
-              },
-              builder.createInternalLinkage());
-          mlir::Value ope = {builder.create<fir::AddrOfOp>(
-              loc, global.resultType(), global.getSymbol())};
-          newOperands.push_back(ope);
-        } else {
-          // alloca but without attr, add it
-          newOperands.push_back(a);
+          store = s;
         }
-      } else {
-        // non-alloca operand, add it
+      }
+      
+      // If we didn't find one signle store, add argument as is, and move on.
+      if (!store) {
         newOperands.push_back(a);
+        continue;
       }
+        
+      LLVM_DEBUG(llvm::dbgs() << " found store " << *store << "\n");
+
+      mlir::Operation *constant_def = store->getOperand(0).getDefiningOp();
+      // Expect constant definition operation or force legalisation of the
+      // callOp and continue with its next argument
+      if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+        // Unable to remove alloca arg
+        newOperands.push_back(a);
+        continue;
+      }
+
+      LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+      std::string globalName = "_extruded_." + std::to_string(uniqueLitId++);
+      assert(!builder.getNamedGlobal(globalName) && "We should have a unique name here");
+        
+      unsigned count = 0;
+      for (mlir::Operation *s : alloca->getUsers())
+        if (di.dominates(store, s))
+          ++count;
+
+      // Delete if dominates itself and one more operation (which should
+      // be callOp)
+      if (count == 2)
+        toErase.push_back(store);
+        
+      auto loc = callOp.getLoc();
+      fir::GlobalOp global = builder.createGlobalConstant(
+          loc, varTy, globalName,
+          [&](fir::FirOpBuilder &builder) {
+            mlir::Operation *cln = constant_def->clone();
+            builder.insert(cln);
+            mlir::Value val = builder.createConvert(loc, varTy, cln->getResult(0));
+            builder.create<fir::HasValueOp>(loc, val);
+          },
+          builder.createInternalLinkage());
+      mlir::Value addr = {builder.create<fir::AddrOfOp>(
+            loc, global.resultType(), global.getSymbol())};
+      newOperands.push_back(addr);
+      needUpdate = true;
     }
 
-    auto loc = callOp.getLoc();
-    llvm::SmallVector<mlir::Type> newResultTypes;
-    newResultTypes.append(callOp.getResultTypes().begin(),
-                          callOp.getResultTypes().end());
-    fir::CallOp newOp = builder.create<fir::CallOp>(
-        loc, newResultTypes,
-        callOp.getCallee().has_value() ? callOp.getCallee().value()
-                                       : mlir::SymbolRefAttr{},
-        newOperands, callOp.getFastmathAttr());
-    rewriter.replaceOp(callOp, newOp);
-
-    for (auto e : toErase)
-      rewriter.eraseOp(e);
-
-    LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
-                            << newOp << '\n');
-    return mlir::success();
+    if (needUpdate) {
+      auto loc = callOp.getLoc();
+      llvm::SmallVector<mlir::Type> newResultTypes;
+      newResultTypes.append(callOp.getResultTypes().begin(),
+                            callOp.getResultTypes().end());
+      fir::CallOp newOp = builder.create<fir::CallOp>(
+          loc, newResultTypes,
+          callOp.getCallee().has_value() ? callOp.getCallee().value()
+          : mlir::SymbolRefAttr{},
+          newOperands, callOp.getFastmathAttr());
+      rewriter.replaceOp(callOp, newOp);
+
+      for (auto e : toErase)
+        rewriter.eraseOp(e);
+      LLVM_DEBUG(llvm::dbgs() << "extruded constant for " << callOp << " as "
+                 << newOp << '\n');
+      return mlir::success();
+    }
+
+    // Failure here just means "we couldn't do the conversion", which is
+    // perfectly acceptable to the upper layers of this function.
+    return mlir::failure();
   }
 };
 
-// This pass attempts to convert immediate scalar literals in function calls
+// this pass attempts to convert immediate scalar literals in function calls
 // to global constants to allow transformations as Dead Argument Elimination
 class ConstExtruderOpt
     : public fir::impl::ConstExtruderOptBase<ConstExtruderOpt> {
-protected:
-  mlir::DominanceInfo *di;
-  mlir::GreedyRewriteConfig config;
-
 public:
-  ConstExtruderOpt() {
-    config.enableRegionSimplification = false;
-    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
-  }
+  ConstExtruderOpt() = default;
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
-    di = &getAnalysis<mlir::DominanceInfo>();
-    mod.walk([this](mlir::func::FuncOp func) { runOnFunc(func); });
+    mlir::DominanceInfo *di = &getAnalysis<mlir::DominanceInfo>();
+    mod.walk([di, this](mlir::func::FuncOp func) { runOnFunc(func, di); });
   }
 
-  void runOnFunc(mlir::func::FuncOp &func) {
-    auto *context = &getContext();
-    mlir::RewritePatternSet patterns(context);
-
+  void runOnFunc(mlir::func::FuncOp &func, const mlir::DominanceInfo* di) {
     // If func is a declaration, skip it.
     if (func.empty())
       return;
 
+    auto *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    mlir::GreedyRewriteConfig config;
+    config.enableRegionSimplification = false;
+    config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
+
     patterns.insert<CallOpRewriter>(context, *di);
     if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
             func, std::move(patterns), config))) {
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index c43149e07bf55f..942ee76d5155e1 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -31,8 +31,8 @@
 
 ! CHECK-NEXT: 'func.func' Pipeline
 ! CHECK-NEXT:   MemoryAllocationOpt
-! CHECK-NEXT:   ConstExtruderOpt
 
+! CHECK-NEXT: ConstExtruderOpt
 ! CHECK-NEXT: Inliner
 ! CHECK-NEXT: SimplifyRegionLite
 ! CHECK-NEXT: CSE



More information about the flang-commits mailing list