[flang-commits] [flang] [flang][cuda] Update target rewrite to work on gpu.func (PR #119283)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon Dec 9 14:56:46 PST 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/119283
Update the pass so it can perform the signature rewrite on gpu.func.
>From 91f124245a6408723376dc241f35484c34c774cf Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Sat, 7 Dec 2024 14:49:48 -0800
Subject: [PATCH] [flang][cuda] Update target rewrite to work on gpu.func
---
flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 163 +++++++++++-------
flang/test/Fir/CUDA/cuda-target-rewrite.mlir | 14 +-
2 files changed, 112 insertions(+), 65 deletions(-)
diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
index 1b86d5241704b1..4603a7d6cb4ec5 100644
--- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
+++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
@@ -62,14 +62,21 @@ struct FixupTy {
FixupTy(Codes code, std::size_t index,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, finalizer{finalizer} {}
+ FixupTy(Codes code, std::size_t index,
+ std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+ : code{code}, index{index}, gpuFinalizer{finalizer} {}
FixupTy(Codes code, std::size_t index, std::size_t second,
std::function<void(mlir::func::FuncOp)> &&finalizer)
: code{code}, index{index}, second{second}, finalizer{finalizer} {}
+ FixupTy(Codes code, std::size_t index, std::size_t second,
+ std::function<void(mlir::gpu::GPUFuncOp)> &&finalizer)
+ : code{code}, index{index}, second{second}, gpuFinalizer{finalizer} {}
Codes code;
std::size_t index;
std::size_t second{};
std::optional<std::function<void(mlir::func::FuncOp)>> finalizer{};
+ std::optional<std::function<void(mlir::gpu::GPUFuncOp)>> gpuFinalizer{};
}; // namespace
/// Target-specific rewriting of the FIR. This is a prerequisite pass to code
@@ -722,9 +729,12 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
convertSignature(fn);
}
- for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>())
+ for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) {
for (auto fn : gpuMod.getOps<mlir::func::FuncOp>())
convertSignature(fn);
+ for (auto fn : gpuMod.getOps<mlir::gpu::GPUFuncOp>())
+ convertSignature(fn);
+ }
return mlir::success();
}
@@ -770,17 +780,20 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Determine if the signature has host associations. The host association
/// argument may need special target specific rewriting.
- static bool hasHostAssociations(mlir::func::FuncOp func) {
+ template <typename OpTy>
+ static bool hasHostAssociations(OpTy func) {
std::size_t end = func.getFunctionType().getInputs().size();
for (std::size_t i = 0; i < end; ++i)
- if (func.getArgAttrOfType<mlir::UnitAttr>(i, fir::getHostAssocAttrName()))
+ if (func.template getArgAttrOfType<mlir::UnitAttr>(
+ i, fir::getHostAssocAttrName()))
return true;
return false;
}
/// Rewrite the signatures and body of the `FuncOp`s in the module for
/// the immediately subsequent target code gen.
- void convertSignature(mlir::func::FuncOp func) {
+ template <typename OpTy>
+ void convertSignature(OpTy func) {
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasPortableSignature(funcTy, func) && !hasHostAssociations(func))
return;
@@ -805,13 +818,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
- .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -825,10 +838,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
rewriter->getUnitAttr()));
newResTys.push_back(retTy);
})
- .Case<fir::RecordType>([&](fir::RecordType recTy) {
+ .template Case<fir::RecordType>([&](fir::RecordType recTy) {
doStructReturn(func, recTy, newResTys, newInTyAndAttrs, fixups);
})
- .Default([&](mlir::Type ty) { newResTys.push_back(ty); });
+ .template Default([&](mlir::Type ty) { newResTys.push_back(ty); });
// Saved potential shift in argument. Handling of result can add arguments
// at the beginning of the function signature.
@@ -840,7 +853,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto ty = e.value();
unsigned index = e.index();
llvm::TypeSwitch<mlir::Type>(ty)
- .Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
+ .template Case<fir::BoxCharType>([&](fir::BoxCharType boxTy) {
if (noCharacterConversion) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(boxTy));
@@ -863,10 +876,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
})
- .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
+ .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
- .Case<mlir::TupleType>([&](mlir::TupleType tuple) {
+ .template Case<mlir::TupleType>([&](mlir::TupleType tuple) {
if (fir::isCharacterProcedureTuple(tuple)) {
fixups.emplace_back(FixupTy::Codes::TrailingCharProc,
newInTyAndAttrs.size(), trailingTys.size());
@@ -878,7 +891,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
fir::CodeGenSpecifics::getTypeAndAttr(ty));
}
})
- .Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
+ .template Case<mlir::IntegerType>([&](mlir::IntegerType intTy) {
auto m = specifics->integerArgumentType(func.getLoc(), intTy);
assert(m.size() == 1);
auto attr = std::get<fir::CodeGenSpecifics::Attributes>(m[0]);
@@ -887,7 +900,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (!extensionAttrName.empty() &&
isFuncWithCCallingConvention(func))
fixups.emplace_back(FixupTy::Codes::ArgumentType, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](OpTy func) {
func.setArgAttr(
argNo, extensionAttrName,
mlir::UnitAttr::get(func.getContext()));
@@ -898,13 +911,13 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
.template Case<fir::RecordType>([&](fir::RecordType recTy) {
doStructArg(func, recTy, newInTyAndAttrs, fixups);
})
- .Default([&](mlir::Type ty) {
+ .template Default([&](mlir::Type ty) {
newInTyAndAttrs.push_back(
fir::CodeGenSpecifics::getTypeAndAttr(ty));
});
- if (func.getArgAttrOfType<mlir::UnitAttr>(index,
- fir::getHostAssocAttrName())) {
+ if (func.template getArgAttrOfType<mlir::UnitAttr>(
+ index, fir::getHostAssocAttrName())) {
extraAttrs.push_back(
{newInTyAndAttrs.size() - 1,
rewriter->getNamedAttr("llvm.nest", rewriter->getUnitAttr())});
@@ -979,29 +992,52 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto newArg =
func.front().insertArgument(fixup.index, fixupType, loc);
offset++;
- func.walk([&](mlir::func::ReturnOp ret) {
- rewriter->setInsertionPoint(ret);
- auto oldOper = ret.getOperand(0);
- auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
- auto cast =
- rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
- rewriter->create<fir::StoreOp>(loc, oldOper, cast);
- rewriter->create<mlir::func::ReturnOp>(loc);
- ret.erase();
- });
+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+ func.walk([&](mlir::func::ReturnOp ret) {
+ rewriter->setInsertionPoint(ret);
+ auto oldOper = ret.getOperand(0);
+ auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+ auto cast =
+ rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+ rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+ rewriter->create<mlir::func::ReturnOp>(loc);
+ ret.erase();
+ });
+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+ func.walk([&](mlir::gpu::ReturnOp ret) {
+ rewriter->setInsertionPoint(ret);
+ auto oldOper = ret.getOperand(0);
+ auto oldOperTy = fir::ReferenceType::get(oldOper.getType());
+ auto cast =
+ rewriter->create<fir::ConvertOp>(loc, oldOperTy, newArg);
+ rewriter->create<fir::StoreOp>(loc, oldOper, cast);
+ rewriter->create<mlir::gpu::ReturnOp>(loc);
+ ret.erase();
+ });
} break;
case FixupTy::Codes::ReturnType: {
// The function is still returning a value, but its type has likely
// changed to suit the target ABI convention.
- func.walk([&](mlir::func::ReturnOp ret) {
- rewriter->setInsertionPoint(ret);
- auto oldOper = ret.getOperand(0);
- mlir::Value bitcast =
- convertValueInMemory(loc, oldOper, newResTys[fixup.index],
- /*inputMayBeBigger=*/false);
- rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
- ret.erase();
- });
+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+ func.walk([&](mlir::func::ReturnOp ret) {
+ rewriter->setInsertionPoint(ret);
+ auto oldOper = ret.getOperand(0);
+ mlir::Value bitcast =
+ convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+ /*inputMayBeBigger=*/false);
+ rewriter->create<mlir::func::ReturnOp>(loc, bitcast);
+ ret.erase();
+ });
+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+ func.walk([&](mlir::gpu::ReturnOp ret) {
+ rewriter->setInsertionPoint(ret);
+ auto oldOper = ret.getOperand(0);
+ mlir::Value bitcast =
+ convertValueInMemory(loc, oldOper, newResTys[fixup.index],
+ /*inputMayBeBigger=*/false);
+ rewriter->create<mlir::gpu::ReturnOp>(loc, bitcast);
+ ret.erase();
+ });
} break;
case FixupTy::Codes::Split: {
// The FIR argument has been split into a pair of distinct arguments
@@ -1101,13 +1137,18 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
- for (auto &fixup : fixups)
- if (fixup.finalizer)
- (*fixup.finalizer)(func);
+ for (auto &fixup : fixups) {
+ if constexpr (std::is_same_v<OpTy, mlir::func::FuncOp>)
+ if (fixup.finalizer)
+ (*fixup.finalizer)(func);
+ if constexpr (std::is_same_v<OpTy, mlir::gpu::GPUFuncOp>)
+ if (fixup.gpuFinalizer)
+ (*fixup.gpuFinalizer)(func);
+ }
}
- template <typename Ty, typename FIXUPS>
- void doReturn(mlir::func::FuncOp func, Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doReturn(OpTy func, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups, fir::CodeGenSpecifics::Marshalling &m) {
assert(m.size() == 1 &&
@@ -1119,7 +1160,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
unsigned argNo = newInTyAndAttrs.size();
if (auto align = attr.getAlignment())
fixups.emplace_back(
- FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::func::FuncOp func) {
+ FixupTy::Codes::ReturnAsStore, argNo, [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
@@ -1130,7 +1171,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
});
else
fixups.emplace_back(FixupTy::Codes::ReturnAsStore, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.sret",
@@ -1141,8 +1182,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
if (auto align = attr.getAlignment())
fixups.emplace_back(
- FixupTy::Codes::ReturnType, newResTys.size(),
- [=](mlir::func::FuncOp func) {
+ FixupTy::Codes::ReturnType, newResTys.size(), [=](OpTy func) {
func.setArgAttr(
newResTys.size(), "llvm.align",
rewriter->getIntegerAttr(rewriter->getIntegerType(32), align));
@@ -1155,9 +1195,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
- template <typename Ty, typename FIXUPS>
- void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
- Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doComplexReturn(OpTy func, mlir::ComplexType cmplx, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
@@ -1169,9 +1208,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}
- template <typename Ty, typename FIXUPS>
- void doStructReturn(mlir::func::FuncOp func, fir::RecordType recTy,
- Ty &newResTys,
+ template <typename OpTy, typename Ty, typename FIXUPS>
+ void doStructReturn(OpTy func, fir::RecordType recTy, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
@@ -1182,12 +1220,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
doReturn(func, newResTys, newInTyAndAttrs, fixups, m);
}
- template <typename FIXUPS>
- void
- createFuncOpArgFixups(mlir::func::FuncOp func,
- fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
- fir::CodeGenSpecifics::Marshalling &argsInTys,
- FIXUPS &fixups) {
+ template <typename OpTy, typename FIXUPS>
+ void createFuncOpArgFixups(
+ OpTy func, fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
+ fir::CodeGenSpecifics::Marshalling &argsInTys, FIXUPS &fixups) {
const auto fixupCode = argsInTys.size() > 1 ? FixupTy::Codes::Split
: FixupTy::Codes::ArgumentType;
for (auto e : llvm::enumerate(argsInTys)) {
@@ -1198,7 +1234,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (attr.isByVal()) {
if (auto align = attr.getAlignment())
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, argNo,
- [=](mlir::func::FuncOp func) {
+ [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
@@ -1210,8 +1246,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
});
else
fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad,
- newInTyAndAttrs.size(),
- [=](mlir::func::FuncOp func) {
+ newInTyAndAttrs.size(), [=](OpTy func) {
auto elemType = fir::dyn_cast_ptrOrBoxEleTy(
func.getFunctionType().getInput(argNo));
func.setArgAttr(argNo, "llvm.byval",
@@ -1220,7 +1255,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
} else {
if (auto align = attr.getAlignment())
fixups.emplace_back(
- fixupCode, argNo, index, [=](mlir::func::FuncOp func) {
+ fixupCode, argNo, index, [=](OpTy func) {
func.setArgAttr(argNo, "llvm.align",
rewriter->getIntegerAttr(
rewriter->getIntegerType(32), align));
@@ -1235,8 +1270,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex argument value. This can involve storing the value to
/// a temporary memory location or factoring the value into two distinct
/// arguments.
- template <typename FIXUPS>
- void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
+ template <typename OpTy, typename FIXUPS>
+ void doComplexArg(OpTy func, mlir::ComplexType cmplx,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noComplexConversion) {
@@ -1248,8 +1283,8 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
createFuncOpArgFixups(func, newInTyAndAttrs, cplxArgs, fixups);
}
- template <typename FIXUPS>
- void doStructArg(mlir::func::FuncOp func, fir::RecordType recTy,
+ template <typename OpTy, typename FIXUPS>
+ void doStructArg(OpTy func, fir::RecordType recTy,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
FIXUPS &fixups) {
if (noStructConversion) {
diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
index c14efc8b13f66b..d88b6776795a0b 100644
--- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
+++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir
@@ -1,5 +1,5 @@
// REQUIRES: x86-registered-target
-// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
+// RUN: fir-opt --split-input-file --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s
gpu.module @testmod {
gpu.func @_QPvcpowdk(%arg0: !fir.ref<complex<f64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
@@ -15,3 +15,15 @@ gpu.module @testmod {
// CHECK-LABEL: gpu.func @_QPvcpowdk
// CHECK: %{{.*}} = fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, i64) -> tuple<f64, f64>
// CHECK: func.func private @_FortranAzpowk(f64, f64, i64) -> tuple<f64, f64> attributes {fir.bindc_name = "_FortranAzpowk", fir.runtime}
+
+// -----
+
+gpu.module @testmod {
+ gpu.func @_QPtest(%arg0: complex<f64>) -> (complex<f64>) {
+ gpu.return %arg0 : complex<f64>
+ }
+}
+
+// CHECK-LABEL: gpu.func @_QPtest
+// CHECK-SAME: (%arg0: f64, %arg1: f64) -> tuple<f64, f64> {
+// CHECK: gpu.return %{{.*}} : tuple<f64, f64>
More information about the flang-commits
mailing list