[flang-commits] [flang] [flang][cuda] Update target rewrite to work on gpu.func (PR #119283)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Dec 10 08:57:10 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/119283

>From 91f124245a6408723376dc241f35484c34c774cf Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Sat, 7 Dec 2024 14:49:48 -0800
Subject: [PATCH 1/2] [flang][cuda] Update target rewrite to work on gpu.func

---
 flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 163 +++++++++++-------
 flang/test/Fir/CUDA/cuda-target-rewrite.mlir  |  14 +-
 2 files changed, 112 insertions(+), 65 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 1b86d5241704b1..4603a7d6cb4ec5 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -62,14 +62,21 @@ struct FixupTy {
   FixupTy(Codes code, std::size_t index,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, gpuFinalizer{finalizer} {}
   FixupTy(Codes code, std::size_t index, std::size_t second,
           std::function<void(mlir::func::FuncOp)> &&finalizer)
       : code{code}, index{index}, second{second}, finalizer{finalizer} {}
+  FixupTy(Codes code, std::size_t index, std::size_t second,
+          std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+      : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
 
   Codes code;
   std::size_t index;
   std::size_t second{};
   std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
+  std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
 }; // namespace
 
 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       convertSignature(fn);
     }
 
-    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
+    for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
       for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
         convertSignature(fn);
+      for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
+        convertSignature(fn);
+    }
 
     return mlir::success();
   }
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   /// Determine if the signature has host associations. The host association
   /// argument may need special target specific rewriting.
-  static bool hasHostAssociations(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  static bool hasHostAssociations(OpTy func) {
     std::size_t end = func.getFunctionType().getInputs().size();
     for (std::size_t i = 0; i < end; ++i)
-      if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              i, fir::getHostAssocAttrName()))
         return true;
     return false;
   }
 
   /// Rewrite the signatures and body of the `FuncOp`s in the module for
   /// the immediately subsequent target code gen.
-  void convertSignature(mlir::func::FuncOp func) {
+  template <typename OpTy>
+  void convertSignature(OpTy func) {
     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
     if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
       return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     // Convert return value(s)
     for (auto ty : funcTy.getResults())
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             if (noComplexConversion)
               newResTys.push_back(cmplx);
             else
               doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                                                 rewriter->getUnitAttr()));
             newResTys.push_back(retTy);
           })
-          .Case<fir::RecordType>([&](fir::RecordType recTy) {
+          .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+          .template Default([&](mlir::Type ty) { newResTys.push_back(ty); });
 
     // Saved potential shift in argument. Handling of result can add arguments
     // at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       auto ty = e.value();
       unsigned index = e.index();
       llvm::TypeSwitch<mlir::Type>(ty)
-          .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+          .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
             if (noCharacterConversion) {
               newInTyAndAttrs.push_back(
                   fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
               }
             }
           })
-          .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+          .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
             doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
           })
-          .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+          .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
             if (fir::isCharacterProcedureTuple(tuple)) {
               fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
                                   newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                   fir::CodeGenSpecifics::getTypeAndAttr(ty));
             }
           })
-          .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+          .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
             auto m = specifics->integerArgumentType(func.getLoc(), intTy);
             assert(m.size() == 1);
             auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             if (!extensionAttrName.empty() &&
                 isFuncWithCCallingConvention(func))
               fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
-                                  [=](mlir::func::FuncOp func) {
+                                  [=](OpTy func) {
                                     func.setArgAttr(
                                         argNo, extensionAttrName,
                                         mlir::UnitAttr::get(func.getContext()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           .template Case<fir::RecordType>([&](fir::RecordType recTy) {
             doStructArg(func, recTy, newInTyAndAttrs, fixups);
           })
-          .Default([&](mlir::Type ty) {
+          .template Default([&](mlir::Type ty) {
             newInTyAndAttrs.push_back(
                 fir::CodeGenSpecifics::getTypeAndAttr(ty));
           });
 
-      if (func.getArgAttrOfType<mlir::UnitAttr>(index,
-                                                fir::getHostAssocAttrName())) {
+      if (func.template getArgAttrOfType<mlir::UnitAttr>(
+              index, fir::getHostAssocAttrName())) {
         extraAttrs.push_back(
             {newInTyAndAttrs.size() - 1,
              rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           auto newArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
           offset++;
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-            auto cast =
-                rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-            rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-            rewriter->create<mlir::func::ReturnOp>(loc);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::func::ReturnOp>(loc);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::gpu::ReturnOp>(loc);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::ReturnType: {
           // The function is still returning a value, but its type has likely
           // changed to suit the target ABI convention.
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            mlir::Value bitcast =
-                convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                     /*inputMayBeBigger=*/false);
-            rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              mlir::Value bitcast =
+                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                       /*inputMayBeBigger=*/false);
+              rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::Split: {
           // The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       }
     }
 
-    for (auto &fixup : fixups)
-      if (fixup.finalizer)
-        (*fixup.finalizer)(func);
+    for (auto &fixup : fixups) {
+      if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+        if (fixup.finalizer)
+          (*fixup.finalizer)(func);
+      if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+        if (fixup.gpuFinalizer)
+          (*fixup.gpuFinalizer)(func);
+    }
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doReturn(OpTy func, Ty &newResTys,
                 fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                 FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
     assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       unsigned argNo = newInTyAndAttrs.size();
       if (auto align = attr.getAlignment())
         fixups.emplace_back(
-            FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
+            FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                   func.getFunctionType().getInput(argNo));
               func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             });
       else
         fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
-                            [=](mlir::func::FuncOp func) {
+                            [=](OpTy func) {
                               auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                   func.getFunctionType().getInput(argNo));
                               func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     }
     if (auto align = attr.getAlignment())
       fixups.emplace_back(
-          FixupTy::Codes::ReturnType, newResTys.size(),
-          [=](mlir::func::FuncOp func) {
+          FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
             func.setArgAttr(
                 newResTys.size(), "llvm.align",
                 rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex return value. This can involve converting the return
   /// value to a "hidden" first argument or packing the complex into a wide
   /// GPR.
-  template <typename Ty, typename FIXUPS>
-  void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
-                       Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                        FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename Ty, typename FIXUPS>
-  void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
-                      Ty &newResTys,
+  template <typename OpTy, typename Ty, typename FIXUPS>
+  void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
                       fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                       FIXUPS &fixups) {
     if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
   }
 
-  template <typename FIXUPS>
-  void
-  createFuncOpArgFixups(mlir::func::FuncOp func,
-                        fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
-                        fir::CodeGenSpecifics::Marshalling &argsInTys,
-                        FIXUPS &fixups) {
+  template <typename OpTy, typename FIXUPS>
+  void createFuncOpArgFixups(
+      OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+      fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
     const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
                                                 : FixupTy::Codes::ArgumentType;
     for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       if (attr.isByVal()) {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
-                              [=](mlir::func::FuncOp func) {
+                              [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
                               });
         else
           fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
-                              newInTyAndAttrs.size(),
-                              [=](mlir::func::FuncOp func) {
+                              newInTyAndAttrs.size(), [=](OpTy func) {
                                 auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
                                     func.getFunctionType().getInput(argNo));
                                 func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       } else {
         if (auto align = attr.getAlignment())
           fixups.emplace_back(
-              fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
+              fixupCode, argNo, index, [=](OpTy func) {
                 func.setArgAttr(argNo, "llvm.align",
                                 rewriter->getIntegerAttr(
                                     rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
   /// Convert a complex argument value. This can involve storing the value to
   /// a temporary memory location or factoring the value into two distinct
   /// arguments.
-  template <typename FIXUPS>
-  void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+  template <typename OpTy, typename FIXUPS>
+  void doComplexArg(OpTy func, mlir::ComplexType cmplx,
                     fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                     FIXUPS &fixups) {
     if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
   }
 
-  template <typename FIXUPS>
-  void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+  template <typename OpTy, typename FIXUPS>
+  void doStructArg(OpTy func, fir::RecordType recTy,
                    fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
                    FIXUPS &fixups) {
     if (noStructConversion) {
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index c14efc8b13f66b..d88b6776795a0b 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -1,5 +1,5 @@
 // REQUIRES: x86-registered-target
-// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
 
 gpu.module @testmod {
   gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
 // CHECK-LABEL: gpu.func @_QPvcpowdk
 // CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
 // CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
+
+// -----
+
+gpu.module @testmod {
+  gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
+    gpu.return %arg0 : complex<f64>
+  }
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
+// CHECK: gpu.return %{{.*}} : tuple<f64, f64>

>From d92c612cd85ef480ba34292537cd6bd5b30b872d Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 10 Dec 2024 08:56:29 -0800
Subject: [PATCH 2/2] Remove duplication of code

---
 flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 77 +++++++------------
 1 file changed, 27 insertions(+), 50 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 4603a7d6cb4ec5..62899ffa3ae416 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -726,14 +726,14 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
       if (targetFeaturesAttr)
         fn->setAttr("target_features", targetFeaturesAttr);
 
-      convertSignature(fn);
+      convertSignature<mlir::func::ReturnOp>(fn);
     }
 
     for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
       for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
-        convertSignature(fn);
+        convertSignature<mlir::func::ReturnOp>(fn);
       for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
-        convertSignature(fn);
+        convertSignature<mlir::gpu::ReturnOp>(fn);
     }
 
     return mlir::success();
@@ -792,8 +792,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
 
   /// Rewrite the signatures and body of the `FuncOp`s in the module for
   /// the immediately subsequent target code gen.
-  template <typename OpTy>
-  void convertSignature(OpTy func) {
+  template <typename ReturnOpTy, typename FuncOpTy>
+  void convertSignature(FuncOpTy func) {
     auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
     if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
       return;
@@ -900,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
             if (!extensionAttrName.empty() &&
                 isFuncWithCCallingConvention(func))
               fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
-                                  [=](OpTy func) {
+                                  [=](FuncOpTy func) {
                                     func.setArgAttr(
                                         argNo, extensionAttrName,
                                         mlir::UnitAttr::get(func.getContext()));
@@ -992,52 +992,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           auto newArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
           offset++;
-          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
-            func.walk([&](mlir::func::ReturnOp ret) {
-              rewriter->setInsertionPoint(ret);
-              auto oldOper = ret.getOperand(0);
-              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-              auto cast =
-                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-              rewriter->create<mlir::func::ReturnOp>(loc);
-              ret.erase();
-            });
-          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
-            func.walk([&](mlir::gpu::ReturnOp ret) {
-              rewriter->setInsertionPoint(ret);
-              auto oldOper = ret.getOperand(0);
-              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-              auto cast =
-                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-              rewriter->create<mlir::gpu::ReturnOp>(loc);
-              ret.erase();
-            });
+          func.walk([&](ReturnOpTy ret) {
+            rewriter->setInsertionPoint(ret);
+            auto oldOper = ret.getOperand(0);
+            auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+            auto cast =
+                rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+            rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+            rewriter->create<ReturnOpTy>(loc);
+            ret.erase();
+          });
         } break;
         case FixupTy::Codes::ReturnType: {
           // The function is still returning a value, but its type has likely
           // changed to suit the target ABI convention.
-          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
-            func.walk([&](mlir::func::ReturnOp ret) {
-              rewriter->setInsertionPoint(ret);
-              auto oldOper = ret.getOperand(0);
-              mlir::Value bitcast =
-                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                       /*inputMayBeBigger=*/false);
-              rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
-              ret.erase();
-            });
-          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
-            func.walk([&](mlir::gpu::ReturnOp ret) {
-              rewriter->setInsertionPoint(ret);
-              auto oldOper = ret.getOperand(0);
-              mlir::Value bitcast =
-                  convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                       /*inputMayBeBigger=*/false);
-              rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
-              ret.erase();
-            });
+          func.walk([&](ReturnOpTy ret) {
+            rewriter->setInsertionPoint(ret);
+            auto oldOper = ret.getOperand(0);
+            mlir::Value bitcast =
+                convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+                                     /*inputMayBeBigger=*/false);
+            rewriter->create<ReturnOpTy>(loc, bitcast);
+            ret.erase();
+          });
         } break;
         case FixupTy::Codes::Split: {
           // The FIR argument has been split into a pair of distinct arguments
@@ -1138,10 +1115,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
     }
 
     for (auto &fixup : fixups) {
-      if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+      if constexpr (std::is_same_v<FuncOpTy, mlir::func::FuncOp>)
         if (fixup.finalizer)
           (*fixup.finalizer)(func);
-      if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+      if constexpr (std::is_same_v<FuncOpTy, mlir::gpu::GPUFuncOp>)
         if (fixup.gpuFinalizer)
           (*fixup.gpuFinalizer)(func);
     }



More information about the flang-commits mailing list