[flang-commits] [flang] [flang][fir] add codegen for fir.load of assumed-rank fir.box (PR #93569)

via flang-commits flang-commits at lists.llvm.org
Wed May 29 01:46:12 PDT 2024


https://github.com/jeanPerier updated https://github.com/llvm/llvm-project/pull/93569

>From 7bd4401ef4e4ccf4c52b1739157f0d36e1a1c959 Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Tue, 28 May 2024 09:03:53 -0700
Subject: [PATCH 1/2] [flang][fir] add codegen for fir.load of assumed-rank
 fir.box

---
 .../flang/Optimizer/CodeGen/FIROpPatterns.h   |  6 +++
 .../flang/Optimizer/CodeGen/TypeConverter.h   |  6 +++
 flang/lib/Optimizer/CodeGen/CodeGen.cpp       | 39 ++++++++++++-------
 flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp | 26 +++++++++++++
 flang/lib/Optimizer/CodeGen/TypeConverter.cpp |  9 ++++-
 flang/test/Fir/convert-to-llvm.fir            | 25 ++++++++++++
 flang/test/Fir/tbaa.fir                       | 34 ++++++++++++++--
 7 files changed, 125 insertions(+), 20 deletions(-)

diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 510ff72998914..211acdc8a38e6 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -125,6 +125,12 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
                                    mlir::ConversionPatternRewriter &rewriter,
                                    unsigned maskValue) const;
 
+  /// Compute the descriptor size in bytes. The result is not guaranteed to be a
+  /// compile time constant if the box is for an assumed rank, in which case the
+  /// box rank will be read.
+  mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
+                             mlir::ConversionPatternRewriter &rewriter) const;
+
   template <typename... ARGS>
   mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
                            mlir::ConversionPatternRewriter &rewriter,
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 79b3bfe4e80e0..58803a5cc4044 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -123,10 +123,16 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
                      mlir::Type baseFIRType, mlir::Type accessFIRType,
                      mlir::LLVM::GEPOp gep) const;
 
+  const mlir::DataLayout &getDataLayout() const {
+    assert(dataLayout && "must be set in ctor");
+    return *dataLayout;
+  }
+
 private:
   KindMapping kindMapping;
   std::unique_ptr<CodeGenSpecifics> specifics;
   std::unique_ptr<TBAABuilder> tbaaBuilder;
+  const mlir::DataLayout *dataLayout;
 };
 
 } // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 664453ebaf2f7..59aa9216b707f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2863,23 +2863,32 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       // descriptor value into a new descriptor temp.
       auto inputBoxStorage = adaptor.getOperands()[0];
       mlir::Location loc = load.getLoc();
-      fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
-      // fir.box of assumed rank do not have a storage
-      // size that is know at compile time. The copy needs to be runtime driven
-      // depending on the actual dynamic rank or type.
-      if (seqTy && seqTy.hasUnknownShape())
-        TODO(loc, "loading or assumed rank fir.box");
-      auto boxValue =
-          rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy, inputBoxStorage);
-      if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
-        boxValue.setTBAATags(*optionalTag);
-      else
-        attachTBAATag(boxValue, boxTy, boxTy, nullptr);
       auto newBoxStorage =
           genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
-      auto storeOp =
-          rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
-      attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      // TODO: always generate llvm.memcpy, LLVM is better at optimizing it than
+      // aggregate loads + stores.
+      if (boxTy.isAssumedRank()) {
+
+        TypePair boxTypePair{boxTy, llvmLoadTy};
+        mlir::Value boxSize =
+            computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
+        auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
+            loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          memcpy.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(memcpy, boxTy, boxTy, nullptr);
+      } else {
+        auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy,
+                                                            inputBoxStorage);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          boxValue.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(boxValue, boxTy, boxTy, nullptr);
+        auto storeOp =
+            rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
+        attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      }
       rewriter.replaceOp(load, newBoxStorage);
     } else {
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
diff --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 8c726d547491a..41a55565fd025 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -240,6 +240,32 @@ mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
                                              maskRes, c0);
 }
 
+mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
+    mlir::Location loc, TypePair boxTy, mlir::Value box,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
+  assert(firBoxType && "must be a BaseBoxType");
+  const mlir::DataLayout &dl = lowerTy().getDataLayout();
+  if (!firBoxType.isAssumedRank())
+    return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
+  fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
+  mlir::Type llvmScalarBoxType =
+      lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+  mlir::Value scalarBoxSize =
+      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmScalarBoxType));
+  mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
+  mlir::Value rank =
+      integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
+  mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
+  mlir::Value sizePerDim =
+      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmDimsType));
+  mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
+      loc, sizePerDim.getType(), sizePerDim, rank);
+  mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
+      loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
+  return size;
+}
+
 // Find the Block in which the alloca should be inserted.
 // The order to recursively find the proper block:
 // 1. An OpenMP Op that will be outlined.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 729ece6fc1774..07d3bd713ce45 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -14,6 +14,7 @@
 
 #include "flang/Optimizer/CodeGen/TypeConverter.h"
 #include "DescriptorModel.h"
+#include "flang/Common/Fortran.h"
 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
 #include "flang/Optimizer/CodeGen/TBAABuilder.h"
 #include "flang/Optimizer/CodeGen/Target.h"
@@ -36,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
           module.getContext(), getTargetTriple(module), getKindMapping(module),
           getTargetCPU(module), getTargetFeatures(module), dl)),
       tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
-                                                forceUnifiedTBAATree)) {
+                                                forceUnifiedTBAATree)),
+      dataLayout{&dl} {
   LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
 
   // Each conversion should return a value of type mlir::Type.
@@ -243,7 +245,10 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
   // [dims]
   if (rank == unknownRank()) {
     if (auto seqTy = mlir::dyn_cast<SequenceType>(ele))
-      rank = seqTy.getDimension();
+      if (seqTy.hasUnknownShape())
+        rank = Fortran::common::maxRank;
+      else
+        rank = seqTy.getDimension();
     else
       rank = 0;
   }
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 70cb0443e9a64..9ae9909f4ed44 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -931,6 +931,31 @@ func.func @test_load_box(%addr : !fir.ref<!fir.box<!fir.array<10xf32>>>) {
 
 // -----
 
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// GENERIC:         %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// AMDGPU:          %[[VAL_2A:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+// AMDGPU:          %[[VAL_2:.*]] = llvm.addrspacecast %[[VAL_2A]] : !llvm.ptr<5> to !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
 // Test `fir.box_rank` conversion.
 
 func.func @extract_rank(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index f4f23d35cba25..5800e608da41d 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -247,7 +247,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i32 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i8 to i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
@@ -267,7 +267,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i8
 // CHECK:           %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i8
@@ -307,7 +307,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]]  : i32
@@ -379,3 +379,31 @@ func.func @tbaa(%arg0: !fir.ref<!fir.array<2x!fir.type<_QMtypesTt{x:!fir.box<!fi
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK: llvm.load{{.*}}{tbaa = [#[[$ANYT]]]}
 // CHECK: llvm.store{{.*}}{tbaa = [#[[$ANYT]]]}
+
+// -----
+
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-DAG:     #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG:     #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG:     #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG:     #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()

>From c191f79e5ad32da8c2acffd5a16be4271e1e228d Mon Sep 17 00:00:00 2001
From: Jean Perier <jperier at nvidia.com>
Date: Wed, 29 May 2024 01:45:06 -0700
Subject: [PATCH 2/2] add assert to ensure no padding is required

---
 flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp | 11 ++++++++---
 1 file changed, 8 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 41a55565fd025..72e072db37432 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -251,14 +251,19 @@ mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
   fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
   mlir::Type llvmScalarBoxType =
       lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+  llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
   mlir::Value scalarBoxSize =
-      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmScalarBoxType));
+      genConstantOffset(loc, rewriter, scalarBoxSizeCst);
   mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
   mlir::Value rank =
       integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
   mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
-  mlir::Value sizePerDim =
-      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmDimsType));
+  llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
+  assert((scalarBoxSizeCst + sizePerDimCst ==
+          dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
+              firBoxType.getBoxTypeWithNewShape(1)))) &&
+         "descriptor layout requires adding padding for dim field");
+  mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
   mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
       loc, sizePerDim.getType(), sizePerDim, rank);
   mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(



More information about the flang-commits mailing list