[flang-commits] [flang] [flang][fir] add codegen for fir.load of assumed-rank fir.box (PR #93569)

via flang-commits flang-commits at lists.llvm.org
Tue May 28 09:14:30 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-codegen

Author: None (jeanPerier)

<details>
<summary>Changes</summary>

- Update LLVM type conversion of assumed-rank fir.box/class to generate the type of the maximum ranked descriptor. That way, alloca for assumed rank descriptor copies are always big enough. This is needed in the fir.load case that generates a new storage for the value
- Add a "computeBoxSize" helper to compute the dynamic size of a descriptor.
- Use that size to generate an llvm.memcpy intrinsic to copy the input descriptor into the new storage.

Looking at https://reviews.llvm.org/D108221?id=404635, it seems valid to add the TBAA node on the memcpy, which I did.

In a further patch, I think we should likely always use a memcpy since LLVM seems to have a better time optimizing it than fir.load/fir.store patterns.

---
Full diff: https://github.com/llvm/llvm-project/pull/93569.diff


7 Files Affected:

- (modified) flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h (+6) 
- (modified) flang/include/flang/Optimizer/CodeGen/TypeConverter.h (+6) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+24-15) 
- (modified) flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp (+26) 
- (modified) flang/lib/Optimizer/CodeGen/TypeConverter.cpp (+7-2) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+25) 
- (modified) flang/test/Fir/tbaa.fir (+31-3) 


``````````diff
diff --git a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
index 510ff72998914..211acdc8a38e6 100644
--- a/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
+++ b/flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
@@ -125,6 +125,12 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
                                    mlir::ConversionPatternRewriter &rewriter,
                                    unsigned maskValue) const;
 
+  /// Compute the descriptor size in bytes. The result is not guaranteed to be a
+  /// compile time constant if the box is for an assumed rank, in which case the
+  /// box rank will be read.
+  mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
+                             mlir::ConversionPatternRewriter &rewriter) const;
+
   template <typename... ARGS>
   mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
                            mlir::ConversionPatternRewriter &rewriter,
diff --git a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
index 79b3bfe4e80e0..58803a5cc4044 100644
--- a/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
+++ b/flang/include/flang/Optimizer/CodeGen/TypeConverter.h
@@ -123,10 +123,16 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
                      mlir::Type baseFIRType, mlir::Type accessFIRType,
                      mlir::LLVM::GEPOp gep) const;
 
+  const mlir::DataLayout &getDataLayout() const {
+    assert(dataLayout && "must be set in ctor");
+    return *dataLayout;
+  }
+
 private:
   KindMapping kindMapping;
   std::unique_ptr<CodeGenSpecifics> specifics;
   std::unique_ptr<TBAABuilder> tbaaBuilder;
+  const mlir::DataLayout *dataLayout;
 };
 
 } // namespace fir
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 664453ebaf2f7..59aa9216b707f 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2863,23 +2863,32 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
       // descriptor value into a new descriptor temp.
       auto inputBoxStorage = adaptor.getOperands()[0];
       mlir::Location loc = load.getLoc();
-      fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
-      // fir.box of assumed rank do not have a storage
-      // size that is know at compile time. The copy needs to be runtime driven
-      // depending on the actual dynamic rank or type.
-      if (seqTy && seqTy.hasUnknownShape())
-        TODO(loc, "loading or assumed rank fir.box");
-      auto boxValue =
-          rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy, inputBoxStorage);
-      if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
-        boxValue.setTBAATags(*optionalTag);
-      else
-        attachTBAATag(boxValue, boxTy, boxTy, nullptr);
       auto newBoxStorage =
           genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
-      auto storeOp =
-          rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
-      attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      // TODO: always generate llvm.memcpy, LLVM is better at optimizing it than
+      // aggregate loads + stores.
+      if (boxTy.isAssumedRank()) {
+
+        TypePair boxTypePair{boxTy, llvmLoadTy};
+        mlir::Value boxSize =
+            computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
+        auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
+            loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          memcpy.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(memcpy, boxTy, boxTy, nullptr);
+      } else {
+        auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy,
+                                                            inputBoxStorage);
+        if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
+          boxValue.setTBAATags(*optionalTag);
+        else
+          attachTBAATag(boxValue, boxTy, boxTy, nullptr);
+        auto storeOp =
+            rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
+        attachTBAATag(storeOp, boxTy, boxTy, nullptr);
+      }
       rewriter.replaceOp(load, newBoxStorage);
     } else {
       auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(
diff --git a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
index 8c726d547491a..41a55565fd025 100644
--- a/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
+++ b/flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp
@@ -240,6 +240,32 @@ mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
                                              maskRes, c0);
 }
 
+mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
+    mlir::Location loc, TypePair boxTy, mlir::Value box,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
+  assert(firBoxType && "must be a BaseBoxType");
+  const mlir::DataLayout &dl = lowerTy().getDataLayout();
+  if (!firBoxType.isAssumedRank())
+    return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
+  fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
+  mlir::Type llvmScalarBoxType =
+      lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
+  mlir::Value scalarBoxSize =
+      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmScalarBoxType));
+  mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
+  mlir::Value rank =
+      integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
+  mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
+  mlir::Value sizePerDim =
+      genConstantOffset(loc, rewriter, dl.getTypeSize(llvmDimsType));
+  mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
+      loc, sizePerDim.getType(), sizePerDim, rank);
+  mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
+      loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
+  return size;
+}
+
 // Find the Block in which the alloca should be inserted.
 // The order to recursively find the proper block:
 // 1. An OpenMP Op that will be outlined.
diff --git a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
index 729ece6fc1774..07d3bd713ce45 100644
--- a/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
+++ b/flang/lib/Optimizer/CodeGen/TypeConverter.cpp
@@ -14,6 +14,7 @@
 
 #include "flang/Optimizer/CodeGen/TypeConverter.h"
 #include "DescriptorModel.h"
+#include "flang/Common/Fortran.h"
 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
 #include "flang/Optimizer/CodeGen/TBAABuilder.h"
 #include "flang/Optimizer/CodeGen/Target.h"
@@ -36,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
           module.getContext(), getTargetTriple(module), getKindMapping(module),
           getTargetCPU(module), getTargetFeatures(module), dl)),
       tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
-                                                forceUnifiedTBAATree)) {
+                                                forceUnifiedTBAATree)),
+      dataLayout{&dl} {
   LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
 
   // Each conversion should return a value of type mlir::Type.
@@ -243,7 +245,10 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
   // [dims]
   if (rank == unknownRank()) {
     if (auto seqTy = mlir::dyn_cast<SequenceType>(ele))
-      rank = seqTy.getDimension();
+      if (seqTy.hasUnknownShape())
+        rank = Fortran::common::maxRank;
+      else
+        rank = seqTy.getDimension();
     else
       rank = 0;
   }
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 70cb0443e9a64..9ae9909f4ed44 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -931,6 +931,31 @@ func.func @test_load_box(%addr : !fir.ref<!fir.box<!fir.array<10xf32>>>) {
 
 // -----
 
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// GENERIC:         %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// AMDGPU:          %[[VAL_2A:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
+// AMDGPU:          %[[VAL_2:.*]] = llvm.addrspacecast %[[VAL_2A]] : !llvm.ptr<5> to !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
+
+// -----
+
 // Test `fir.box_rank` conversion.
 
 func.func @extract_rank(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
diff --git a/flang/test/Fir/tbaa.fir b/flang/test/Fir/tbaa.fir
index f4f23d35cba25..5800e608da41d 100644
--- a/flang/test/Fir/tbaa.fir
+++ b/flang/test/Fir/tbaa.fir
@@ -247,7 +247,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i32 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i8 to i32
 // CHECK:           llvm.return %[[VAL_3]] : i32
@@ -267,7 +267,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i8
 // CHECK:           %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i8
@@ -307,7 +307,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
 
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK-SAME:                    %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
-// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
+// CHECK:           %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
 // CHECK:           %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i32
 // CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
 // CHECK:           %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]]  : i32
@@ -379,3 +379,31 @@ func.func @tbaa(%arg0: !fir.ref<!fir.array<2x!fir.type<_QMtypesTt{x:!fir.box<!fi
 // CHECK-LABEL:   llvm.func @tbaa(
 // CHECK: llvm.load{{.*}}{tbaa = [#[[$ANYT]]]}
 // CHECK: llvm.store{{.*}}{tbaa = [#[[$ANYT]]]}
+
+// -----
+
+func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
+  %0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
+  fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
+  return
+}
+func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
+
+// CHECK-DAG:     #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
+// CHECK-DAG:     #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
+// CHECK-DAG:     #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
+// CHECK-DAG:     #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
+
+// CHECK-LABEL:   llvm.func @test_assumed_rank_load(
+// CHECK-SAME:                                      %[[VAL_0:.*]]: !llvm.ptr) {
+// CHECK:           %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK:           %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
+// CHECK:           %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
+// CHECK:           %[[VAL_5:.*]] = llvm.load %[[VAL_4]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
+// CHECK:           %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
+// CHECK:           %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
+// CHECK:           %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
+// CHECK:           %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
+// CHECK:           "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
+// CHECK:           llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()

``````````

</details>


https://github.com/llvm/llvm-project/pull/93569


More information about the flang-commits mailing list