[flang-commits] [flang] 016fbc4 - [flang] Carry dynamic type in fir.rebox code generation

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon Oct 24 11:43:50 PDT 2022


Author: Valentin Clement
Date: 2022-10-24T20:43:41+02:00
New Revision: 016fbc40a5f3d8fe35cf6a87d9697709fcca1c36

URL: https://github.com/llvm/llvm-project/commit/016fbc40a5f3d8fe35cf6a87d9697709fcca1c36
DIFF: https://github.com/llvm/llvm-project/commit/016fbc40a5f3d8fe35cf6a87d9697709fcca1c36.diff

LOG: [flang] Carry dynamic type in fir.rebox code generation

Load the pointer of the dynamic type descriptor from the
original box and update the destination descriptor with this pointer.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D136618

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIROps.td
    flang/lib/Optimizer/CodeGen/CodeGen.cpp
    flang/test/Lower/allocatable-polymorphic.f90

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index cdcfc9a2f2ddd..08079435862f2 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -836,7 +836,8 @@ def fir_ReboxOp : fir_Op<"rebox", [NoMemoryEffect, AttrSizedOperandSegments]> {
   let results = (outs BoxOrClassType);
 
   let assemblyFormat = [{
-    $box (`(` $shape^ `)`)? (`[` $slice^ `]`)? attr-dict `:` functional-type(operands, results)
+    $box (`(` $shape^ `)`)? (`[` $slice^ `]`)?
+        attr-dict `:` functional-type(operands, results)
   }];
 
   let hasVerifier = 1;

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index a4ec6258646e1..d3caa60e2c272 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -91,6 +91,11 @@ static int64_t getConstantIntValue(mlir::Value val) {
   fir::emitFatalError(val.getLoc(), "must be a constant");
 }
 
+static unsigned getTypeDescFieldId(mlir::Type ty) {
+  auto isArray = fir::dyn_cast_ptrOrBoxEleTy(ty).isa<fir::SequenceType>();
+  return isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
+}
+
 namespace {
 /// FIR conversion pattern template
 template <typename FromOp>
@@ -244,6 +249,18 @@ class FIROpConversion : public mlir::ConvertOpToLLVMPattern<FromOp> {
     return getBoxEleTy(type, {kAddrPosInBox});
   }
 
+  /// Read the address of the type descriptor from a box.
+  mlir::Value
+  loadTypeDescAddress(mlir::Location loc, mlir::Type ty, mlir::Value box,
+                      mlir::ConversionPatternRewriter &rewriter) const {
+    unsigned typeDescFieldId = getTypeDescFieldId(ty);
+    mlir::Type tdescType = lowerTy().convertTypeDescType(rewriter.getContext());
+    auto pty = mlir::LLVM::LLVMPointerType::get(tdescType);
+    mlir::LLVM::GEPOp p = genGEP(loc, pty, rewriter, box, 0,
+                                 static_cast<std::int32_t>(typeDescFieldId));
+    return rewriter.create<mlir::LLVM::LoadOp>(loc, tdescType, p);
+  }
+
   // Load the attribute from the \p box and perform a check against \p maskValue
   // The final comparison is implemented as `(attribute & maskValue) != 0`.
   mlir::Value genBoxAttributeCheck(mlir::Location loc, mlir::Value box,
@@ -940,9 +957,7 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
       typeDescTy = global.getType();
     }
 
-    auto isArray = fir::dyn_cast_ptrOrBoxEleTy(passedObject.getType())
-                       .template isa<fir::SequenceType>();
-    unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
+    unsigned typeDescFieldId = getTypeDescFieldId(passedObject.getType());
 
     auto descPtr = adaptor.getOperands()[0]
                        .getType()
@@ -1458,7 +1473,8 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
   template <typename BOX>
   std::tuple<fir::BaseBoxType, mlir::Value, mlir::Value>
   consDescriptorPrefix(BOX box, mlir::ConversionPatternRewriter &rewriter,
-                       unsigned rank, mlir::ValueRange lenParams) const {
+                       unsigned rank, mlir::ValueRange lenParams,
+                       mlir::Value typeDesc = {}) const {
     auto loc = box.getLoc();
     auto boxTy = box.getType().template dyn_cast<fir::BaseBoxType>();
     auto convTy = this->lowerTy().convertBoxType(boxTy, rank);
@@ -1492,11 +1508,10 @@ struct EmboxCommonConversion : public FIROpConversion<OP> {
                     this->genI32Constant(loc, rewriter, hasAddendum ? 1 : 0));
 
     if (hasAddendum) {
-      auto isArray =
-          fir::dyn_cast_ptrOrBoxEleTy(boxTy).template isa<fir::SequenceType>();
-      unsigned typeDescFieldId = isArray ? kOptTypePtrPosInBox : kDimsPosInBox;
-      auto typeDesc =
-          getTypeDescriptor(box, rewriter, loc, unwrapIfDerived(boxTy));
+      unsigned typeDescFieldId = getTypeDescFieldId(boxTy);
+      if (!typeDesc)
+        typeDesc =
+            getTypeDescriptor(box, rewriter, loc, unwrapIfDerived(boxTy));
       descriptor =
           insertField(rewriter, loc, descriptor, {typeDescFieldId}, typeDesc,
                       /*bitCast=*/true);
@@ -1849,8 +1864,16 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
       if (recTy.getNumLenParams() != 0)
         TODO(loc, "reboxing descriptor of derived type with length parameters");
     }
-    auto [boxTy, dest, eleSize] =
-        consDescriptorPrefix(rebox, rewriter, rebox.getOutRank(), lenParams);
+
+    // Rebox on polymorphic entities needs to carry over the dynamic type.
+    mlir::Value typeDescAddr;
+    if (rebox.getBox().getType().isa<fir::ClassType>() &&
+        rebox.getType().isa<fir::ClassType>())
+      typeDescAddr = loadTypeDescAddress(loc, rebox.getBox().getType(),
+                                         loweredBox, rewriter);
+
+    auto [boxTy, dest, eleSize] = consDescriptorPrefix(
+        rebox, rewriter, rebox.getOutRank(), lenParams, typeDescAddr);
 
     // Read input extents, strides, and base address
     llvm::SmallVector<mlir::Value> inputExtents;
@@ -2835,7 +2858,7 @@ struct LoadOpConversion : public FIROpConversion<fir::LoadOp> {
   mlir::LogicalResult
   matchAndRewrite(fir::LoadOp load, OpAdaptor adaptor,
                   mlir::ConversionPatternRewriter &rewriter) const override {
-    if (auto boxTy = load.getType().dyn_cast<fir::BoxType>()) {
+    if (auto boxTy = load.getType().dyn_cast<fir::BaseBoxType>()) {
       // fir.box is a special case because it is considered as an ssa values in
       // fir, but it is lowered as a pointer to a descriptor. So
       // fir.ref<fir.box> and fir.box end up being the same llvm types and

diff  --git a/flang/test/Lower/allocatable-polymorphic.f90 b/flang/test/Lower/allocatable-polymorphic.f90
index 3883a35c74853..bb26b88ac528b 100644
--- a/flang/test/Lower/allocatable-polymorphic.f90
+++ b/flang/test/Lower/allocatable-polymorphic.f90
@@ -7,12 +7,14 @@ module poly
     integer :: b
   contains
     procedure, nopass :: proc1 => proc1_p1
+    procedure :: proc2 => proc2_p1
   end type
 
   type, extends(p1) :: p2
     integer :: c
   contains
-      procedure, nopass :: proc1 => proc1_p2
+    procedure, nopass :: proc1 => proc1_p2
+    procedure :: proc2 => proc2_p2
   end type
 
 contains
@@ -23,6 +25,16 @@ subroutine proc1_p1()
   subroutine proc1_p2()
     print*, 'call proc1_p2'
   end subroutine
+
+  subroutine proc2_p1(this)
+    class(p1) :: this
+    print*, 'call proc2_p1'
+  end subroutine
+
+  subroutine proc2_p2(this)
+    class(p2) :: this
+    print*, 'call proc2_p2'
+  end subroutine
 end module
 
 program test_allocatable
@@ -42,6 +54,9 @@ program test_allocatable
 
   call c1%proc1()
   call c2%proc1()
+
+  call c1%proc2()
+  call c2%proc2()
 end
 
 ! CHECK-LABEL: func.func @_QQmain()
@@ -85,7 +100,6 @@ program test_allocatable
 ! CHECK: %[[RANK:.*]] = arith.constant 1 : i32
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: fir.call @_FortranAAllocatableInitDerived(%[[C3_CAST]], %[[TYPE_DESC_P1_CAST]], %[[RANK]], %[[C0]]) : (!fir.ref<!fir.box<none>>, !fir.ref<none>, i32, i32) -> none
-! CHECK: %[[C1:.*]] = arith.constant 1 : index
 ! CHECK: %[[C10:.*]] = arith.constant 10 : i32
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: %[[C3_CAST:.*]] = fir.convert %[[C3]] : (!fir.ref<!fir.class<!fir.heap<!fir.array<?x!fir.type<_QMpolyTp1{a:i32,b:i32}>>>>>) -> !fir.ref<!fir.box<none>>
@@ -101,22 +115,34 @@ program test_allocatable
 ! CHECK: %[[RANK:.*]] = arith.constant 1 : i32
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: fir.call @_FortranAAllocatableInitDerived(%[[C4_CAST]], %[[TYPE_DESC_P2_CAST]], %[[RANK]], %[[C0]]) : (!fir.ref<!fir.box<none>>, !fir.ref<none>, i32, i32) -> none
-! CHECK: %[[C1:.*]] = arith.constant 1 : index
+! CHECK: %[[CST1:.*]] = arith.constant 1 : index
 ! CHECK: %[[C20:.*]] = arith.constant 20 : i32
 ! CHECK: %[[C0:.*]] = arith.constant 0 : i32
 ! CHECK: %[[C4_CAST:.*]] = fir.convert %[[C4]] : (!fir.ref<!fir.class<!fir.heap<!fir.array<?x!fir.type<_QMpolyTp1{a:i32,b:i32}>>>>>) -> !fir.ref<!fir.box<none>>
-! CHECK: %[[C1_I64:.*]] = fir.convert %[[C1]] : (index) -> i64
+! CHECK: %[[C1_I64:.*]] = fir.convert %[[CST1]] : (index) -> i64
 ! CHECK: %[[C20_I64:.*]] = fir.convert %[[C20]] : (i32) -> i64
 ! CHECK: %{{.*}} = fir.call @_FortranAAllocatableSetBounds(%[[C4_CAST]], %[[C0]], %[[C1_I64]], %[[C20_I64]]) : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
 ! CHECK: %[[C4_CAST:.*]] = fir.convert %[[C4]] : (!fir.ref<!fir.class<!fir.heap<!fir.array<?x!fir.type<_QMpolyTp1{a:i32,b:i32}>>>>>) -> !fir.ref<!fir.box<none>>
 ! CHECK: %{{.*}} = fir.call @_FortranAAllocatableAllocate(%[[C4_CAST]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
 
+! Check fir.rebox for fir.class
+! CHECK: %[[C1_LOAD:.*]] = fir.load %[[C1]] : !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>>
+! CHECK: %[[C1_REBOX:.*]] = fir.rebox %[[C1_LOAD]] : (!fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>) -> !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>
+! CHECK: fir.dispatch "proc2"(%[[C1_REBOX]] : !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>) (%61 : !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
+
+! CHECK: %[[C2_LOAD:.*]] = fir.load %[[C2]] : !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>>
+! CHECK: %[[C2_REBOX:.*]] = fir.rebox %[[C2_LOAD]] : (!fir.class<!fir.heap<!fir.type<_QMpolyTp1{a:i32,b:i32}>>>) -> !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>
+! CHECK: fir.dispatch "proc2"(%[[C2_REBOX]] : !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>) (%63 : !fir.class<!fir.type<_QMpolyTp1{a:i32,b:i32}>>) {pass_arg_pos = 0 : i32}
 
 ! Check code generation of allocate runtime calls for polymoprhic entities. This
 ! is done from Fortran so we don't have a file full of auto-generated type info
 ! in order to perform the checks.
 
 ! LLVM-LABEL: define void @_QQmain()
+! LLVM: %[[TMP1:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }
+! LLVM: %[[TMP2:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }
+! LLVM: %[[TMP3:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }
+! LLVM: %[[TMP4:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }
 ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEp, ptr @_QMpolyE.dt.p1, i32 0, i32 0)
 ! LLVM: %{{.*}} = call i32 @_FortranAAllocatableAllocate(ptr @_QFEp, i1 false, ptr null, ptr @_QQcl.{{.*}}, i32 {{.*}})
 ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEc1, ptr @_QMpolyE.dt.p1, i32 0, i32 0)
@@ -129,3 +155,20 @@ program test_allocatable
 ! LLVM: %{{.*}} = call {} @_FortranAAllocatableInitDerived(ptr @_QFEc4, ptr @_QMpolyE.dt.p2, i32 1, i32 0)
 ! LLVM: %{{.*}} = call {} @_FortranAAllocatableSetBounds(ptr @_QFEc4, i32 0, i64 1, i64 20)
 ! LLVM: %{{.*}} = call i32 @_FortranAAllocatableAllocate(ptr @_QFEc4, i1 false, ptr null, ptr @_QQcl.{{.*}}, i32 {{.*}})
+! LLVM:  call void %{{.*}}()
+
+! LLVM: %[[C1_LOAD:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr @_QFEc1
+! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[C1_LOAD]], ptr %[[TMP4]]
+! LLVM: %[[GEP_TDESC_C1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[TMP4]], i32 0, i32 7
+! LLVM: %[[TDESC_C1:.*]] = load ptr, ptr %[[GEP_TDESC_C1]]
+! LLVM: %[[BOX0:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } { ptr undef, i64 ptrtoint (ptr getelementptr (%_QMpolyTp1, ptr null, i32 1) to i64), i32 20180515, i8 0, i8 42, i8 0, i8 1, ptr undef, [1 x i64] undef }, ptr %[[TDESC_C1]], 7
+! LLVM: %[[BOX1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX0]], ptr %{{.*}}, 0
+! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX1]], ptr %[[TMP3]]
+
+! LLVM: %[[LOAD_C2:.*]] = load { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr @_QFEc2
+! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[LOAD_C2]], ptr %[[TMP2]]
+! LLVM: %[[GEP_TDESC_C2:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] }, ptr %[[TMP2]], i32 0, i32 7
+! LLVM: %[[TDESC_C2:.*]] = load ptr, ptr %[[GEP_TDESC_C2]]
+! LLVM: %[[BOX0:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } { ptr undef, i64 ptrtoint (ptr getelementptr (%_QMpolyTp1, ptr null, i32 1) to i64), i32 20180515, i8 0, i8 42, i8 0, i8 1, ptr undef, [1 x i64] undef }, ptr %[[TDESC_C2]], 7
+! LLVM: %[[BOX1:.*]] = insertvalue { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX0]], ptr %{{.*}}, 0
+! LLVM: store { ptr, i64, i32, i8, i8, i8, i8, ptr, [1 x i64] } %[[BOX1]], ptr %[[TMP1]]


        


More information about the flang-commits mailing list