[flang-commits] [flang] [flang][fir] add codegen for fir.load of assumed-rank fir.box (PR #93569)
via flang-commits
flang-commits at lists.llvm.org
Wed May 29 01:46:12 PDT 2024
https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/93569
>From 7bd4401ef4e4ccf4c52b1739157f0d36e1a1c959 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 28 May 2024 09:03:53 -0700
Subject: [PATCH 1/2] [flang][fir] add codegen for fir.load of assumed-rank
fir.box
---
.../flang/Optimizer/CodeGen/FIROpPatterns.h | 6 +++
.../flang/Optimizer/CodeGen/TypeConverter.h | 6 +++
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 39 ++++++++++++-------
flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp | 26 +++++++++++++
flang/lib/Optimizer/CodeGen/TypeConverter.cpp | 9 ++++-
flang/test/Fir/convert-to-llvm.fir | 25 ++++++++++++
flang/test/Fir/tbaa.fir | 34 ++++++++++++++--
7 files changed, 125 insertions(+), 20 deletions(-)
diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 510ff72998914..211acdc8a38e6 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -125,6 +125,12 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
mlir::ConversionPatternRewriter &rewriter,
unsigned maskValue) const;
+ /// Compute the descriptor size in bytes. The result is not guaranteed to be a
+ /// compile time constant if the box is for an assumed rank, in which case the
+ /// box rank will be read.
+ mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
+ mlir::ConversionPatternRewriter &rewriter) const;
+
template <typename... ARGS>
mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
mlir::ConversionPatternRewriter &rewriter,
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 79b3bfe4e80e0..58803a5cc4044 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -123,10 +123,16 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
mlir::Type baseFIRType, mlir::Type accessFIRType,
mlir::LLVM::GEPOp gep) const;
+ const mlir::DataLayout &getDataLayout() const {
+ assert(dataLayout && "must be set in ctor");
+ return *dataLayout;
+ }
+
private:
KindMapping kindMapping;
std::unique_ptr<CodeGenSpecifics> specifics;
std::unique_ptr<TBAABuilder> tbaaBuilder;
+ const mlir::DataLayout *dataLayout;
};
} // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 664453ebaf2f7..59aa9216b707f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2863,23 +2863,32 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
// descriptor value into a new descriptor temp.
auto inputBoxStorage = adaptor.getOperands()[0];
mlir::Location loc = load.getLoc();
- fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
- // fir.box of assumed rank do not have a storage
- // size that is know at compile time. The copy needs to be runtime driven
- // depending on the actual dynamic rank or type.
- if (seqTy && seqTy.hasUnknownShape())
- TODO(loc, "loading or assumed rank fir.box");
- auto boxValue =
- rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy, inputBoxStorage);
- if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
- boxValue.setTBAATags(*optionalTag);
- else
- attachTBAATag(boxValue, boxTy, boxTy, nullptr);
auto newBoxStorage =
genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
- auto storeOp =
- rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
- attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+ // TODO: always generate llvm.memcpy, LLVM is better at optimizing it than
+ // aggregate loads + stores.
+ if (boxTy.isAssumedRank()) {
+
+ TypePair boxTypePair{boxTy, llvmLoadTy};
+ mlir::Value boxSize =
+ computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
+ auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
+ loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+ if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+ memcpy.setTBAATags(*optionalTag);
+ else
+ attachTBAATag(memcpy, boxTy, boxTy, nullptr);
+ } else {
+ auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy,
+ inputBoxStorage);
+ if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+ boxValue.setTBAATags(*optionalTag);
+ else
+ attachTBAATag(boxValue, boxTy, boxTy, nullptr);
+ auto storeOp =
+ rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
+ attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+ }
rewriter.replaceOp(load, newBoxStorage);
} else {
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
diff --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 8c726d547491a..41a55565fd025 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -240,6 +240,32 @@ mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
maskRes, c0);
}
+mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
+ mlir::Location loc, TypePair boxTy, mlir::Value box,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
+ assert(firBoxType && "must be a BaseBoxType");
+ const mlir::DataLayout &dl = lowerTy().getDataLayout();
+ if (!firBoxType.isAssumedRank())
+ return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
+ fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
+ mlir::Type llvmScalarBoxType =
+ lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+ mlir::Value scalarBoxSize =
+ genConstantOffset(loc, rewriter, dl.getTypeSize(llvmScalarBoxType));
+ mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
+ mlir::Value rank =
+ integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
+ mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
+ mlir::Value sizePerDim =
+ genConstantOffset(loc, rewriter, dl.getTypeSize(llvmDimsType));
+ mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
+ loc, sizePerDim.getType(), sizePerDim, rank);
+ mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
+ loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
+ return size;
+}
+
// Find the Block in which the alloca should be inserted.
// The order to recursively find the proper block:
// 1. An OpenMP Op that will be outlined.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 729ece6fc1774..07d3bd713ce45 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -14,6 +14,7 @@
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "DescriptorModel.h"
+#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
#include "flang/Optimizer/CodeGen/TBAABuilder.h"
#include "flang/Optimizer/CodeGen/Target.h"
@@ -36,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
module.getContext(), getTargetTriple(module), getKindMapping(module),
getTargetCPU(module), getTargetFeatures(module), dl)),
tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
- forceUnifiedTBAATree)) {
+ forceUnifiedTBAATree)),
+ dataLayout{&dl} {
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
// Each conversion should return a value of type mlir::Type.
@@ -243,7 +245,10 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
// [dims]
if (rank == unknownRank()) {
if (auto seqTy = mlir::dyn_cast<SequenceType>(ele))
- rank = seqTy.getDimension();
+ if (seqTy.hasUnknownShape())
+ rank = Fortran::common::maxRank;
+ else
+ rank = seqTy.getDimension();
else
rank = 0;
}
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 70cb0443e9a64..9ae9909f4ed44 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -931,6 +931,31 @@ func.func @test_load_box(%addr : !fir.ref<!fir.box<!fir.array<10xf32>>>) {
// -----
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+ %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+ fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+ return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-LABEL: llvm.func @test_assumed_rank_load(
+// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// GENERIC: %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// AMDGPU: %[[VAL_2A:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+// AMDGPU: %[[VAL_2:.*]] = llvm.addrspacecast %[[VAL_2A]] : !llvm.ptr<5> to !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> i8
+// CHECK: %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
// Test `fir.box_rank` conversion.
func.func @extract_rank(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index f4f23d35cba25..5800e608da41d 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -247,7 +247,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i32 {
-// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
// CHECK: %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i8 to i32
// CHECK: llvm.return %[[VAL_3]] : i32
@@ -267,7 +267,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i8
// CHECK: %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i8
@@ -307,7 +307,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
// CHECK-LABEL: llvm.func @tbaa(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i32
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]] : i32
@@ -379,3 +379,31 @@ func.func @tbaa(%arg0: !fir.ref<!fir.array<2x!fir.type<_QMtypesTt{x:!fir.box<!fi
// CHECK-LABEL: llvm.func @tbaa(
// CHECK: llvm.load{{.*}}{tbaa = [#[[$ANYT]]]}
// CHECK: llvm.store{{.*}}{tbaa = [#[[$ANYT]]]}
+
+// -----
+
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+ %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+ fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+ return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-DAG: #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG: #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG: #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG: #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL: llvm.func @test_assumed_rank_load(
+// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK: %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
>From c191f79e5ad32da8c2acffd5a16be4271e1e228d Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 29 May 2024 01:45:06 -0700
Subject: [PATCH 2/2] add assert to ensure no padding is required
---
flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 41a55565fd025..72e072db37432 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -251,14 +251,19 @@ mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
mlir::Type llvmScalarBoxType =
lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+ llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
mlir::Value scalarBoxSize =
- genConstantOffset(loc, rewriter, dl.getTypeSize(llvmScalarBoxType));
+ genConstantOffset(loc, rewriter, scalarBoxSizeCst);
mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
mlir::Value rank =
integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
- mlir::Value sizePerDim =
- genConstantOffset(loc, rewriter, dl.getTypeSize(llvmDimsType));
+ llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
+ assert((scalarBoxSizeCst + sizePerDimCst ==
+ dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
+ firBoxType.getBoxTypeWithNewShape(1)))) &&
+ "descriptor layout requires adding padding for dim field");
+ mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
loc, sizePerDim.getType(), sizePerDim, rank);
mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
More information about the flang-commits
mailing list