[flang-commits] [flang] [Flang] Extracting internal constants from scalar literals (PR #73829)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Thu Dec 7 09:06:40 PST 2023


================
@@ -0,0 +1,209 @@
+//===- ConstExtruder.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <atomic>
+
+namespace fir {
+#define GEN_PASS_DEF_CONSTEXTRUDEROPT
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+#define DEBUG_TYPE "flang-const-extruder-opt"
+
+namespace {
+std::atomic<int> uniqueLitId = 1;
+
+static bool needsExtrusion(const mlir::Value *a) {
+  if (!a || !a->getDefiningOp())
+    return false;
+
+  // is alloca
+  if (auto alloca = mlir::dyn_cast_or_null<fir::AllocaOp>(a->getDefiningOp())) {
+    // alloca has annotation
+    if (alloca->hasAttr(fir::getAdaptToByRefAttrName())) {
+      for (mlir::Operation *s : alloca.getOperation()->getUsers()) {
+        if (const auto store = mlir::dyn_cast_or_null<fir::StoreOp>(s)) {
+          auto constant_def = store->getOperand(0).getDefiningOp();
+          // Expect constant definition operation
+          if (mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return false;
+}
+
+class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
+protected:
+  mlir::DominanceInfo &di;
+
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  CallOpRewriter(mlir::MLIRContext *ctx, mlir::DominanceInfo &_di)
+      : OpRewritePattern(ctx), di(_di) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::CallOp callOp,
+                  mlir::PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(llvm::dbgs() << "Processing call op: " << callOp << "\n");
+    auto module = callOp->getParentOfType<mlir::ModuleOp>();
+    fir::FirOpBuilder builder(rewriter, module);
+    llvm::SmallVector<mlir::Value> newOperands;
+    llvm::SmallVector<mlir::Operation *> toErase;
+    for (const auto &a : callOp.getArgs()) {
+      if (auto alloca =
+              mlir::dyn_cast_or_null<fir::AllocaOp>(a.getDefiningOp())) {
+        if (needsExtrusion(&a)) {
+
+          mlir::Type varTy = alloca.getInType();
+          assert(!fir::hasDynamicSize(varTy) &&
+                 "only expect statically sized scalars to be by value");
+
+          // find immediate store with const argument
+          llvm::SmallVector<mlir::Operation *> stores;
+          for (mlir::Operation *s : alloca.getOperation()->getUsers())
+            if (mlir::isa<fir::StoreOp>(s) && di.dominates(s, callOp))
+              stores.push_back(s);
+          assert(stores.size() == 1 && "expected exactly one store");
+          LLVM_DEBUG(llvm::dbgs() << " found store " << *stores[0] << "\n");
+
+          auto constant_def = stores[0]->getOperand(0).getDefiningOp();
+          // Expect constant definition operation or force legalisation of the
+          // callOp and continue with its next argument
+          if (!mlir::isa<mlir::arith::ConstantOp>(constant_def)) {
+            // unable to remove alloca arg
+            newOperands.push_back(a);
+            continue;
+          }
+
+          LLVM_DEBUG(llvm::dbgs() << " found define " << *constant_def << "\n");
+
+          auto loc = callOp.getLoc();
+          llvm::StringRef globalPrefix = "_extruded_";
+
+          std::string globalName;
+          while (!globalName.length() || builder.getNamedGlobal(globalName))
+            globalName =
+                globalPrefix.str() + "." + std::to_string(uniqueLitId++);
+
+          if (alloca->hasOneUse()) {
+            toErase.push_back(alloca);
+            toErase.push_back(stores[0]);
+          } else {
+            int count = -2;
----------------
tblah wrote:

I presume the -2 is to discount the allocation and the call operation? Please could you add a comment explaining where this comes from

https://github.com/llvm/llvm-project/pull/73829


More information about the flang-commits mailing list