[flang-commits] [flang] e398383 - [flang][fir] add codegen for fir.load of assumed-rank fir.box (#93569)

via flang-commits flang-commits at lists.llvm.org
Thu May 30 00:30:30 PDT 2024


Author: jeanPerier
Date: 2024-05-30T09:30:27+02:00
New Revision: e398383f9a05ec6f3766e5ab49dd862a72325ba6

URL: https://github.com/llvm/llvm-project/commit/e398383f9a05ec6f3766e5ab49dd862a72325ba6
DIFF: https://github.com/llvm/llvm-project/commit/e398383f9a05ec6f3766e5ab49dd862a72325ba6.diff

LOG: [flang][fir] add codegen for fir.load of assumed-rank fir.box (#93569)

- Update LLVM type conversion of assumed-rank fir.box/class to generate
the type of the maximum ranked descriptor. That way, alloca for assumed
rank descriptor copies are always big enough. This is needed in the
fir.load case that generates a new storage for the value
- Add a "computeBoxSize" helper to compute the dynamic size of a
descriptor.
- Use that size to generate an llvm.memcpy intrinsic to copy the input
descriptor into the new storage.

Looking at https://reviews.llvm.org/D108221?id=404635, it seems valid to
add the TBAA node on the memcpy, which I did.

In a further patch, I think we should likely always use a memcpy since
LLVM seems to have a better time optimizing it than fir.load/fir.store
patterns.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
    flang/include/flang/Optimizer/CodeGen/TypeConverter.h
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
    flang/lib/Optimizer/CodeGen/TypeConverter.cpp
    flang/test/Fir/convert-to-llvm.fir
    flang/test/Fir/tbaa.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 510ff72998914..211acdc8a38e6 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -125,6 +125,12 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
                                    mlir::ConversionPatternRewriter &rewriter,
                                    unsigned maskValue) const;
 
+  /// Compute the descriptor size in bytes. The result is not guaranteed to be a
+  /// compile time constant if the box is for an assumed rank, in which case the
+  /// box rank will be read.
+  mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
+                             mlir::ConversionPatternRewriter &rewriter) const;
+
   template <typename... ARGS>
   mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
                            mlir::ConversionPatternRewriter &rewriter,

diff  --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 79b3bfe4e80e0..58803a5cc4044 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -123,10 +123,16 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
                      mlir::Type baseFIRType, mlir::Type accessFIRType,
                      mlir::LLVM::GEPOp gep) const;
 
+  const mlir::DataLayout &getDataLayout() const {
+    assert(dataLayout && "must be set in ctor");
+    return *dataLayout;
+  }
+
 private:
   KindMapping kindMapping;
   std::unique_ptr<CodeGenSpecifics> specifics;
   std::unique_ptr<TBAABuilder> tbaaBuilder;
+  const mlir::DataLayout *dataLayout;
 };
 
 } // namespace fir

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 664453ebaf2f7..59aa9216b707f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2863,23 +2863,32 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       // descriptor value into a new descriptor temp.
       auto inputBoxStorage = adaptor.getOperands()[0];
       mlir::Location loc = load.getLoc();
-      fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
-      // fir.box of assumed rank do not have a storage
-      // size that is know at compile time. The copy needs to be runtime driven
-      // depending on the actual dynamic rank or type.
-      if (seqTy && seqTy.hasUnknownShape())
-        TODO(loc, "loading or assumed rank fir.box");
-      auto boxValue =
-          rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy, inputBoxStorage);
-      if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
-        boxValue.setTBAATags(*optionalTag);
-      else
-        attachTBAATag(boxValue, boxTy, boxTy, nullptr);
       auto newBoxStorage =
           genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
-      auto storeOp =
-          rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
-      attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      // TODO: always generate llvm.memcpy, LLVM is better at optimizing it than
+      // aggregate loads + stores.
+      if (boxTy.isAssumedRank()) {
+
+        TypePair boxTypePair{boxTy, llvmLoadTy};
+        mlir::Value boxSize =
+            computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
+        auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
+            loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          memcpy.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(memcpy, boxTy, boxTy, nullptr);
+      } else {
+        auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy,
+                                                            inputBoxStorage);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          boxValue.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(boxValue, boxTy, boxTy, nullptr);
+        auto storeOp =
+            rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
+        attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      }
       rewriter.replaceOp(load, newBoxStorage);
     } else {
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(

diff  --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 8c726d547491a..72e072db37432 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -240,6 +240,37 @@ mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
                                              maskRes, c0);
 }
 
+mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
+    mlir::Location loc, TypePair boxTy, mlir::Value box,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
+  assert(firBoxType && "must be a BaseBoxType");
+  const mlir::DataLayout &dl = lowerTy().getDataLayout();
+  if (!firBoxType.isAssumedRank())
+    return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
+  fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
+  mlir::Type llvmScalarBoxType =
+      lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+  llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
+  mlir::Value scalarBoxSize =
+      genConstantOffset(loc, rewriter, scalarBoxSizeCst);
+  mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
+  mlir::Value rank =
+      integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
+  mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
+  llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
+  assert((scalarBoxSizeCst + sizePerDimCst ==
+          dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
+              firBoxType.getBoxTypeWithNewShape(1)))) &&
+         "descriptor layout requires adding padding for dim field");
+  mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
+  mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
+      loc, sizePerDim.getType(), sizePerDim, rank);
+  mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
+      loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
+  return size;
+}
+
 // Find the Block in which the alloca should be inserted.
 // The order to recursively find the proper block:
 // 1. An OpenMP Op that will be outlined.

diff  --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 729ece6fc1774..07d3bd713ce45 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -14,6 +14,7 @@
 
 #include "flang/Optimizer/CodeGen/TypeConverter.h"
 #include "DescriptorModel.h"
+#include "flang/Common/Fortran.h"
 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
 #include "flang/Optimizer/CodeGen/TBAABuilder.h"
 #include "flang/Optimizer/CodeGen/Target.h"
@@ -36,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
           module.getContext(), getTargetTriple(module), getKindMapping(module),
           getTargetCPU(module), getTargetFeatures(module), dl)),
       tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
-                                                forceUnifiedTBAATree)) {
+                                                forceUnifiedTBAATree)),
+      dataLayout{&dl} {
   LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
 
   // Each conversion should return a value of type mlir::Type.
@@ -243,7 +245,10 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
   // [dims]
   if (rank == unknownRank()) {
     if (auto seqTy = mlir::dyn_cast<SequenceType>(ele))
-      rank = seqTy.getDimension();
+      if (seqTy.hasUnknownShape())
+        rank = Fortran::common::maxRank;
+      else
+        rank = seqTy.getDimension();
     else
       rank = 0;
   }

diff  --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 369d4bd3029bc..81810aa4bfc74 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -931,6 +931,31 @@ func.func @test_load_box(%addr : !fir.ref<!fir.box<!fir.array<10xf32>>>) {
 
 // -----
 
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// GENERIC:         %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// AMDGPU:          %[[VAL_2A:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+// AMDGPU:          %[[VAL_2:.*]] = llvm.addrspacecast %[[VAL_2A]] : !llvm.ptr<5> to !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
 // Test `fir.box_rank` conversion.
 
 func.func @extract_rank(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {

diff  --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index f4f23d35cba25..5800e608da41d 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -247,7 +247,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i32 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i8 to i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
@@ -267,7 +267,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i8
 // CHECK:           %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i8
@@ -307,7 +307,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]]  : i32
@@ -379,3 +379,31 @@ func.func @tbaa(%arg0: !fir.ref<!fir.array<2x!fir.type<_QMtypesTt{x:!fir.box<!fi
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK: llvm.load{{.*}}{tbaa = [#[[$ANYT]]]}
 // CHECK: llvm.store{{.*}}{tbaa = [#[[$ANYT]]]}
+
+// -----
+
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-DAG:     #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG:     #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG:     #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG:     #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()


        


More information about the flang-commits mailing list