[flang-commits] [flang] [flang] Implement conversion of compatible derived types (PR #111165)

via flang-commits flang-commits at lists.llvm.org
Fri Oct 4 07:24:45 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Leandro Lupori (luporl)

<details>
<summary>Changes</summary>

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes https://github.com/llvm/llvm-project/issues/107783


---
Full diff: https://github.com/llvm/llvm-project/pull/111165.diff


5 Files Affected:

- (modified) flang/lib/Optimizer/Builder/FIRBuilder.cpp (+4-1) 
- (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+25) 
- (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+11-1) 
- (modified) flang/test/Fir/convert-to-llvm.fir (+25) 
- (modified) flang/test/Fir/invalid.fir (+8) 


``````````diff
diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 9ad37c8df434a2..8fa695a5c0c2e1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder,
                                         mlir::Location loc, mlir::Type toTy,
                                         mlir::Value val) {
   if (val.getType() != toTy) {
-    assert(!fir::isa_derived(toTy));
+    assert((!fir::isa_derived(toTy) ||
+            mlir::cast<fir::RecordType>(val.getType()).getTypeList() ==
+                mlir::cast<fir::RecordType>(toTy).getTypeList()) &&
+           "incompatible record types");
     return builder.create<fir::ConvertOp>(loc, toTy, val);
   }
   return val;
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1cb869bfeb95a8..19c38a1ba6be26 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
     auto loc = convert.getLoc();
     auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
 
+    if (mlir::isa<fir::RecordType>(toFirTy)) {
+      // Convert to compatible BIND(C) record type.
+      // Double check that the record types are compatible (it should have
+      // already been checked by the verifier).
+      assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() ==
+                 mlir::cast<fir::RecordType>(toFirTy).getTypeList() &&
+             "incompatible record types");
+
+      auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy);
+      mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy);
+      auto indexTypeMap = toStTy.getSubelementIndexMap();
+      assert(indexTypeMap.has_value() && "invalid record type");
+
+      for (auto [attr, type] : indexTypeMap.value()) {
+        int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt();
+        auto extVal =
+            rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index);
+        val =
+            rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index);
+      }
+
+      rewriter.replaceOp(convert, val);
+      return mlir::success();
+    }
+
     if (mlir::isa<fir::LogicalType>(fromFirTy) ||
         mlir::isa<fir::LogicalType>(toFirTy)) {
       // By specification fir::LogicalType value may be any number,
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8fdc06f6fce3f5..90ce8b87605912 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) {
   return true;
 }
 
+static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) {
+  // Both records must have the same field types.
+  // Trust frontend semantics for in-depth checks, such as if both records
+  // have the BIND(C) attribute.
+  auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy);
+  auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy);
+  return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList();
+}
+
 bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
   if (inType == outType)
     return true;
@@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
          (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) ||
-         areVectorsCompatible(inType, outType);
+         areVectorsCompatible(inType, outType) ||
+         areRecordsCompatible(inType, outType);
 }
 
 llvm::LogicalResult fir::ConvertOp::verify() {
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 0c17d7c25a8c8d..1182a0a10f218b 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex<f128>) -> complex<f16> {
 
 // -----
 
+// Test `fir.convert` operation conversion between compatible fir.record types.
+
+func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                                  !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> {
+    %0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                              !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+  return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+}
+
+// CHECK-LABEL: func @convert_record(
+// CHECK-SAME:    %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) ->
+// CHECK-SAME:                  [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]
+// CHECK:         %{{.*}} = llvm.mlir.undef : [[MOD2_REC]]
+// CHECK-DAG:     %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]]
+// CHECK-DAG:     %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]]
+// CHECK-DAG:     %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]]
+// CHECK-DAG:     %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]]
+// CHECK:         llvm.return %{{.*}} : [[MOD2_REC]]
+
+// -----
+
 // Test `fir.store` --> `llvm.store` conversion
 
 func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) {
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 086a426db5642e..7e3f9d64984129 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
 
 // -----
 
+func.func @rec_to_rec(%arg0: !fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}> {
+  // expected-error at +1{{'fir.convert' op invalid type conversion}}
+  %0 = fir.convert %arg0 : (!fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}>
+  return %0 : !fir.type<t2{f:f32, i:i32}>
+}
+
+// -----
+
 func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
   // expected-error at +1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
   %addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/111165


More information about the flang-commits mailing list