[Mlir-commits] [mlir] [MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (PR #99890)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 22 08:57:37 PDT 2024
https://github.com/runseny created https://github.com/llvm/llvm-project/pull/99890
Support fastMath and other non-supported mathOp which only require float operands and call libdevice function directly to nvvm.
1. lowering mathOp with fastMath attribute to correct libdevice intrinsic.
2. some mathOp in math dialect has been lowered to libdevice now, but it doesn't cover all mathOp. so this mr lowers all the remaining mathOp which only require float operands.
>From 1c5f7e83dc76d7eeca7284466ae3758ebc47d3da Mon Sep 17 00:00:00 2001
From: runseny <runseny at nvidia.com>
Date: Mon, 22 Jul 2024 15:53:38 +0000
Subject: [PATCH] [MLIR][GPUToNVVM] support fastMath and other non-supported
mathOp
---
.../GPUCommon/OpToFuncCallLowering.h | 18 +-
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 46 ++--
.../Conversion/MathToROCDL/MathToROCDL.cpp | 4 +-
.../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 242 ++++++++++++++++--
4 files changed, 264 insertions(+), 46 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index ebce2d77310ae..7ce17a69d7e4d 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -31,9 +31,9 @@ template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
- StringRef f64Func)
+ StringRef f64Func, StringRef f32FastFunc)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func) {}
+ f64Func(f64Func), f32FastFunc(f32FastFunc) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -55,7 +55,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
- getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType());
+ getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
if (funcName.empty())
return failure();
@@ -90,9 +90,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
- StringRef getFunctionName(Type type) const {
- if (isa<Float32Type>(type))
- return f32Func;
+ StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+ if (isa<Float32Type>(type)) {
+ if (arith::FastMathFlags::fast == flag && !f32FastFunc.empty())
+ return f32FastFunc;
+ else
+ return f32Func;
+ }
if (isa<Float64Type>(type))
return f64Func;
return "";
@@ -113,8 +117,10 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32Func;
const std::string f64Func;
+ const std::string f32FastFunc;
};
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index fea8a0ddc7f06..9cfad02538c98 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,10 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
- target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
+ LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
+
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -321,9 +322,9 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -370,42 +371,53 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getMaxntidAttrName()));
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
+ "__nv_fmod");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
"__nv_atan2");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil");
- populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
+ populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
"__nv_expm1");
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
- "__nv_fmod");
- populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
+ populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
+ "__nv_log10", "__nv_fast_log10f");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
"__nv_log1p");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
- "__nv_log10");
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
- "__nv_log2");
+ "__nv_log2", "__nv_fast_log2f");
populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
- "__nv_pow");
+ "__nv_pow", "__nv_fast_powf");
+ populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
"__nv_rsqrt");
- populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
"__nv_sqrt");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
- populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 03c7ce5dac0d1..4344fdc142cd2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,9 +38,9 @@ using namespace mlir;
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index d914790c05fe0..a3b79ae2561e1 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -254,13 +254,16 @@ gpu.module @test_module_9 {
gpu.module @test_module_10 {
// CHECK: llvm.func @__nv_cosf(f32) -> f32
// CHECK: llvm.func @__nv_cos(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_cosf(f32) -> f32
// CHECK-LABEL: func @gpu_cos
- func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__nv_cosf(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__nv_cos(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.cos %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_cosf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -268,13 +271,16 @@ gpu.module @test_module_10 {
gpu.module @test_module_11 {
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -297,13 +303,16 @@ gpu.module @test_module_12 {
gpu.module @test_module_13 {
// CHECK: llvm.func @__nv_logf(f32) -> f32
// CHECK: llvm.func @__nv_log(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_logf(f32) -> f32
// CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.call @__nv_logf(%{{.*}}) : (f32) -> f32
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__nv_log(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_logf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -312,13 +321,16 @@ gpu.module @test_module_13 {
gpu.module @test_module_14 {
// CHECK: llvm.func @__nv_log10f(f32) -> f32
// CHECK: llvm.func @__nv_log10(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log10f(f32) -> f32
// CHECK-LABEL: func @gpu_log10
- func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log10 %arg_f32 : f32
// CHECK: llvm.call @__nv_log10f(%{{.*}}) : (f32) -> f32
%result64 = math.log10 %arg_f64 : f64
// CHECK: llvm.call @__nv_log10(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log10 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log10f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -342,13 +354,16 @@ gpu.module @test_module_15 {
gpu.module @test_module_16 {
// CHECK: llvm.func @__nv_log2f(f32) -> f32
// CHECK: llvm.func @__nv_log2(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_log2f(f32) -> f32
// CHECK-LABEL: func @gpu_log2
- func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.log2 %arg_f32 : f32
// CHECK: llvm.call @__nv_log2f(%{{.*}}) : (f32) -> f32
%result64 = math.log2 %arg_f64 : f64
// CHECK: llvm.call @__nv_log2(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.log2 %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_log2f(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -357,13 +372,16 @@ gpu.module @test_module_16 {
gpu.module @test_module_17 {
// CHECK: llvm.func @__nv_sinf(f32) -> f32
// CHECK: llvm.func @__nv_sin(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_sinf(f32) -> f32
// CHECK-LABEL: func @gpu_sin
- func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__nv_sinf(%{{.*}}) : (f32) -> f32
%result64 = math.sin %arg_f64 : f64
// CHECK: llvm.call @__nv_sin(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.sin %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_sinf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -372,8 +390,9 @@ gpu.module @test_module_17 {
gpu.module @test_module_18 {
// CHECK: llvm.func @__nv_tanf(f32) -> f32
// CHECK: llvm.func @__nv_tan(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_tanf(f32) -> f32
// CHECK-LABEL: func @gpu_tan
- func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64, f32) {
%result16 = math.tan %arg_f16 : f16
// CHECK: llvm.fpext %{{.*}} : f16 to f32
// CHECK-NEXT: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
@@ -382,7 +401,9 @@ gpu.module @test_module_18 {
// CHECK: llvm.call @__nv_tanf(%{{.*}}) : (f32) -> f32
%result64 = math.tan %arg_f64 : f64
// CHECK: llvm.call @__nv_tan(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
+ %result32Fast = math.tan %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_tanf(%{{.*}}) : (f32) -> f32
+ func.return %result16, %result32, %result64, %result32Fast : f16, f32, f64, f32
}
}
@@ -494,13 +515,16 @@ gpu.module @test_module_24 {
// CHECK: test.symbol_scope
// CHECK: llvm.func @__nv_expf(f32) -> f32
// CHECK: llvm.func @__nv_exp(f64) -> f64
+ // CHECK: llvm.func @__nv_fast_expf(f32) -> f32
// CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__nv_expf(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__nv_exp(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.exp %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_expf(%{{.*}}) : (f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
"test.finish" () : () -> ()
}) : () -> ()
@@ -526,13 +550,16 @@ gpu.module @test_module_25 {
gpu.module @test_module_26 {
// CHECK: llvm.func @__nv_powf(f32, f32) -> f32
// CHECK: llvm.func @__nv_pow(f64, f64) -> f64
+ // CHECK: llvm.func @__nv_fast_powf(f32, f32) -> f32
// CHECK-LABEL: func @gpu_pow
- func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64, f32) {
%result32 = math.powf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__nv_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.powf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__nv_pow(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result32Fast = math.powf %arg_f32, %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_fast_powf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ func.return %result32, %result64, %result32Fast : f32, f64, f32
}
}
@@ -701,6 +728,179 @@ gpu.module @test_module_34 {
}
}
+gpu.module @test_module_35 {
+ // CHECK: llvm.func @__nv_acosf(f32) -> f32
+ // CHECK: llvm.func @__nv_acos(f64) -> f64
+ // CHECK-LABEL: func @gpu_acos
+ func.func @gpu_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acosf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acos(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_36 {
+ // CHECK: llvm.func @__nv_acoshf(f32) -> f32
+ // CHECK: llvm.func @__nv_acosh(f64) -> f64
+ // CHECK-LABEL: func @gpu_acosh
+ func.func @gpu_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_acoshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_acosh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_37 {
+ // CHECK: llvm.func @__nv_asinf(f32) -> f32
+ // CHECK: llvm.func @__nv_asin(f64) -> f64
+ // CHECK-LABEL: func @gpu_asin
+ func.func @gpu_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asin(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_38 {
+ // CHECK: llvm.func @__nv_asinhf(f32) -> f32
+ // CHECK: llvm.func @__nv_asinh(f64) -> f64
+ // CHECK-LABEL: func @gpu_asinh
+ func.func @gpu_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_asinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_asinh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_39 {
+ // CHECK: llvm.func @__nv_atanhf(f32) -> f32
+ // CHECK: llvm.func @__nv_atanh(f64) -> f64
+ // CHECK-LABEL: func @gpu_atanh
+ func.func @gpu_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
+ -> (f16, f32, f64) {
+ %result16 = math.atanh %arg_f16 : f16
+ // CHECK: llvm.fpext %{{.*}} : f16 to f32
+ // CHECK-NEXT: llvm.call @__nv_atanhf(%{{.*}}) : (f32) -> f32
+ // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
+ %result32 = math.atanh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_atanhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.atanh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_atanh(%{{.*}}) : (f64) -> f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
+ }
+}
+
+gpu.module @test_module_40 {
+ // CHECK: llvm.func @__nv_copysignf(f32, f32) -> f32
+ // CHECK: llvm.func @__nv_copysign(f64, f64) -> f64
+ // CHECK-LABEL: func @gpu_copysign
+ func.func @gpu_copysign(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.copysign %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__nv_copysignf(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = math.copysign %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__nv_copysign(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_41 {
+ // CHECK: llvm.func @__nv_coshf(f32) -> f32
+ // CHECK: llvm.func @__nv_cosh(f64) -> f64
+ // CHECK-LABEL: func @gpu_cosh
+ func.func @gpu_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.cosh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_coshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.cosh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_cosh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_42 {
+ // CHECK: llvm.func @__nv_fmaf(f32, f32, f32) -> f32
+ // CHECK: llvm.func @__nv_fma(f64, f64, f64) -> f64
+ // CHECK-LABEL: func @gpu_fma
+ func.func @gpu_fma(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.fma %arg_f32, %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__nv_fmaf(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, f32, f32) -> f32
+ %result64 = math.fma %arg_f64, %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__nv_fma(%{{.*}}, %{{.*}}, %{{.*}}) : (f64, f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_43 {
+ // CHECK: llvm.func @__nv_roundf(f32) -> f32
+ // CHECK: llvm.func @__nv_round(f64) -> f64
+ // CHECK-LABEL: func @gpu_round
+ func.func @gpu_round(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.round %arg_f32 : f32
+ // CHECK: llvm.call @__nv_roundf(%{{.*}}) : (f32) -> f32
+ %result64 = math.round %arg_f64 : f64
+ // CHECK: llvm.call @__nv_round(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_44 {
+ // CHECK: llvm.func @__nv_rintf(f32) -> f32
+ // CHECK: llvm.func @__nv_rint(f64) -> f64
+ // CHECK-LABEL: func @gpu_roundeven
+ func.func @gpu_roundeven(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.roundeven %arg_f32 : f32
+ // CHECK: llvm.call @__nv_rintf(%{{.*}}) : (f32) -> f32
+ %result64 = math.roundeven %arg_f64 : f64
+ // CHECK: llvm.call @__nv_rint(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_45 {
+ // CHECK: llvm.func @__nv_sinhf(f32) -> f32
+ // CHECK: llvm.func @__nv_sinh(f64) -> f64
+ // CHECK-LABEL: func @gpu_sinh
+ func.func @gpu_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.sinh %arg_f32 : f32
+ // CHECK: llvm.call @__nv_sinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.sinh %arg_f64 : f64
+ // CHECK: llvm.call @__nv_sinh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_46 {
+ // CHECK: llvm.func @__nv_coshf(f32) -> f32
+ // CHECK: llvm.func @__nv_cosh(f64) -> f64
+ // CHECK-LABEL: func @gpu_cosh_with_fastmath
+ func.func @gpu_cosh_with_fastmath(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.cosh %arg_f32 fastmath<fast> : f32
+ // CHECK: llvm.call @__nv_coshf(%{{.*}}) : (f32) -> f32
+ %result64 = math.cosh %arg_f64 fastmath<fast> : f64
+ // CHECK: llvm.call @__nv_cosh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+gpu.module @test_module_47 {
+ // CHECK: llvm.func @__nv_sinhf(f32) -> f32
+ // CHECK: llvm.func @__nv_sinh(f64) -> f64
+ // CHECK-LABEL: func @gpu_sinh_with_fastmath
+ func.func @gpu_sinh_with_fastmath(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.sinh %arg_f32 fastmath<contract> : f32
+ // CHECK: llvm.call @__nv_sinhf(%{{.*}}) : (f32) -> f32
+ %result64 = math.sinh %arg_f64 fastmath<none> : f64
+ // CHECK: llvm.call @__nv_sinh(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
@@ -729,9 +929,9 @@ module attributes {transform.with_named_sequence} {
legal_dialects = ["llvm", "memref", "nvvm", "test"],
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],
illegal_dialects = ["gpu"],
- illegal_ops = ["llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
- "llvm.ffloor", "llvm.log", "llvm.log10", "llvm.log2","llvm.pow",
- "llvm.sin", "llvm.sqrt"],
+ illegal_ops = ["llvm.copysign", "llvm.cos", "llvm.exp", "llvm.exp2", "llvm.fabs", "llvm.fceil",
+ "llvm.ffloor", "llvm.fma", "llvm.frem", "llvm.log", "llvm.log10", "llvm.log2", "llvm.pow",
+ "llvm.roundeven", "llvm.round", "llvm.sin", "llvm.sqrt"],
partial_conversion
} : !transform.any_op
transform.yield
More information about the Mlir-commits
mailing list