[Mlir-commits] [mlir] [MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (PR #99890)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 22 08:58:10 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (runseny)
<details>
<summary>Changes</summary>
Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.
1. lowering mathOp with fastMath attribute to correct libdevice intrinsic.
2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands.
---
Patch is 25.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99890.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+12-6)
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+29-17)
- (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+2-2)
- (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+221-21)
``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index ebce2d77310ae..7ce17a69d7e4d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -31,9 +31,9 @@ template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
- StringRef f64Func)
+ StringRef f64Func, StringRef f32FastFunc)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func) {}
+ f64Func(f64Func), f32FastFunc(f32FastFunc) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
- getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
+ getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
if (funcName.empty())
return failure();
@@ -90,9 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
- StringRef getFunctionName(Type type) const {
- if (isa<Float32Type>(type))
- return f32Func;
+ StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+ if (isa<Float32Type>(type)) {
+ if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
+ return f32FastFunc;
+ else
+ return f32Func;
+ }
if (isa<Float64Type>(type))
return f64Func;
return "";
@@ -113,8 +117,10 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32Func;
const std::string f64Func;
+ const std::string f32FastFunc;
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..9cfad02538c98 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
- target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
+ LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
+
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +371,53 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getMaxntidAttrName()));
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
+ "__nv_fmod");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
"__nv_atan2");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil");
- populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+ populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
"__nv_expm1");
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
- "__nv_fmod");
- populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+ populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+ "__nv_log10", "__nv_fast_log10f");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
"__nv_log1p");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
- "__nv_log10");
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
- "__nv_log2");
+ "__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
- "__nv_pow");
+ "__nv_pow", "__nv_fast_powf");
+ populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
"__nv_rsqrt");
- populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
"__nv_sqrt");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
- populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 03c7ce5dac0d1..4344fdc142cd2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,9 +38,9 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index d914790c05fe0..a3b79ae2561e1 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -254,13 +254,16 @@ gpu.module @test_module_9 {
gpu.module @test_module_10 {
// CHECK: llvm.func @__nv_cosf(f32) -> f32
// CHECK: llvm.func @__nv_cos(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_cosf(f32) -> f32
// CHECK-LABEL: func @gpu_cos
- func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__nv_cosf(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__nv_cos(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.cos %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_cosf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -268,13 +271,16 @@ gpu.module @test_module_10 {
gpu.module @test_module_11 {
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -297,13 +303,16 @@ gpu.module @test_module_12 {
gpu.module @test_module_13 {
// CHECK: llvm.func @__nv_logf(f32) -> f32
// CHECK: llvm.func @__nv_log(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_logf(f32) -> f32
// CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.call @__nv_logf(%{{.*}}) : (f32) -> f32
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__nv_log(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_logf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -312,13 +321,16 @@ gpu.module @test_module_13 {
gpu.module @test_module_14 {
// CHECK: llvm.func @__nv_log10f(f32) -> f32
// CHECK: llvm.func @__nv_log10(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log10f(f32) -> f32
// CHECK-LABEL: func @gpu_log10
- func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log10 %arg_f32 : f32
// CHECK: llvm.call @__nv_log10f(%{{.*}}) : (f32) -> f32
%result64 = math.log10 %arg_f64 : f64
// CHECK: llvm.call @__nv_log10(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log10 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log10f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -342,13 +354,16 @@ gpu.module @test_module_15 {
gpu.module @test_module_16 {
// CHECK: llvm.func @__nv_log2f(f32) -> f32
// CHECK: llvm.func @__nv_log2(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log2f(f32) -> f32
// CHECK-LABEL: func @gpu_log2
- func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log2 %arg_f32 : f32
// CHECK: llvm.call @__nv_log2f(%{{.*}}) : (f32) -> f32
%result64 = math.log2 %arg_f64 : f64
// CHECK: llvm.call @__nv_log2(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log2 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log2f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -357,13 +372,16 @@ gpu.module @test_module_16 {
gpu.module @test_module_17 {
// CHECK: llvm.func @__nv_sinf(f32) -> f32
// CHECK: llvm.func @__nv_sin(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_sinf(f32) -> f32
// CHECK-LABEL: func @gpu_sin
- func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__nv_sinf(%{{.*}}) : (f32) -> f32
%result64 = math.sin %arg_f64 : f64
// CHECK: llvm.call @__nv_sin(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.sin %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_sinf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -372,8 +390,9 @@ gpu.module @test_module_17 {
gpu.module @test_module_18 {
// CHECK: llvm.func @__nv_tanf(f32) -> f32
// CHECK: llvm.func @__nv_tan(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_tanf(f32) -> f32
// CHECK-LABEL: func @gpu_tan
- func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64, f32) {
%result16 = math.tan %arg_f16 : f16
// CHECK: llvm.fpext %{{.*}} : f16 to f32
// CHECK-NEXT: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
@@ -382,7 +401,9 @@ gpu.module @test_module_18 {
// CHECK: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
%result64 = math.tan %arg_f64 : f64
// CHECK: llvm.call @__nv_tan(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ %result32Fast = math.tan %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_tanf(%{{.*}}) : (f32) -> f32
+ func.return %result16, %result32, %result64, %result32Fast : f16, f32, f64, f32
}
}
@@ -494,13 +515,16 @@ gpu.module @test_module_24 {
// CHECK: test.symbol_scope
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
"test.finish" () : () -> ()
}) : () -> ()
@@ -526,13 +550,16 @@ gpu.module @test_module_25 {
gpu.module @test_module_26 {
// CHECK: llvm.func @__nv_powf(f32, f32) -> f32
// CHECK: llvm.func @__nv_pow(f64, f64) -> f64
+ // CHECK: llvm.func @__nv_fast_powf(f32, f32) -> f32
// CHECK-LABEL: func @gpu_pow
- func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.powf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__nv_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.powf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__nv_pow(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.powf %arg_f32, %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -701,6 +728,179 @@ gpu.module @test_module_34 {
}
}
+gpu.module @test_module_35 {
+ // CHECK: llvm.func @__nv_acosf(f32) -> f32
+ // CHECK: llvm.func @__nv_acos(f64) -> f64
+ // CHECK-LABEL: func @gpu_acos
+ func.func @gpu_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acosf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acos(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_36 {
+ // CHECK: llvm.func @__nv_acoshf(f32) -> f32
+ // CHECK: llvm.func @__nv_acosh(f64) -> f64
+ // CHECK-LABEL: func @gpu_acosh
+ func.func @gpu_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acoshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acosh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_37 {
+ // CHECK: llvm.func @__nv_asinf(f32) -> f32
+ // CHECK: llvm.func @__nv_asin(f64) -> f64
+ // CHECK-LABEL: func @gpu_asin
+ func.func @gpu_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asin(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_38 {
+ // CHECK: llvm.func @__nv_asinhf(f32) -> f32
+ // CHECK: llvm.func @__nv_asinh(f64) -> f64
+ // CHECK-LABEL: func @gpu_asinh
+ func.func @gpu_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asinh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_39 {
+ // CHECK: llvm.func @__nv_atanhf(f32) -> f32
+ // CHECK: llvm.func @__nv_atanh(f64) -> f64
+ // CHECK-LABEL: func @gpu_atanh
+ func.func @gpu_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
+ -> (f16, f32, f64) {
+ %result16 = math.atanh %arg_f16 : f16
+ // CHECK: llvm.fpext %{{.*}} : f16 to f32
+ // CHECK-NEXT: llvm.call @__nv_atanhf(%{{.*}}) : (f32) -> f3...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/99890
More information about the Mlir-commits
mailing list