[flang-commits] [flang] 3fd250d - [fir] TargetRewrite: Rewrite fir.address_of(func)

Diana Picus via flang-commits flang-commits at lists.llvm.org
Fri Dec 3 02:56:46 PST 2021


Author: Diana Picus
Date: 2021-12-03T10:56:24Z
New Revision: 3fd250d25858c38cd83c71be47c95e4b20a38171

URL: https://github.com/llvm/llvm-project/commit/3fd250d25858c38cd83c71be47c95e4b20a38171
DIFF: https://github.com/llvm/llvm-project/commit/3fd250d25858c38cd83c71be47c95e4b20a38171.diff

LOG: [fir] TargetRewrite: Rewrite fir.address_of(func)

Rewrite AddrOfOp if taking the address of a function.

Differential Revision: https://reviews.llvm.org/D114925

Co-authored-by: Eric Schweitz <eschweitz at nvidia.com>

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
    flang/test/Fir/target-rewrite-boxchar.fir
    flang/test/Fir/target-rewrite-complex.fir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 25e1e44671d1..7a762fb181bf 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -100,6 +100,10 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
       } else if (auto dispatch = dyn_cast<DispatchOp>(op)) {
         if (!hasPortableSignature(dispatch.getFunctionType()))
           convertCallOp(dispatch);
+      } else if (auto addr = dyn_cast<AddrOfOp>(op)) {
+        if (addr.getType().isa<mlir::FunctionType>() &&
+            !hasPortableSignature(addr.getType()))
+          convertAddrOp(addr);
       }
     });
 
@@ -319,6 +323,55 @@ class TargetRewrite : public TargetRewriteBase<TargetRewrite> {
         newInTys.push_back(std::get<mlir::Type>(tup));
   }
 
+  /// Taking the address of a function. Modify the signature as needed.
+  void convertAddrOp(AddrOfOp addrOp) {
+    rewriter->setInsertionPoint(addrOp);
+    auto addrTy = addrOp.getType().cast<mlir::FunctionType>();
+    llvm::SmallVector<mlir::Type> newResTys;
+    llvm::SmallVector<mlir::Type> newInTys;
+    for (mlir::Type ty : addrTy.getResults()) {
+      llvm::TypeSwitch<mlir::Type>(ty)
+          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
+            lowerComplexSignatureRes(ty, newResTys, newInTys);
+          })
+          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+            lowerComplexSignatureRes(ty, newResTys, newInTys);
+          })
+          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+    }
+    llvm::SmallVector<mlir::Type> trailingInTys;
+    for (mlir::Type ty : addrTy.getInputs()) {
+      llvm::TypeSwitch<mlir::Type>(ty)
+          .Case<BoxCharType>([&](BoxCharType box) {
+            if (noCharacterConversion) {
+              newInTys.push_back(box);
+            } else {
+              for (auto &tup : specifics->boxcharArgumentType(box.getEleTy())) {
+                auto attr = std::get<CodeGenSpecifics::Attributes>(tup);
+                auto argTy = std::get<mlir::Type>(tup);
+                llvm::SmallVector<mlir::Type> &vec =
+                    attr.isAppend() ? trailingInTys : newInTys;
+                vec.push_back(argTy);
+              }
+            }
+          })
+          .Case<fir::ComplexType>([&](fir::ComplexType ty) {
+            lowerComplexSignatureArg(ty, newInTys);
+          })
+          .Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
+            lowerComplexSignatureArg(ty, newInTys);
+          })
+          .Default([&](mlir::Type ty) { newInTys.push_back(ty); });
+    }
+    // append trailing input types
+    newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end());
+    // replace this op with a new one with the updated signature
+    auto newTy = rewriter->getFunctionType(newInTys, newResTys);
+    auto newOp =
+        rewriter->create<AddrOfOp>(addrOp.getLoc(), newTy, addrOp.symbol());
+    replaceOp(addrOp, newOp.getResult());
+  }
+
   /// Convert the type signatures on all the functions present in the module.
   /// As the type signature is being changed, this must also update the
   /// function itself to use any new arguments, etc.

diff  --git a/flang/test/Fir/target-rewrite-boxchar.fir b/flang/test/Fir/target-rewrite-boxchar.fir
index e2fb31ffecea..400ae548c529 100644
--- a/flang/test/Fir/target-rewrite-boxchar.fir
+++ b/flang/test/Fir/target-rewrite-boxchar.fir
@@ -93,3 +93,13 @@ fir.global @name constant : !fir.char<1,9> {
   //constant 1
   fir.has_value %str : !fir.char<1,9>
 }
+
+// Test that we rewrite the fir.address_of operator
+// INT32-LABEL: @addrof
+// INT64-LABEL: @addrof
+func @addrof() {
+  // INT32: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i32) -> ()
+  // INT64: {{.*}} = fir.address_of(@boxcharcallee) : (!fir.ref<!fir.char<1,?>>, i64) -> ()
+  %f = fir.address_of(@boxcharcallee) : (!fir.boxchar<1>) -> ()
+  return
+}

diff  --git a/flang/test/Fir/target-rewrite-complex.fir b/flang/test/Fir/target-rewrite-complex.fir
index 54fd2f2adf53..49c9586108bc 100644
--- a/flang/test/Fir/target-rewrite-complex.fir
+++ b/flang/test/Fir/target-rewrite-complex.fir
@@ -452,3 +452,23 @@ func private @mlircomplexf32(%z1: complex<f32>, %z2: complex<f32>) -> complex<f3
   // PPC: return [[RES]] : tuple<f32, f32>
   return %0 : complex<f32>
 }
+
+// Test that we rewrite the fir.address_of operator.
+// I32-LABEL: func @addrof()
+// X64-LABEL: func @addrof()
+// AARCH64-LABEL: func @addrof()
+// PPC-LABEL: func @addrof()
+func @addrof() {
+  // I32: {{%.*}} = fir.address_of(@returncomplex4) : () -> i64
+  // X64: {{%.*}} = fir.address_of(@returncomplex4) : () -> !fir.vector<2:!fir.real<4>>
+  // AARCH64: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
+  // PPC: {{%.*}} = fir.address_of(@returncomplex4) : () -> tuple<!fir.real<4>, !fir.real<4>>
+  %r = fir.address_of(@returncomplex4) : () -> !fir.complex<4>
+
+  // I32: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.ref<tuple<!fir.real<4>, !fir.real<4>>>) -> ()
+  // X64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.vector<2:!fir.real<4>>) -> ()
+  // AARCH64: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.array<2x!fir.real<4>>) -> ()
+  // PPC: {{%.*}} = fir.address_of(@paramcomplex4) : (!fir.real<4>, !fir.real<4>) -> ()
+  %p = fir.address_of(@paramcomplex4) : (!fir.complex<4>) -> ()
+  return
+}


        


More information about the flang-commits mailing list