[Mlir-commits] [mlir] Add math to LLVM lowering support for missing trigonometric & hyperbolic ops (PR #125753)
Paul Carabas
llvmlistbot at llvm.org
Tue Feb 4 12:10:27 PST 2025
https://github.com/PaulCarabas created https://github.com/llvm/llvm-project/pull/125753
None
>From 94fc09c7c9dd31dee5062cf5ac484b527e685af8 Mon Sep 17 00:00:00 2001
From: PaulCarabas <paulcaraa at gmail.com>
Date: Tue, 4 Feb 2025 21:24:19 +0200
Subject: [PATCH 1/2] [mlir][LLVMIR] Add support for tan intrinsic op
---
.../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 7 +++++--
mlir/test/Target/LLVMIR/Import/intrinsic.ll | 19 +++++++++++++++----
.../test/Target/LLVMIR/llvmir-intrinsics.mlir | 18 ++++++++++++++----
3 files changed, 34 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index a7d683438ee8ab0..72fae1bdbf461df 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -107,7 +107,6 @@ def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
}
def LLVM_CopySignOp : LLVM_BinarySameArgsIntrOpF<"copysign">;
-def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">;
def LLVM_ExpOp : LLVM_UnaryIntrOpF<"exp">;
def LLVM_Exp2Op : LLVM_UnaryIntrOpF<"exp2">;
def LLVM_FAbsOp : LLVM_UnaryIntrOpF<"fabs">;
@@ -125,7 +124,6 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0],
> {
let arguments = (ins LLVM_AnyPointer:$addr, I32Attr:$rw, I32Attr:$hint, I32Attr:$cache);
}
-def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
def LLVM_RoundEvenOp : LLVM_UnaryIntrOpF<"roundeven">;
def LLVM_RoundOp : LLVM_UnaryIntrOpF<"round">;
def LLVM_FTruncOp : LLVM_UnaryIntrOpF<"trunc">;
@@ -167,6 +165,11 @@ def LLVM_SMaxOp : LLVM_BinarySameArgsIntrOpI<"smax">;
def LLVM_SMinOp : LLVM_BinarySameArgsIntrOpI<"smin">;
def LLVM_UMaxOp : LLVM_BinarySameArgsIntrOpI<"umax">;
def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">;
+
+def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
+def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">;
+def LLVM_TanOp : LLVM_UnaryIntrOpF<"tan">;
+
def LLVM_SinhOp : LLVM_UnaryIntrOpF<"sinh">;
def LLVM_CoshOp : LLVM_UnaryIntrOpF<"cosh">;
def LLVM_TanhOp : LLVM_UnaryIntrOpF<"tanh">;
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index bd335323a2e1c93..249a0552c87f380 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -101,12 +101,23 @@ define void @floor_test(float %0, <8 x float> %1) {
%4 = call <8 x float> @llvm.floor.v8f32(<8 x float> %1)
ret void
}
-; CHECK-LABEL: llvm.func @cos_test
-define void @cos_test(float %0, <8 x float> %1) {
+; CHECK-LABEL: llvm.func @trig_test
+define void @trig_test(float %0, <8 x float> %1) {
+ ; CHECK: llvm.intr.sin(%{{.*}}) : (f32) -> f32
+ %3 = call float @llvm.sin.f32(float %0)
+ ; CHECK: llvm.intr.sin(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+ %4 = call <8 x float> @llvm.sin.v8f32(<8 x float> %1)
+
; CHECK: llvm.intr.cos(%{{.*}}) : (f32) -> f32
- %3 = call float @llvm.cos.f32(float %0)
+ %5 = call float @llvm.cos.f32(float %0)
; CHECK: llvm.intr.cos(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
- %4 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1)
+ %6 = call <8 x float> @llvm.cos.v8f32(<8 x float> %1)
+
+ ; CHECK: llvm.intr.tan(%{{.*}}) : (f32) -> f32
+ %7 = call float @llvm.tan.f32(float %0)
+ ; CHECK: llvm.intr.tan(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+ %8 = call <8 x float> @llvm.tan.v8f32(<8 x float> %1)
+
ret void
}
; CHECK-LABEL: llvm.func @hyperbolic_trig_test
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index 382b2b9f3cd732e..2c208789e36ddbe 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -103,12 +103,22 @@ llvm.func @floor_test(%arg0: f32, %arg1: vector<8xf32>) {
llvm.return
}
-// CHECK-LABEL: @cos_test
-llvm.func @cos_test(%arg0: f32, %arg1: vector<8xf32>) {
+// CHECK-LABEL: @trig_test
+llvm.func @trig_test(%arg0: f32, %arg1: vector<8xf32>) {
+ // CHECK: call float @llvm.sin.f32
+ llvm.intr.sin(%arg0) : (f32) -> f32
+ // CHECK: call <8 x float> @llvm.sin.v8f32
+ llvm.intr.sin(%arg1) : (vector<8xf32>) -> vector<8xf32>
+
// CHECK: call float @llvm.cos.f32
- "llvm.intr.cos"(%arg0) : (f32) -> f32
+ llvm.intr.cos(%arg0) : (f32) -> f32
// CHECK: call <8 x float> @llvm.cos.v8f32
- "llvm.intr.cos"(%arg1) : (vector<8xf32>) -> vector<8xf32>
+ llvm.intr.cos(%arg1) : (vector<8xf32>) -> vector<8xf32>
+
+ // CHECK: call float @llvm.tan.f32
+ llvm.intr.tan(%arg0) : (f32) -> f32
+ // CHECK: call <8 x float> @llvm.tan.v8f32
+ llvm.intr.tan(%arg1) : (vector<8xf32>) -> vector<8xf32>
llvm.return
}
>From db8b3ab6e27366e680e84e34dd56a4475098225a Mon Sep 17 00:00:00 2001
From: PaulCarabas <paulcaraa at gmail.com>
Date: Tue, 4 Feb 2025 22:09:23 +0200
Subject: [PATCH 2/2] Add math to LLVM lowering support for missing
trigonometric & hyperbolic ops
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 10 ++++++-
.../Conversion/MathToLLVM/math-to-llvm.mlir | 26 +++++++++++++++++--
2 files changed, 33 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 668f8385ac2dcf4..98680773e00d2ac 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -39,6 +39,7 @@ using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
+using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
using CtPopFOpLowering =
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
@@ -58,9 +59,12 @@ using RoundEvenOpLowering =
using RoundOpLowering =
ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
+using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using FTruncOpLowering =
ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
+using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
+using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
template <typename MathOp, typename LLVMOp>
@@ -310,6 +314,7 @@ void mlir::populateMathToLLVMConversionPatterns(
CeilOpLowering,
CopySignOpLowering,
CosOpLowering,
+ CoshOpLowering,
CountLeadingZerosOpLowering,
CountTrailingZerosOpLowering,
CtPopFOpLowering,
@@ -327,8 +332,11 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundOpLowering,
RsqrtOpLowering,
SinOpLowering,
+ SinhOpLowering,
SqrtOpLowering,
- FTruncOpLowering
+ FTruncOpLowering,
+ TanOpLowering,
+ TanhOpLowering
>(converter);
// clang-format on
}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 56129dbd2788923..24eef9341bf74f2 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -161,11 +161,33 @@ func.func @rsqrt(%arg0 : f32) {
// -----
-// CHECK-LABEL: func @sine(
+// CHECK-LABEL: func @trigonometrics(
// CHECK-SAME: f32
-func.func @sine(%arg0 : f32) {
+func.func @trigonometrics(%arg0 : f32) {
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32
%0 = math.sin %arg0 : f32
+
+ // CHECK: llvm.intr.cos(%arg0) : (f32) -> f32
+ %1 = math.cos %arg0 : f32
+
+ // CHECK: llvm.intr.tan(%arg0) : (f32) -> f32
+ %2 = math.tan %arg0 : f32
+ func.return
+}
+
+// -----
+
+// CHECK-LABEL: func @hyperbolics(
+// CHECK-SAME: f32
+func.func @hyperbolics(%arg0 : f32) {
+ // CHECK: llvm.intr.sinh(%arg0) : (f32) -> f32
+ %0 = math.sinh %arg0 : f32
+
+ // CHECK: llvm.intr.cosh(%arg0) : (f32) -> f32
+ %1 = math.cosh %arg0 : f32
+
+ // CHECK: llvm.intr.tanh(%arg0) : (f32) -> f32
+ %2 = math.tanh %arg0 : f32
func.return
}
More information about the Mlir-commits
mailing list