[llvm-branch-commits] [flang] [flang] translate derived type array init to attribute if possible (PR #140268)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon May 19 01:56:09 PDT 2025
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/140268
>From d71c0b7f45582ece43016eb98367251e54e75280 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Fri, 16 May 2025 08:09:37 -0700
Subject: [PATCH 1/2] [flang] translate derived type array init to attribute if
possible
---
.../Optimizer/CodeGen/LLVMInsertChainFolder.h | 31 +++
.../include/flang/Optimizer/Dialect/FIROps.td | 5 +
flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 +
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 51 +++--
.../CodeGen/LLVMInsertChainFolder.cpp | 204 ++++++++++++++++++
flang/lib/Optimizer/Dialect/FIROps.cpp | 15 ++
.../Fir/convert-and-fold-insert-on-range.fir | 33 +++
7 files changed, 319 insertions(+), 21 deletions(-)
create mode 100644 flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
create mode 100644 flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
create mode 100644 flang/test/Fir/convert-and-fold-insert-on-range.fir
diff --git a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
new file mode 100644
index 0000000000000..d577c4c0fa70b
--- /dev/null
+++ b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
@@ -0,0 +1,31 @@
+//===-- LLVMInsertChainFolder.h -- insertvalue chain folder ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helper to fold LLVM dialect llvm.insertvalue chain representing constants
+// into an Attribute representation.
+// This sits in Flang because it is incomplete and tailored for flang needs.
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+class Attribute;
+class OpBuilder;
+class Value;
+} // namespace mlir
+
+namespace fir {
+
+/// Attempt to fold an llvm.insertvalue chain into an attribute representation
+/// suitable as llvm.constant operand. The returned value will be a null pointer
+/// if this is not an llvm.insertvalue result pr if the chain is not a constant,
+/// or cannot be represented as an Attribute. The operations are not deleted,
+/// but some llvm.insertvalue value operands may be folded with the builder on
+/// the way.
+mlir::Attribute tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
+ mlir::OpBuilder &builder);
+} // namespace fir
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 458b780806144..dc66885f776f0 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2129,6 +2129,11 @@ def fir_InsertOnRangeOp : fir_OneResultOp<"insert_on_range", [NoMemoryEffect]> {
$seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
}];
+ let extraClassDeclaration = [{
+ /// Is this insert_on_range inserting on all the values of the result type?
+ bool isFullRange();
+ }];
+
let hasVerifier = 1;
}
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 04480bac552b7..980307db315d9 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -3,6 +3,7 @@ add_flang_library(FIRCodeGen
CodeGen.cpp
CodeGenOpenMP.cpp
FIROpPatterns.cpp
+ LLVMInsertChainFolder.cpp
LowerRepackArrays.cpp
PreCGRewrite.cpp
TBAABuilder.cpp
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ad9119ba4a031..ed76a77ced047 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -14,6 +14,7 @@
#include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
#include "flang/Optimizer/CodeGen/FIROpPatterns.h"
+#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRCG/CGOps.h"
@@ -2412,15 +2413,38 @@ struct InsertOnRangeOpConversion
doRewrite(fir::InsertOnRangeOp range, mlir::Type ty, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
- llvm::SmallVector<std::int64_t> dims;
- auto type = adaptor.getOperands()[0].getType();
+ auto arrayType = adaptor.getSeq().getType();
// Iteratively extract the array dimensions from the type.
+ llvm::SmallVector<std::int64_t> dims;
+ mlir::Type type = arrayType;
while (auto t = mlir::dyn_cast<mlir::LLVM::LLVMArrayType>(type)) {
dims.push_back(t.getNumElements());
type = t.getElementType();
}
+ // Avoid generating long insert chain that are very slow to fold back
+ // (which is required in globals when later generating LLVM IR). Attempt to
+ // fold the inserted element value to an attribute and build an ArrayAttr
+ // for the resulting array.
+ if (range.isFullRange()) {
+ if (mlir::Attribute cst =
+ fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter)) {
+ mlir::Attribute dimVal = cst;
+ for (auto dim : llvm::reverse(dims)) {
+ // Use std::vector in case the number of elements is big.
+ std::vector<mlir::Attribute> elements(dim, dimVal);
+ dimVal = mlir::ArrayAttr::get(range.getContext(), elements);
+ }
+ // Replace insert chain with constant.
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(range, arrayType,
+ dimVal);
+ return mlir::success();
+ }
+ }
+
+ // The inserted value cannot be folded to an attribute, turn the
+ // insert_range into an llvm.insertvalue chain.
llvm::SmallVector<std::int64_t> lBounds;
llvm::SmallVector<std::int64_t> uBounds;
@@ -2434,8 +2458,8 @@ struct InsertOnRangeOpConversion
auto &subscripts = lBounds;
auto loc = range.getLoc();
- mlir::Value lastOp = adaptor.getOperands()[0];
- mlir::Value insertVal = adaptor.getOperands()[1];
+ mlir::Value lastOp = adaptor.getSeq();
+ mlir::Value insertVal = adaptor.getVal();
while (subscripts != uBounds) {
lastOp = rewriter.create<mlir::LLVM::InsertValueOp>(
@@ -3131,7 +3155,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
// initialization is on the full range.
auto insertOnRangeOps = gr.front().getOps<fir::InsertOnRangeOp>();
for (auto insertOp : insertOnRangeOps) {
- if (isFullRange(insertOp.getCoor(), insertOp.getType())) {
+ if (insertOp.isFullRange()) {
auto seqTyAttr = convertType(insertOp.getType());
auto *op = insertOp.getVal().getDefiningOp();
auto constant = mlir::dyn_cast<mlir::arith::ConstantOp>(op);
@@ -3161,22 +3185,7 @@ struct GlobalOpConversion : public fir::FIROpConversion<fir::GlobalOp> {
return mlir::success();
}
- bool isFullRange(mlir::DenseIntElementsAttr indexes,
- fir::SequenceType seqTy) const {
- auto extents = seqTy.getShape();
- if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
- return false;
- auto cur_index = indexes.value_begin<int64_t>();
- for (unsigned i = 0; i < indexes.size(); i += 2) {
- if (*(cur_index++) != 0)
- return false;
- if (*(cur_index++) != extents[i / 2] - 1)
- return false;
- }
- return true;
- }
-
- // TODO: String comparaison should be avoided. Replace linkName with an
+ // TODO: String comparisons should be avoided. Replace linkName with an
// enumeration.
mlir::LLVM::Linkage
convertLinkage(std::optional<llvm::StringRef> optLinkage) const {
diff --git a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
new file mode 100644
index 0000000000000..0fc8697b735cf
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
@@ -0,0 +1,204 @@
+//===-- LLVMInsertChainFolder.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/LLVMInsertChainFolder.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "flang-insert-folder"
+
+#include <deque>
+
+namespace {
+// Helper class to construct the attribute elements of an aggregate value being
+// folded without creating a full mlir::Attribute representation for each step
+// of the insert value chain, which would both be expensive in terms of
+// compilation time and memory (since the intermediate Attribute would survive,
+// unused, inside the mlir context).
+class InsertChainBackwardFolder {
+ // Type for the current value of an element of the aggregate value being
+ // constructed by the insert chain.
+ // At any point of the insert chain, the value of an element is either:
+ // - nullptr: not yet known, the insert has not yet been seen.
+ // - an mlir::Attribute: the element is fully defined.
+ // - a nested InsertChainBackwardFolder: the element is itself an aggregate
+ // and its sub-elements have been partially defined (insert with mutliple
+ // indices have been seen).
+
+ // The insertion folder assumes backward walk of the insert chain. Once an
+ // element or sub-element has been defined, it is not overriden by new
+ // insertions (last insert wins).
+ using InFlightValue =
+ llvm::PointerUnion<mlir::Attribute, InsertChainBackwardFolder *>;
+
+public:
+ InsertChainBackwardFolder(
+ mlir::Type type, std::deque<InsertChainBackwardFolder> *folderStorage)
+ : values(getNumElements(type), mlir::Attribute{}),
+ folderStorage{folderStorage}, type{type} {}
+
+ /// Push
+ bool pushValue(mlir::Attribute val, llvm::ArrayRef<int64_t> at);
+
+ mlir::Attribute finalize(mlir::Attribute defaultFieldValue);
+
+private:
+ static int64_t getNumElements(mlir::Type type) {
+ if (auto structTy =
+ llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
+ return structTy.getBody().size();
+ if (auto arrayTy =
+ llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
+ return arrayTy.getNumElements();
+ return 0;
+ }
+
+ static mlir::Type getSubElementType(mlir::Type type, int64_t field) {
+ if (auto arrayTy =
+ llvm::dyn_cast_if_present<mlir::LLVM::LLVMArrayType>(type))
+ return arrayTy.getElementType();
+ if (auto structTy =
+ llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
+ return structTy.getBody()[field];
+ return {};
+ }
+
+ // Current element value of the aggregate value being built.
+ llvm::SmallVector<InFlightValue> values;
+ // std::deque is used to allocate storage for nested list and guarantee the
+ // stability of the InsertChainBackwardFolder* used as element value.
+ std::deque<InsertChainBackwardFolder> *folderStorage;
+ // Type of the aggregate value being built.
+ mlir::Type type;
+};
+} // namespace
+
+// Helper to fold the value being inserted by an llvm.insert_value.
+// This may call tryFoldingLLVMInsertChain if the value is an aggregate and
+// was itself constructed by a different insert chain.
+static mlir::Attribute getAttrIfConstant(mlir::Value val,
+ mlir::OpBuilder &rewriter) {
+ if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
+ return cst.getValue();
+ if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>())
+ return fir::tryFoldingLLVMInsertChain(val, rewriter);
+ if (val.getDefiningOp<mlir::LLVM::ZeroOp>())
+ return mlir::LLVM::ZeroAttr::get(val.getContext());
+ if (val.getDefiningOp<mlir::LLVM::UndefOp>())
+ return mlir::LLVM::UndefAttr::get(val.getContext());
+ if (mlir::Operation *op = val.getDefiningOp()) {
+ unsigned resNum = llvm::cast<mlir::OpResult>(val).getResultNumber();
+ llvm::SmallVector<mlir::Value> results;
+ if (mlir::succeeded(rewriter.tryFold(op, results)) &&
+ results.size() > resNum) {
+ if (auto cst = results[resNum].getDefiningOp<mlir::LLVM::ConstantOp>())
+ return cst.getValue();
+ }
+ }
+ if (auto trunc = val.getDefiningOp<mlir::LLVM::TruncOp>())
+ if (auto attr = getAttrIfConstant(trunc.getArg(), rewriter))
+ if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr))
+ return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt());
+ LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val
+ << "\n");
+ return {};
+}
+
+mlir::Attribute
+InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) {
+ std::vector<mlir::Attribute> attrs;
+ attrs.reserve(values.size());
+ for (InFlightValue &inFlight : values) {
+ if (!inFlight) {
+ attrs.push_back(defaultFieldValue);
+ } else if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) {
+ attrs.push_back(attr);
+ } else {
+ auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight);
+ attrs.push_back(inFlightList->finalize(defaultFieldValue));
+ }
+ }
+ return mlir::ArrayAttr::get(type.getContext(), attrs);
+}
+
+bool InsertChainBackwardFolder::pushValue(mlir::Attribute val,
+ llvm::ArrayRef<int64_t> at) {
+ if (at.size() == 0 || at[0] >= static_cast<int64_t>(values.size()))
+ return false;
+ InFlightValue &inFlight = values[at[0]];
+ if (!inFlight) {
+ if (at.size() == 1) {
+ inFlight = val;
+ return true;
+ }
+ // This is the first insert to a nested field. Create a
+ // InsertChainBackwardFolder for the current element value.
+ InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back(
+ getSubElementType(type, at[0]), folderStorage);
+ inFlight = &inFlightList;
+ return inFlightList.pushValue(val, at.drop_front());
+ }
+ // Keep last inserted value if already set.
+ if (llvm::isa<mlir::Attribute>(inFlight))
+ return true;
+ auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight);
+ if (at.size() == 1) {
+ if (!llvm::isa<mlir::LLVM::ZeroAttr, mlir::LLVM::UndefAttr>(val)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "insert chain sub-element partially overwritten initial "
+ "value is not zero or undef\n");
+ return false;
+ }
+ inFlight = inFlightList->finalize(val);
+ return true;
+ }
+ return inFlightList->pushValue(val, at.drop_front());
+}
+
+mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val,
+ mlir::OpBuilder &rewriter) {
+ if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
+ return cst.getValue();
+ if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+ LLVM_DEBUG(llvm::dbgs() << "trying to fold insert chain:" << val << "\n");
+ if (auto structTy =
+ llvm::dyn_cast<mlir::LLVM::LLVMStructType>(insert.getType())) {
+ mlir::LLVM::InsertValueOp currentInsert = insert;
+ mlir::LLVM::InsertValueOp lastInsert;
+ std::deque<InsertChainBackwardFolder> folderStorage;
+ InsertChainBackwardFolder inFlightList(structTy, &folderStorage);
+ while (currentInsert) {
+ mlir::Attribute attr =
+ getAttrIfConstant(currentInsert.getValue(), rewriter);
+ if (!attr)
+ return {};
+ if (!inFlightList.pushValue(attr, currentInsert.getPosition()))
+ return {};
+ lastInsert = currentInsert;
+ currentInsert = currentInsert.getContainer()
+ .getDefiningOp<mlir::LLVM::InsertValueOp>();
+ }
+ mlir::Attribute defaultVal;
+ if (lastInsert) {
+ if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::ZeroOp>())
+ defaultVal = mlir::LLVM::ZeroAttr::get(val.getContext());
+ else if (lastInsert.getContainer().getDefiningOp<mlir::LLVM::UndefOp>())
+ defaultVal = mlir::LLVM::UndefAttr::get(val.getContext());
+ }
+ if (!defaultVal) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "insert chain initial value is not Zero or Undef\n");
+ return {};
+ }
+ return inFlightList.finalize(defaultVal);
+ }
+ }
+ return {};
+}
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index d85b38c467857..e12af7782a578 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -2365,6 +2365,21 @@ llvm::LogicalResult fir::InsertOnRangeOp::verify() {
return mlir::success();
}
+bool fir::InsertOnRangeOp::isFullRange() {
+ auto extents = getType().getShape();
+ mlir::DenseIntElementsAttr indexes = getCoor();
+ if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
+ return false;
+ auto cur_index = indexes.value_begin<int64_t>();
+ for (unsigned i = 0; i < indexes.size(); i += 2) {
+ if (*(cur_index++) != 0)
+ return false;
+ if (*(cur_index++) != extents[i / 2] - 1)
+ return false;
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// InsertValueOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/convert-and-fold-insert-on-range.fir b/flang/test/Fir/convert-and-fold-insert-on-range.fir
new file mode 100644
index 0000000000000..df18614d80b63
--- /dev/null
+++ b/flang/test/Fir/convert-and-fold-insert-on-range.fir
@@ -0,0 +1,33 @@
+// Test codegen of constant insert_on_range without symbol reference into mlir.constant.
+// RUN: fir-opt --cg-rewrite --split-input-file --fir-to-llvm-ir %s | FileCheck %s
+
+module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, "dlti.endianness" = "little", "dlti.mangling_mode" = "e", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
+ fir.global @derived_array : !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>> {
+ %c0 = arith.constant 0 : index
+ %0 = fir.undefined !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>
+ %1 = fir.zero_bits !fir.heap<!fir.array<?xf64>>
+ %2 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %3 = fir.embox %1(%2) : (!fir.heap<!fir.array<?xf64>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf64>>>
+ %4 = fir.insert_value %0, %3, ["comp", !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>] : (!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>, !fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>
+ %5 = fir.undefined !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>>
+ %6 = fir.insert_on_range %5, %4 from (0) to (1) : (!fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>>, !fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>) -> !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>>
+ fir.has_value %6 : !fir.array<2x!fir.type<sometype{comp:!fir.box<!fir.heap<!fir.array<?xf64>>>}>>
+ }
+}
+
+//CHECK-LABEL: llvm.mlir.global external @derived_array()
+//CHECK: %[[CST:.*]] = llvm.mlir.constant([
+//CHECK-SAME: [
+//CHECK-SAME: [#llvm.zero, 8, 20240719 : i32, 1 : i8, 28 : i8, 2 : i8, 0 : i8,
+//CHECK-SAME: [
+//CHECK-SAME: [1, 0 : index, 8]
+//CHECK-SAME: ]
+//CHECK-SAME: ],
+//CHECK-SAME: [
+//CHECK-SAME: [#llvm.zero, 8, 20240719 : i32, 1 : i8, 28 : i8, 2 : i8, 0 : i8,
+//CHECK-SAME: [
+//CHECK-SAME: [1, 0 : index, 8]
+//CHECK-SAME: ]
+//CHECK-SAME: ]) :
+//CHECK-SAME: !llvm.array<2 x struct<"sometype", (struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>)>>
+//CHECK: llvm.return %[[CST]] : !llvm.array<2 x struct<"sometype", (struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>)>>
>From 796a1e0269baf1c77ffabf47a8fa155356bc9096 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Mon, 19 May 2025 01:37:14 -0700
Subject: [PATCH 2/2] use map_to_vector and FailureOr
---
.../Optimizer/CodeGen/LLVMInsertChainFolder.h | 7 ++-
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 7 +--
.../CodeGen/LLVMInsertChainFolder.cpp | 54 ++++++++++---------
3 files changed, 39 insertions(+), 29 deletions(-)
diff --git a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
index d577c4c0fa70b..321bda91aa6fe 100644
--- a/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
+++ b/flang/include/flang/Optimizer/CodeGen/LLVMInsertChainFolder.h
@@ -12,6 +12,8 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Support/LogicalResult.h"
+
namespace mlir {
class Attribute;
class OpBuilder;
@@ -26,6 +28,7 @@ namespace fir {
/// or cannot be represented as an Attribute. The operations are not deleted,
/// but some llvm.insertvalue value operands may be folded with the builder on
/// the way.
-mlir::Attribute tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
- mlir::OpBuilder &builder);
+llvm::FailureOr<mlir::Attribute>
+tryFoldingLLVMInsertChain(mlir::Value insertChainResult,
+ mlir::OpBuilder &builder);
} // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ed76a77ced047..70c90fae34086 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2428,9 +2428,10 @@ struct InsertOnRangeOpConversion
// fold the inserted element value to an attribute and build an ArrayAttr
// for the resulting array.
if (range.isFullRange()) {
- if (mlir::Attribute cst =
- fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter)) {
- mlir::Attribute dimVal = cst;
+ llvm::FailureOr<mlir::Attribute> cst =
+ fir::tryFoldingLLVMInsertChain(adaptor.getVal(), rewriter);
+ if (llvm::succeeded(cst)) {
+ mlir::Attribute dimVal = *cst;
for (auto dim : llvm::reverse(dims)) {
// Use std::vector in case the number of elements is big.
std::vector<mlir::Attribute> elements(dim, dimVal);
diff --git a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
index 0fc8697b735cf..5b522f2647916 100644
--- a/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
+++ b/flang/lib/Optimizer/CodeGen/LLVMInsertChainFolder.cpp
@@ -67,7 +67,7 @@ class InsertChainBackwardFolder {
if (auto structTy =
llvm::dyn_cast_if_present<mlir::LLVM::LLVMStructType>(type))
return structTy.getBody()[field];
- return {};
+ return nullptr;
}
// Current element value of the aggregate value being built.
@@ -83,12 +83,18 @@ class InsertChainBackwardFolder {
// Helper to fold the value being inserted by an llvm.insert_value.
// This may call tryFoldingLLVMInsertChain if the value is an aggregate and
// was itself constructed by a different insert chain.
+// Returns a nullptr Attribute if the value could not be folded.
static mlir::Attribute getAttrIfConstant(mlir::Value val,
mlir::OpBuilder &rewriter) {
if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
return cst.getValue();
- if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>())
- return fir::tryFoldingLLVMInsertChain(val, rewriter);
+ if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
+ llvm::FailureOr<mlir::Attribute> attr =
+ fir::tryFoldingLLVMInsertChain(val, rewriter);
+ if (succeeded(attr))
+ return *attr;
+ return nullptr;
+ }
if (val.getDefiningOp<mlir::LLVM::ZeroOp>())
return mlir::LLVM::ZeroAttr::get(val.getContext());
if (val.getDefiningOp<mlir::LLVM::UndefOp>())
@@ -108,23 +114,20 @@ static mlir::Attribute getAttrIfConstant(mlir::Value val,
return mlir::IntegerAttr::get(trunc.getType(), intAttr.getInt());
LLVM_DEBUG(llvm::dbgs() << "cannot fold insert value operand: " << val
<< "\n");
- return {};
+ return nullptr;
}
mlir::Attribute
InsertChainBackwardFolder::finalize(mlir::Attribute defaultFieldValue) {
- std::vector<mlir::Attribute> attrs;
- attrs.reserve(values.size());
- for (InFlightValue &inFlight : values) {
- if (!inFlight) {
- attrs.push_back(defaultFieldValue);
- } else if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight)) {
- attrs.push_back(attr);
- } else {
- auto *inFlightList = llvm::cast<InsertChainBackwardFolder *>(inFlight);
- attrs.push_back(inFlightList->finalize(defaultFieldValue));
- }
- }
+ llvm::SmallVector<mlir::Attribute> attrs = llvm::map_to_vector(
+ values, [&](InFlightValue inFlight) -> mlir::Attribute {
+ if (!inFlight)
+ return defaultFieldValue;
+ if (auto attr = llvm::dyn_cast<mlir::Attribute>(inFlight))
+ return attr;
+ return llvm::cast<InsertChainBackwardFolder *>(inFlight)->finalize(
+ defaultFieldValue);
+ });
return mlir::ArrayAttr::get(type.getContext(), attrs);
}
@@ -140,8 +143,11 @@ bool InsertChainBackwardFolder::pushValue(mlir::Attribute val,
}
// This is the first insert to a nested field. Create a
// InsertChainBackwardFolder for the current element value.
- InsertChainBackwardFolder &inFlightList = folderStorage->emplace_back(
- getSubElementType(type, at[0]), folderStorage);
+ mlir::Type subType = getSubElementType(type, at[0]);
+ if (!subType)
+ return false;
+ InsertChainBackwardFolder &inFlightList =
+ folderStorage->emplace_back(subType, folderStorage);
inFlight = &inFlightList;
return inFlightList.pushValue(val, at.drop_front());
}
@@ -162,8 +168,8 @@ bool InsertChainBackwardFolder::pushValue(mlir::Attribute val,
return inFlightList->pushValue(val, at.drop_front());
}
-mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val,
- mlir::OpBuilder &rewriter) {
+llvm::FailureOr<mlir::Attribute>
+fir::tryFoldingLLVMInsertChain(mlir::Value val, mlir::OpBuilder &rewriter) {
if (auto cst = val.getDefiningOp<mlir::LLVM::ConstantOp>())
return cst.getValue();
if (auto insert = val.getDefiningOp<mlir::LLVM::InsertValueOp>()) {
@@ -178,9 +184,9 @@ mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val,
mlir::Attribute attr =
getAttrIfConstant(currentInsert.getValue(), rewriter);
if (!attr)
- return {};
+ return llvm::failure();
if (!inFlightList.pushValue(attr, currentInsert.getPosition()))
- return {};
+ return llvm::failure();
lastInsert = currentInsert;
currentInsert = currentInsert.getContainer()
.getDefiningOp<mlir::LLVM::InsertValueOp>();
@@ -195,10 +201,10 @@ mlir::Attribute fir::tryFoldingLLVMInsertChain(mlir::Value val,
if (!defaultVal) {
LLVM_DEBUG(llvm::dbgs()
<< "insert chain initial value is not Zero or Undef\n");
- return {};
+ return llvm::failure();
}
return inFlightList.finalize(defaultVal);
}
}
- return {};
+ return llvm::failure();
}
More information about the llvm-branch-commits
mailing list