[flang-commits] [flang] [flang][cuda] Update target rewrite to work on gpu.func (PR #119283)

via flang-commits flang-commits at lists.llvm.org
Tue Dec 10 01:23:10 PST 2024


================
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
           auto newArg =
               func.front().insertArgument(fixup.index, fixupType, loc);
           offset++;
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
-            auto cast =
-                rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
-            rewriter->create<fir::StoreOp>(loc, oldOper, cast);
-            rewriter->create<mlir::func::ReturnOp>(loc);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+            func.walk([&](mlir::func::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::func::ReturnOp>(loc);
+              ret.erase();
+            });
+          if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+            func.walk([&](mlir::gpu::ReturnOp ret) {
+              rewriter->setInsertionPoint(ret);
+              auto oldOper = ret.getOperand(0);
+              auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+              auto cast =
+                  rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+              rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+              rewriter->create<mlir::gpu::ReturnOp>(loc);
+              ret.erase();
+            });
         } break;
         case FixupTy::Codes::ReturnType: {
           // The function is still returning a value, but its type has likely
           // changed to suit the target ABI convention.
-          func.walk([&](mlir::func::ReturnOp ret) {
-            rewriter->setInsertionPoint(ret);
-            auto oldOper = ret.getOperand(0);
-            mlir::Value bitcast =
-                convertValueInMemory(loc, oldOper, newResTys[fixup.index],
-                                     /*inputMayBeBigger=*/false);
-            rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
-            ret.erase();
-          });
+          if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
----------------
jeanPerier wrote:

Can you add a `ReturnOpTy` to `convertSignature` to avoid the duplication here and above?

And maybe rename `OpTy` to `FuncOpTy` to underline the expected template types.

https://github.com/llvm/llvm-project/pull/119283


More information about the flang-commits mailing list