[Mlir-commits] [mlir] [MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (PR #99890)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 22 09:01:21 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 2e6558b8bcdaa4c0924f1f49a9200cb2dea44bd4 1c5f7e83dc76d7eeca7284466ae3758ebc47d3da --extensions cpp,h -- mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 7ce17a69d7..219f66e65f 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -55,7 +55,8 @@ public:
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName =
- getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
+ getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
+ op.getFastmath());
if (funcName.empty())
return failure();
@@ -120,7 +121,6 @@ private:
const std::string f32FastFunc;
};
-
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9cfad02538..2fe653aa99 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,11 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
- target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
- LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
- LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
- LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
-
+ target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
+ LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
+ LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
+ LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
+ LLVM::SinOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -324,7 +324,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32FastFunc);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -375,24 +376,33 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
"__nv_fmod");
populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
- populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
- populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
- populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
- populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
+ "__nv_acos");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
+ "__nv_acosh");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
+ "__nv_asin");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
+ "__nv_asinh");
populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
"__nv_atan");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
"__nv_atan2");
- populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
+ "__nv_atanh");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
"__nv_cbrt");
populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
"__nv_ceil");
- populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
- populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
- populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
+ populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
+ "__nv_copysign");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
+ "__nv_fast_cosf");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
+ "__nv_cosh");
populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
+ "__nv_fast_expf");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
"__nv_exp2");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
@@ -400,24 +410,30 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
"__nv_floor");
populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
- populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
+ "__nv_fast_logf");
populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
"__nv_log10", "__nv_fast_log10f");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
"__nv_log1p");
populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
"__nv_log2", "__nv_fast_log2f");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
- "__nv_pow", "__nv_fast_powf");
- populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
- populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
+ "__nv_fast_powf");
+ populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
+ "__nv_round");
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
+ "__nv_rint");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
"__nv_rsqrt");
- populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
- populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
+ "__nv_fast_sinf");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
+ "__nv_sinh");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
"__nv_sqrt");
- populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
+ "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
"__nv_tanh");
}
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 4344fdc142..d74c6aa7ae 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -40,7 +40,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func, StringRef f32FastFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32FastFunc);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
``````````
</details>
https://github.com/llvm/llvm-project/pull/99890
More information about the Mlir-commits
mailing list