[flang-commits] [flang] ea55503 - [fir] Add fir.extract_value and fir.insert_value conversion
Valentin Clement via flang-commits
flang-commits at lists.llvm.org
Fri Nov 5 07:53:49 PDT 2021
Author: Valentin Clement
Date: 2021-11-05T15:53:42+01:00
New Revision: ea55503d7ca5b6ded73b0fd01a8c528f68e00521
URL: https://github.com/llvm/llvm-project/commit/ea55503d7ca5b6ded73b0fd01a8c528f68e00521
DIFF: https://github.com/llvm/llvm-project/commit/ea55503d7ca5b6ded73b0fd01a8c528f68e00521.diff
LOG: [fir] Add fir.extract_value and fir.insert_value conversion
This patch add the conversion pattern for fir.extract_value
and fir.insert_value. fir.extract_value is lowered to llvm.extractvalue
anf fir.insert_value is lowered to llvm.insertvalue.
This patch also adds the type conversion for the BoxType and RecordType
needed to have some comprehensive tests.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: awarzynski
Differential Revision: https://reviews.llvm.org/D112961
Co-authored-by: Jean Perier <jperier at nvidia.com>
Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>
Added:
flang/lib/Optimizer/CodeGen/DescriptorModel.h
Modified:
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/CodeGen/TypeConverter.h
flang/test/Fir/convert-to-llvm.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index e18bdc2bdbf4..ba22f83d91ef 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -296,6 +296,105 @@ struct ZeroOpConversion : public FIROpConversion<fir::ZeroOp> {
}
};
+// Code shared between insert_value and extract_value Ops.
+struct ValueOpCommon {
+ // Translate the arguments pertaining to any multidimensional array to
+ // row-major order for LLVM-IR.
+ static void toRowMajor(SmallVectorImpl<mlir::Attribute> &attrs,
+ mlir::Type ty) {
+ assert(ty && "type is null");
+ const auto end = attrs.size();
+ for (std::remove_const_t<decltype(end)> i = 0; i < end; ++i) {
+ if (auto seq = ty.dyn_cast<mlir::LLVM::LLVMArrayType>()) {
+ const auto dim = getDimension(seq);
+ if (dim > 1) {
+ auto ub = std::min(i + dim, end);
+ std::reverse(attrs.begin() + i, attrs.begin() + ub);
+ i += dim - 1;
+ }
+ ty = getArrayElementType(seq);
+ } else if (auto st = ty.dyn_cast<mlir::LLVM::LLVMStructType>()) {
+ ty = st.getBody()[attrs[i].cast<mlir::IntegerAttr>().getInt()];
+ } else {
+ llvm_unreachable("index into invalid type");
+ }
+ }
+ }
+
+ static llvm::SmallVector<mlir::Attribute>
+ collectIndices(mlir::ConversionPatternRewriter &rewriter,
+ mlir::ArrayAttr arrAttr) {
+ llvm::SmallVector<mlir::Attribute> attrs;
+ for (auto i = arrAttr.begin(), e = arrAttr.end(); i != e; ++i) {
+ if (i->isa<mlir::IntegerAttr>()) {
+ attrs.push_back(*i);
+ } else {
+ auto fieldName = i->cast<mlir::StringAttr>().getValue();
+ ++i;
+ auto ty = i->cast<mlir::TypeAttr>().getValue();
+ auto index = ty.cast<fir::RecordType>().getFieldIndex(fieldName);
+ attrs.push_back(mlir::IntegerAttr::get(rewriter.getI32Type(), index));
+ }
+ }
+ return attrs;
+ }
+
+private:
+ static unsigned getDimension(mlir::LLVM::LLVMArrayType ty) {
+ unsigned result = 1;
+ for (auto eleTy = ty.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>();
+ eleTy;
+ eleTy = eleTy.getElementType().dyn_cast<mlir::LLVM::LLVMArrayType>())
+ ++result;
+ return result;
+ }
+
+ static mlir::Type getArrayElementType(mlir::LLVM::LLVMArrayType ty) {
+ auto eleTy = ty.getElementType();
+ while (auto arrTy = eleTy.dyn_cast<mlir::LLVM::LLVMArrayType>())
+ eleTy = arrTy.getElementType();
+ return eleTy;
+ }
+};
+
+/// Extract a subobject value from an ssa-value of aggregate type
+struct ExtractValueOpConversion
+ : public FIROpAndTypeConversion<fir::ExtractValueOp>,
+ public ValueOpCommon {
+ using FIROpAndTypeConversion::FIROpAndTypeConversion;
+
+ mlir::LogicalResult
+ doRewrite(fir::ExtractValueOp extractVal, mlir::Type ty, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto attrs = collectIndices(rewriter, extractVal.coor());
+ toRowMajor(attrs, adaptor.getOperands()[0].getType());
+ auto position = mlir::ArrayAttr::get(extractVal.getContext(), attrs);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
+ extractVal, ty, adaptor.getOperands()[0], position);
+ return success();
+ }
+};
+
+/// InsertValue is the generalized instruction for the composition of new
+/// aggregate type values.
+struct InsertValueOpConversion
+ : public FIROpAndTypeConversion<fir::InsertValueOp>,
+ public ValueOpCommon {
+ using FIROpAndTypeConversion::FIROpAndTypeConversion;
+
+ mlir::LogicalResult
+ doRewrite(fir::InsertValueOp insertVal, mlir::Type ty, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ auto attrs = collectIndices(rewriter, insertVal.coor());
+ toRowMajor(attrs, adaptor.getOperands()[0].getType());
+ auto position = mlir::ArrayAttr::get(insertVal.getContext(), attrs);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(
+ insertVal, ty, adaptor.getOperands()[0], adaptor.getOperands()[1],
+ position);
+ return success();
+ }
+};
+
/// InsertOnRange inserts a value into a sequence over a range of offsets.
struct InsertOnRangeOpConversion
: public FIROpAndTypeConversion<fir::InsertOnRangeOp> {
@@ -389,10 +488,11 @@ class FIRToLLVMLowering : public fir::FIRToLLVMLoweringBase<FIRToLLVMLowering> {
auto *context = getModule().getContext();
fir::LLVMTypeConverter typeConverter{getModule()};
mlir::OwningRewritePatternList pattern(context);
- pattern.insert<AddrOfOpConversion, HasValueOpConversion, GlobalOpConversion,
- InsertOnRangeOpConversion, SelectOpConversion,
- SelectRankOpConversion, UnreachableOpConversion,
- ZeroOpConversion, UndefOpConversion>(typeConverter);
+ pattern.insert<
+ AddrOfOpConversion, ExtractValueOpConversion, HasValueOpConversion,
+ GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion,
+ SelectOpConversion, SelectRankOpConversion, UndefOpConversion,
+ UnreachableOpConversion, ZeroOpConversion>(typeConverter);
mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern);
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
pattern);
diff --git a/flang/lib/Optimizer/CodeGen/DescriptorModel.h b/flang/lib/Optimizer/CodeGen/DescriptorModel.h
new file mode 100644
index 000000000000..6c357ab85e8f
--- /dev/null
+++ b/flang/lib/Optimizer/CodeGen/DescriptorModel.h
@@ -0,0 +1,141 @@
+//===-- DescriptorModel.h -- model of descriptors for codegen ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// LLVM IR dialect models of C++ types.
+//
+// This supplies a set of model builders to decompose the C declaration of a
+// descriptor (as encoded in ISO_Fortran_binding.h and elsewhere) and
+// reconstruct that type in the LLVM IR dialect.
+//
+// TODO: It is understood that this is deeply incorrect as far as building a
+// portability layer for cross-compilation as these reflected types are those of
+// the build machine and not necessarily that of either the host or the target.
+// This assumption that build == host == target is actually pervasive across the
+// compiler (https://llvm.org/PR52418).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef OPTIMIZER_DESCRIPTOR_MODEL_H
+#define OPTIMIZER_DESCRIPTOR_MODEL_H
+
+#include "flang/ISO_Fortran_binding.h"
+#include "flang/Runtime/descriptor.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "llvm/Support/ErrorHandling.h"
+#include <tuple>
+
+namespace fir {
+
+using TypeBuilderFunc = mlir::Type (*)(mlir::MLIRContext *);
+
+/// Get the LLVM IR dialect model for building a particular C++ type, `T`.
+template <typename T>
+TypeBuilderFunc getModel();
+
+template <>
+TypeBuilderFunc getModel<void *>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(context, 8));
+ };
+}
+template <>
+TypeBuilderFunc getModel<unsigned>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, sizeof(unsigned) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<int>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, sizeof(int) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<unsigned long>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, sizeof(unsigned long) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<unsigned long long>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, sizeof(unsigned long long) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<long long>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context, sizeof(long long) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<Fortran::ISO::CFI_rank_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context,
+ sizeof(Fortran::ISO::CFI_rank_t) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<Fortran::ISO::CFI_type_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context,
+ sizeof(Fortran::ISO::CFI_type_t) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<Fortran::ISO::CFI_index_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ return mlir::IntegerType::get(context,
+ sizeof(Fortran::ISO::CFI_index_t) * 8);
+ };
+}
+template <>
+TypeBuilderFunc getModel<Fortran::ISO::CFI_dim_t>() {
+ return [](mlir::MLIRContext *context) -> mlir::Type {
+ auto indexTy = getModel<Fortran::ISO::CFI_index_t>()(context);
+ return mlir::LLVM::LLVMArrayType::get(indexTy, 3);
+ };
+}
+template <>
+TypeBuilderFunc
+getModel<Fortran::ISO::cfi_internal::FlexibleArray<Fortran::ISO::CFI_dim_t>>() {
+ return getModel<Fortran::ISO::CFI_dim_t>();
+}
+
+//===----------------------------------------------------------------------===//
+// Descriptor reflection
+//===----------------------------------------------------------------------===//
+
+/// Get the type model of the field number `Field` in an ISO CFI descriptor.
+template <int Field>
+static constexpr TypeBuilderFunc getDescFieldTypeModel() {
+ Fortran::ISO::Fortran_2018::CFI_cdesc_t dummyDesc{};
+ // check that the descriptor is exactly 8 fields as specified in CFI_cdesc_t
+ // in flang/include/flang/ISO_Fortran_binding.h.
+ auto [a, b, c, d, e, f, g, h] = dummyDesc;
+ auto tup = std::tie(a, b, c, d, e, f, g, h);
+ auto field = std::get<Field>(tup);
+ return getModel<decltype(field)>();
+}
+
+/// An extended descriptor is defined by a class in runtime/descriptor.h. The
+/// three fields in the class are hard-coded here, unlike the reflection used on
+/// the ISO parts, which are a POD.
+template <int Field>
+static constexpr TypeBuilderFunc getExtendedDescFieldTypeModel() {
+ if constexpr (Field == 8) {
+ return getModel<void *>();
+ } else if constexpr (Field == 9) {
+ return getModel<Fortran::runtime::typeInfo::TypeParameterValue>();
+ } else {
+ llvm_unreachable("extended ISO descriptor only has 10 fields");
+ }
+}
+
+} // namespace fir
+
+#endif // OPTIMIZER_DESCRIPTOR_MODEL_H
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.h b/flang/lib/Optimizer/CodeGen/TypeConverter.h
index fe63da90b8ec..9f0b5b299e91 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.h
@@ -13,6 +13,9 @@
#ifndef FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
#define FORTRAN_OPTIMIZER_CODEGEN_TYPECONVERTER_H
+#include "DescriptorModel.h"
+#include "flang/Lower/Todo.h" // remove when TODO's are done
+#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Debug.h"
namespace fir {
@@ -26,10 +29,35 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
// Each conversion should return a value of type mlir::Type.
+ addConversion(
+ [&](fir::RecordType derived) { return convertRecordType(derived); });
addConversion(
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
addConversion(
[&](SequenceType sequence) { return convertSequenceType(sequence); });
+ addConversion([&](mlir::TupleType tuple) {
+ LLVM_DEBUG(llvm::dbgs() << "type convert: " << tuple << '\n');
+ llvm::SmallVector<mlir::Type> inMembers;
+ tuple.getFlattenedTypes(inMembers);
+ llvm::SmallVector<mlir::Type> members;
+ for (auto mem : inMembers)
+ members.push_back(convertType(mem).cast<mlir::Type>());
+ return mlir::LLVM::LLVMStructType::getLiteral(&getContext(), members,
+ /*isPacked=*/false);
+ });
+ }
+
+ // fir.type<name(p : TY'...){f : TY...}> --> llvm<"%name = { ty... }">
+ mlir::Type convertRecordType(fir::RecordType derived) {
+ auto name = derived.getName();
+ auto st = mlir::LLVM::LLVMStructType::getIdentified(&getContext(), name);
+ llvm::SmallVector<mlir::Type> members;
+ for (auto mem : derived.getTypeList()) {
+ members.push_back(convertType(mem.second).cast<mlir::Type>());
+ }
+ if (mlir::succeeded(st.setBody(members, /*isPacked=*/false)))
+ return st;
+ return mlir::Type();
}
template <typename A>
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index b0a628678b56..80a099039711 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -259,3 +259,67 @@ func @select_rank(%arg : i32, %arg2 : i32) -> i32 {
// CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32),
// CHECK: 4: ^bb4(%[[C1]] : i32)
// CHECK: ]
+
+// -----
+
+// Test fir.extract_value operation conversion with derived type.
+
+func @extract_derived_type() -> f32 {
+ %0 = fir.undefined !fir.type<derived{f:f32}>
+ %1 = fir.extract_value %0, ["f", !fir.type<derived{f:f32}>] : (!fir.type<derived{f:f32}>) -> f32
+ return %1 : f32
+}
+
+// CHECK-LABEL: llvm.func @extract_derived_type
+// CHECK: %[[STRUCT:.*]] = llvm.mlir.undef : !llvm.struct<"derived", (f32)>
+// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[STRUCT]][0 : i32] : !llvm.struct<"derived", (f32)>
+// CHECK: llvm.return %[[VALUE]] : f32
+
+// -----
+
+// Test fir.extract_value operation conversion with a multi-dimensional array
+// of tuple.
+
+func @extract_array(%a : !fir.array<10x10xtuple<i32, f32>>) -> f32 {
+ %0 = fir.extract_value %a, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple<i32, f32>>) -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: llvm.func @extract_array(
+// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>>
+// CHECK: %[[VALUE:.*]] = llvm.extractvalue %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
+// CHECK: llvm.return %[[VALUE]] : f32
+
+// -----
+
+// Test fir.insert_value operation conversion with a multi-dimensional array
+// of tuple.
+
+func @extract_array(%a : !fir.array<10x10xtuple<i32, f32>>) {
+ %f = arith.constant 2.0 : f32
+ %i = arith.constant 1 : i32
+ %0 = fir.insert_value %a, %i, [5 : index, 4 : index, 0 : index] : (!fir.array<10x10xtuple<i32, f32>>, i32) -> !fir.array<10x10xtuple<i32, f32>>
+ %1 = fir.insert_value %a, %f, [5 : index, 4 : index, 1 : index] : (!fir.array<10x10xtuple<i32, f32>>, f32) -> !fir.array<10x10xtuple<i32, f32>>
+ return
+}
+
+// CHECK-LABEL: llvm.func @extract_array(
+// CHECK-SAME: %[[ARR:.*]]: !llvm.array<10 x array<10 x struct<(i32, f32)>>>
+// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 0 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
+// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[ARR]][4 : index, 5 : index, 1 : index] : !llvm.array<10 x array<10 x struct<(i32, f32)>>>
+// CHECK: llvm.return
+
+// -----
+
+// Test fir.insert_value operation conversion with derived type.
+
+func @insert_tuple(%a : tuple<i32, f32>) {
+ %f = arith.constant 2.0 : f32
+ %1 = fir.insert_value %a, %f, [1 : index] : (tuple<i32, f32>, f32) -> tuple<i32, f32>
+ return
+}
+
+// CHECK-LABEL: func @insert_tuple(
+// CHECK-SAME: %[[TUPLE:.*]]: !llvm.struct<(i32, f32)>
+// CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %[[TUPLE]][1 : index] : !llvm.struct<(i32, f32)>
+// CHECK: llvm.return
More information about the flang-commits
mailing list