[Mlir-commits] [mlir] [MLIR][GPUToNVVM] support fastMath and other non-supported mathOp (PR #99890)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 22 09:01:21 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 2e6558b8bcdaa4c0924f1f49a9200cb2dea44bd4 1c5f7e83dc76d7eeca7284466ae3758ebc47d3da --extensions cpp,h -- mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 7ce17a69d7..219f66e65f 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -55,7 +55,8 @@ public:
     Type resultType = castedOperands.front().getType();
     Type funcType = getFunctionType(resultType, castedOperands);
     StringRef funcName =
-        getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op.getFastmath());
+        getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
+                        op.getFastmath());
     if (funcName.empty())
       return failure();
 
@@ -120,7 +121,6 @@ private:
   const std::string f32FastFunc;
 };
 
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9cfad02538..2fe653aa99 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -309,11 +309,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
   target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
   target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
   target.addIllegalDialect<gpu::GPUDialect>();
-  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
-                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp, LLVM::FRemOp, LLVM::LogOp,
-                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::RoundEvenOp,
-                      LLVM::RoundOp, LLVM::SinOp, LLVM::SqrtOp>();
-
+  target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
+                      LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
+                      LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
+                      LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
+                      LLVM::SinOp, LLVM::SqrtOp>();
 
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
@@ -324,7 +324,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
                                StringRef f64Func, StringRef f32FastFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           f32FastFunc);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
@@ -375,24 +376,33 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
                                     "__nv_fmod");
   populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
                                    "__nv_fabs");
-  populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf", "__nv_acos");
-  populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf", "__nv_acosh");
-  populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf", "__nv_asin");
-  populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf", "__nv_asinh");
+  populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
+                                   "__nv_acos");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, "__nv_acoshf",
+                                    "__nv_acosh");
+  populateOpPatterns<math::AsinOp>(converter, patterns, "__nv_asinf",
+                                   "__nv_asin");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, "__nv_asinhf",
+                                    "__nv_asinh");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
                                    "__nv_atan");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
                                     "__nv_atan2");
-  populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf", "__nv_atanh");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, "__nv_atanhf",
+                                    "__nv_atanh");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
                                    "__nv_cbrt");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
                                    "__nv_ceil");
-  populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf", "__nv_copysign");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos", "__nv_fast_cosf");
-  populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf", "__nv_cosh");
+  populateOpPatterns<math::CopySignOp>(converter, patterns, "__nv_copysignf",
+                                       "__nv_copysign");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos",
+                                  "__nv_fast_cosf");
+  populateOpPatterns<math::CoshOp>(converter, patterns, "__nv_coshf",
+                                   "__nv_cosh");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp", "__nv_fast_expf");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp",
+                                  "__nv_fast_expf");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
                                    "__nv_exp2");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
@@ -400,24 +410,30 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
   populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
                                     "__nv_floor");
   populateOpPatterns<math::FmaOp>(converter, patterns, "__nv_fmaf", "__nv_fma");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log", "__nv_fast_logf");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log",
+                                  "__nv_fast_logf");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
                                     "__nv_log10", "__nv_fast_log10f");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
                                     "__nv_log1p");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
                                    "__nv_log2", "__nv_fast_log2f");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
-                                   "__nv_pow", "__nv_fast_powf");
-  populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf", "__nv_round");
-  populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf", "__nv_rint");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
+                                   "__nv_fast_powf");
+  populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
+                                    "__nv_round");
+  populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
+                                        "__nv_rint");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
                                     "__nv_rsqrt");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin", "__nv_fast_sinf");
-  populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf", "__nv_sinh");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin",
+                                  "__nv_fast_sinf");
+  populateOpPatterns<math::SinhOp>(converter, patterns, "__nv_sinhf",
+                                   "__nv_sinh");
   populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
                                    "__nv_sqrt");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan", "__nv_fast_tanf");
+  populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan",
+                                  "__nv_fast_tanf");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
                                    "__nv_tanh");
 }
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 4344fdc142..d74c6aa7ae 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -40,7 +40,8 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
                                StringRef f64Func, StringRef f32FastFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, f32FastFunc);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           f32FastFunc);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,

``````````

</details>


https://github.com/llvm/llvm-project/pull/99890


More information about the Mlir-commits mailing list