[flang-commits] [flang] [flang] Implement conversion of compatible derived types (PR #111165)

Leandro Lupori via flang-commits flang-commits at lists.llvm.org
Fri Oct 4 07:24:10 PDT 2024


https://github.com/luporl created https://github.com/llvm/llvm-project/pull/111165

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes https://github.com/llvm/llvm-project/issues/107783


>From 4fa873aefbeea19a5c5520443d19a2acd3c8b73c Mon Sep 17 00:00:00 2001
From: Leandro Lupori <leandro.lupori at linaro.org>
Date: Thu, 3 Oct 2024 14:26:00 -0300
Subject: [PATCH] [flang] Implement conversion of compatible derived types

With some restrictions, BIND(C) derived types can be converted to
compatible BIND(C) derived types.
Semantics already support this, but ConvertOp was missing the
conversion of such types.

Fixes https://github.com/llvm/llvm-project/issues/107783
---
 flang/lib/Optimizer/Builder/FIRBuilder.cpp |  5 ++++-
 flang/lib/Optimizer/CodeGen/CodeGen.cpp    | 25 ++++++++++++++++++++++
 flang/lib/Optimizer/Dialect/FIROps.cpp     | 12 ++++++++++-
 flang/test/Fir/convert-to-llvm.fir         | 25 ++++++++++++++++++++++
 flang/test/Fir/invalid.fir                 |  8 +++++++
 5 files changed, 73 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
index 9ad37c8df434a2..8fa695a5c0c2e1 100644
--- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp
+++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp
@@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder,
                                         mlir::Location loc, mlir::Type toTy,
                                         mlir::Value val) {
   if (val.getType() != toTy) {
-    assert(!fir::isa_derived(toTy));
+    assert((!fir::isa_derived(toTy) ||
+            mlir::cast<fir::RecordType>(val.getType()).getTypeList() ==
+                mlir::cast<fir::RecordType>(toTy).getTypeList()) &&
+           "incompatible record types");
     return builder.create<fir::ConvertOp>(loc, toTy, val);
   }
   return val;
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 1cb869bfeb95a8..19c38a1ba6be26 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion<fir::ConvertOp> {
     auto loc = convert.getLoc();
     auto i1Type = mlir::IntegerType::get(convert.getContext(), 1);
 
+    if (mlir::isa<fir::RecordType>(toFirTy)) {
+      // Convert to compatible BIND(C) record type.
+      // Double check that the record types are compatible (it should have
+      // already been checked by the verifier).
+      assert(mlir::cast<fir::RecordType>(fromFirTy).getTypeList() ==
+                 mlir::cast<fir::RecordType>(toFirTy).getTypeList() &&
+             "incompatible record types");
+
+      auto toStTy = mlir::cast<mlir::LLVM::LLVMStructType>(toTy);
+      mlir::Value val = rewriter.create<mlir::LLVM::UndefOp>(loc, toStTy);
+      auto indexTypeMap = toStTy.getSubelementIndexMap();
+      assert(indexTypeMap.has_value() && "invalid record type");
+
+      for (auto [attr, type] : indexTypeMap.value()) {
+        int64_t index = mlir::cast<mlir::IntegerAttr>(attr).getInt();
+        auto extVal =
+            rewriter.create<mlir::LLVM::ExtractValueOp>(loc, op0, index);
+        val =
+            rewriter.create<mlir::LLVM::InsertValueOp>(loc, val, extVal, index);
+      }
+
+      rewriter.replaceOp(convert, val);
+      return mlir::success();
+    }
+
     if (mlir::isa<fir::LogicalType>(fromFirTy) ||
         mlir::isa<fir::LogicalType>(toFirTy)) {
       // By specification fir::LogicalType value may be any number,
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 8fdc06f6fce3f5..90ce8b87605912 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) {
   return true;
 }
 
+static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) {
+  // Both records must have the same field types.
+  // Trust frontend semantics for in-depth checks, such as if both records
+  // have the BIND(C) attribute.
+  auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy);
+  auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy);
+  return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList();
+}
+
 bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
   if (inType == outType)
     return true;
@@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) {
          (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) ||
          (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) ||
-         areVectorsCompatible(inType, outType);
+         areVectorsCompatible(inType, outType) ||
+         areRecordsCompatible(inType, outType);
 }
 
 llvm::LogicalResult fir::ConvertOp::verify() {
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index 0c17d7c25a8c8d..1182a0a10f218b 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex<f128>) -> complex<f16> {
 
 // -----
 
+// Test `fir.convert` operation conversion between compatible fir.record types.
+
+func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                                  !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> {
+    %0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) ->
+                              !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+  return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>
+}
+
+// CHECK-LABEL: func @convert_record(
+// CHECK-SAME:    %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) ->
+// CHECK-SAME:                  [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]
+// CHECK:         %{{.*}} = llvm.mlir.undef : [[MOD2_REC]]
+// CHECK-DAG:     %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]]
+// CHECK-DAG:     %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]]
+// CHECK-DAG:     %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]]
+// CHECK-DAG:     %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]]
+// CHECK-DAG:     %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]]
+// CHECK:         llvm.return %{{.*}} : [[MOD2_REC]]
+
+// -----
+
 // Test `fir.store` --> `llvm.store` conversion
 
 func.func @test_store_index(%val_to_store : index, %addr : !fir.ref<index>) {
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 086a426db5642e..7e3f9d64984129 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> {
 
 // -----
 
+func.func @rec_to_rec(%arg0: !fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}> {
+  // expected-error at +1{{'fir.convert' op invalid type conversion}}
+  %0 = fir.convert %arg0 : (!fir.type<t1{i:i32, f:f32}>) -> !fir.type<t2{f:f32, i:i32}>
+  return %0 : !fir.type<t2{f:f32, i:i32}>
+}
+
+// -----
+
 func.func @bad_box_offset(%not_a_box : !fir.ref<i32>) {
   // expected-error at +1{{'fir.box_offset' op box_ref operand must have !fir.ref<!fir.box<T>> type}}
   %addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref<i32>) -> !fir.llvm_ptr<!fir.ref<i32>>



More information about the flang-commits mailing list