[Mlir-commits] [mlir] 196d897 - [mlir][llvm] Add rounding intrinsics
Markus Böck
llvmlistbot at llvm.org
Mon May 29 09:12:58 PDT 2023
Author: Lukas Sommer
Date: 2023-05-29T18:13:08+02:00
New Revision: 196d89740c5e8bf238200b7f95e6173b231aa5d2
URL: https://github.com/llvm/llvm-project/commit/196d89740c5e8bf238200b7f95e6173b231aa5d2
DIFF: https://github.com/llvm/llvm-project/commit/196d89740c5e8bf238200b7f95e6173b231aa5d2.diff
LOG: [mlir][llvm] Add rounding intrinsics
Add some of the missing libm rounding intrinsics to the LLVM dialect:
* `llvm.rint`
* `llvm.nearbyint`
* `llvm.lround`
* `llvm.llround`
* `llvm.lrint`
* `llvm.llrint`
Differential Revision: https://reviews.llvm.org/D151558
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/test/Target/LLVMIR/Import/intrinsic.ll
mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index eb815b3f0b0d4..a409223ade155 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -130,6 +130,18 @@ def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1],
let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
"functional-type(operands, results)";
}
+def LLVM_RintOp : LLVM_UnaryIntrOpF<"rint">;
+def LLVM_NearbyintOp : LLVM_UnaryIntrOpF<"nearbyint">;
+class LLVM_IntRoundIntrOpBase<string func> :
+ LLVM_OneResultIntrOp<func, [0], [0], [Pure]> {
+ let arguments = (ins LLVM_AnyFloat:$val);
+ let assemblyFormat = "`(` operands `)` custom<LLVMOpAttrs>(attr-dict) `:` "
+ "functional-type(operands, results)";
+}
+def LLVM_LroundOp : LLVM_IntRoundIntrOpBase<"lround">;
+def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">;
+def LLVM_LrintOp : LLVM_IntRoundIntrOpBase<"lrint">;
+def LLVM_LlrintOp : LLVM_IntRoundIntrOpBase<"llrint">;
def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">;
def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">;
def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">;
diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
index 811dc44973410..e9b361509d037 100644
--- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll
+++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll
@@ -117,6 +117,72 @@ define void @pow_test(float %0, float %1, <8 x float> %2, <8 x float> %3) {
%6 = call <8 x float> @llvm.pow.v8f32(<8 x float> %2, <8 x float> %3)
ret void
}
+
+; CHECK-LABEL: llvm.func @rint_test
+define void @rint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) {
+ ; CHECK: llvm.intr.rint(%{{.*}}) : (f32) -> f32
+ %5 = call float @llvm.rint.f32(float %0)
+ ; CHECK: llvm.intr.rint(%{{.*}}) : (f64) -> f64
+ %6 = call double @llvm.rint.f64(double %1)
+ ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+ %7 = call <8 x float> @llvm.rint.v8f32(<8 x float> %2)
+ ; CHECK: llvm.intr.rint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+ %8 = call <8 x double> @llvm.rint.v8f64(<8 x double> %3)
+ ret void
+}
+; CHECK-LABEL: llvm.func @nearbyint_test
+define void @nearbyint_test(float %0, double %1, <8 x float> %2, <8 x double> %3) {
+ ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f32) -> f32
+ %5 = call float @llvm.nearbyint.f32(float %0)
+ ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (f64) -> f64
+ %6 = call double @llvm.nearbyint.f64(double %1)
+ ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf32>) -> vector<8xf32>
+ %7 = call <8 x float> @llvm.nearbyint.v8f32(<8 x float> %2)
+ ; CHECK: llvm.intr.nearbyint(%{{.*}}) : (vector<8xf64>) -> vector<8xf64>
+ %8 = call <8 x double> @llvm.nearbyint.v8f64(<8 x double> %3)
+ ret void
+}
+; CHECK-LABEL: llvm.func @lround_test
+define void @lround_test(float %0, double %1) {
+ ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32
+ %3 = call i32 @llvm.lround.i32.f32(float %0)
+ ; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i64
+ %4 = call i64 @llvm.lround.i64.f32(float %0)
+ ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i32
+ %5 = call i32 @llvm.lround.i32.f64(double %1)
+ ; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i64
+ %6 = call i64 @llvm.lround.i64.f64(double %1)
+ ret void
+}
+; CHECK-LABEL: llvm.func @llround_test
+define void @llround_test(float %0, double %1) {
+ ; CHECK: llvm.intr.llround(%{{.*}}) : (f32) -> i64
+ %3 = call i64 @llvm.llround.i64.f32(float %0)
+ ; CHECK: llvm.intr.llround(%{{.*}}) : (f64) -> i64
+ %4 = call i64 @llvm.llround.i64.f64(double %1)
+ ret void
+}
+; CHECK-LABEL: llvm.func @lrint_test
+define void @lrint_test(float %0, double %1) {
+ ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i32
+ %3 = call i32 @llvm.lrint.i32.f32(float %0)
+ ; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i64
+ %4 = call i64 @llvm.lrint.i64.f32(float %0)
+ ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i32
+ %5 = call i32 @llvm.lrint.i32.f64(double %1)
+ ; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i64
+ %6 = call i64 @llvm.lrint.i64.f64(double %1)
+ ret void
+}
+; CHECK-LABEL: llvm.func @llrint_test
+define void @llrint_test(float %0, double %1) {
+ ; CHECK: llvm.intr.llrint(%{{.*}}) : (f32) -> i64
+ %3 = call i64 @llvm.llrint.i64.f32(float %0)
+ ; CHECK: llvm.intr.llrint(%{{.*}}) : (f64) -> i64
+ %4 = call i64 @llvm.llrint.i64.f64(double %1)
+ ret void
+}
+
; CHECK-LABEL: llvm.func @bitreverse_test
define void @bitreverse_test(i32 %0, <8 x i32> %1) {
; CHECK: llvm.intr.bitreverse(%{{.*}}) : (i32) -> i32
@@ -781,6 +847,26 @@ declare float @llvm.copysign.f32(float, float)
declare <8 x float> @llvm.copysign.v8f32(<8 x float>, <8 x float>)
declare float @llvm.pow.f32(float, float)
declare <8 x float> @llvm.pow.v8f32(<8 x float>, <8 x float>)
+declare float @llvm.rint.f32(float)
+declare double @llvm.rint.f64(double)
+declare <8 x float> @llvm.rint.v8f32(<8 x float>)
+declare <8 x double> @llvm.rint.v8f64(<8 x double>)
+declare float @llvm.nearbyint.f32(float)
+declare double @llvm.nearbyint.f64(double)
+declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>)
+declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>)
+declare i32 @llvm.lround.i32.f32(float)
+declare i64 @llvm.lround.i64.f32(float)
+declare i32 @llvm.lround.i32.f64(double)
+declare i64 @llvm.lround.i64.f64(double)
+declare i64 @llvm.llround.i64.f32(float)
+declare i64 @llvm.llround.i64.f64(double)
+declare i32 @llvm.lrint.i32.f32(float)
+declare i64 @llvm.lrint.i64.f32(float)
+declare i32 @llvm.lrint.i32.f64(double)
+declare i64 @llvm.lrint.i64.f64(double)
+declare i64 @llvm.llrint.i64.f32(float)
+declare i64 @llvm.llrint.i64.f64(double)
declare i32 @llvm.bitreverse.i32(i32)
declare <8 x i32> @llvm.bitreverse.v8i32(<8 x i32>)
declare i32 @llvm.bswap.i32(i32)
diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
index c6a3c7fbb4450..ec619b9a9d367 100644
--- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir
@@ -134,6 +134,76 @@ llvm.func @pow_test(%arg0: f32, %arg1: f32, %arg2: vector<8xf32>, %arg3: vector<
llvm.return
}
+// CHECK-LABEL: @rint_test
+llvm.func @rint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) {
+ // CHECK: call float @llvm.rint.f32
+ "llvm.intr.rint"(%arg0) : (f32) -> f32
+ // CHECK: call double @llvm.rint.f64
+ "llvm.intr.rint"(%arg1) : (f64) -> f64
+ // CHECK: call <8 x float> @llvm.rint.v8f32
+ "llvm.intr.rint"(%arg2) : (vector<8xf32>) -> vector<8xf32>
+ // CHECK: call <8 x double> @llvm.rint.v8f64
+ "llvm.intr.rint"(%arg3) : (vector<8xf64>) -> vector<8xf64>
+ llvm.return
+}
+
+// CHECK-LABEL: @nearbyint_test
+llvm.func @nearbyint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3 : vector<8xf64>) {
+ // CHECK: call float @llvm.nearbyint.f32
+ "llvm.intr.nearbyint"(%arg0) : (f32) -> f32
+ // CHECK: call double @llvm.nearbyint.f64
+ "llvm.intr.nearbyint"(%arg1) : (f64) -> f64
+ // CHECK: call <8 x float> @llvm.nearbyint.v8f32
+ "llvm.intr.nearbyint"(%arg2) : (vector<8xf32>) -> vector<8xf32>
+ // CHECK: call <8 x double> @llvm.nearbyint.v8f64
+ "llvm.intr.nearbyint"(%arg3) : (vector<8xf64>) -> vector<8xf64>
+ llvm.return
+}
+
+// CHECK-LABEL: @lround_test
+llvm.func @lround_test(%arg0 : f32, %arg1 : f64) {
+ // CHECK: call i32 @llvm.lround.i32.f32
+ "llvm.intr.lround"(%arg0) : (f32) -> i32
+ // CHECK: call i64 @llvm.lround.i64.f32
+ "llvm.intr.lround"(%arg0) : (f32) -> i64
+ // CHECK: call i32 @llvm.lround.i32.f64
+ "llvm.intr.lround"(%arg1) : (f64) -> i32
+ // CHECK: call i64 @llvm.lround.i64.f64
+ "llvm.intr.lround"(%arg1) : (f64) -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @llround_test
+llvm.func @llround_test(%arg0 : f32, %arg1 : f64) {
+ // CHECK: call i64 @llvm.llround.i64.f32
+ "llvm.intr.llround"(%arg0) : (f32) -> i64
+ // CHECK: call i64 @llvm.llround.i64.f64
+ "llvm.intr.llround"(%arg1) : (f64) -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @lrint_test
+llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) {
+ // CHECK: call i32 @llvm.lrint.i32.f32
+ "llvm.intr.lrint"(%arg0) : (f32) -> i32
+ // CHECK: call i64 @llvm.lrint.i64.f32
+ "llvm.intr.lrint"(%arg0) : (f32) -> i64
+ // CHECK: call i32 @llvm.lrint.i32.f64
+ "llvm.intr.lrint"(%arg1) : (f64) -> i32
+ // CHECK: call i64 @llvm.lrint.i64.f64
+ "llvm.intr.lrint"(%arg1) : (f64) -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @llrint_test
+llvm.func @llrint_test(%arg0 : f32, %arg1 : f64) {
+ // CHECK: call i64 @llvm.llrint.i64.f32
+ "llvm.intr.llrint"(%arg0) : (f32) -> i64
+ // CHECK: call i64 @llvm.llrint.i64.f64
+ "llvm.intr.llrint"(%arg1) : (f64) -> i64
+ llvm.return
+}
+
// CHECK-LABEL: @bitreverse_test
llvm.func @bitreverse_test(%arg0: i32, %arg1: vector<8xi32>) {
// CHECK: call i32 @llvm.bitreverse.i32
@@ -865,6 +935,26 @@ llvm.func @lifetime(%p: !llvm.ptr) {
// CHECK-DAG: declare float @llvm.cos.f32(float)
// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
// CHECK-DAG: declare float @llvm.copysign.f32(float, float)
+// CHECK-DAG: declare float @llvm.rint.f32(float)
+// CHECK-DAG: declare double @llvm.rint.f64(double)
+// CHECK-DAG: declare <8 x float> @llvm.rint.v8f32(<8 x float>)
+// CHECK-DAG: declare <8 x double> @llvm.rint.v8f64(<8 x double>)
+// CHECK-DAG: declare float @llvm.nearbyint.f32(float)
+// CHECK-DAG: declare double @llvm.nearbyint.f64(double)
+// CHECK-DAG: declare <8 x float> @llvm.nearbyint.v8f32(<8 x float>)
+// CHECK-DAG: declare <8 x double> @llvm.nearbyint.v8f64(<8 x double>)
+// CHECK-DAG: declare i32 @llvm.lround.i32.f32(float)
+// CHECK-DAG: declare i64 @llvm.lround.i64.f32(float)
+// CHECK-DAG: declare i32 @llvm.lround.i32.f64(double)
+// CHECK-DAG: declare i64 @llvm.lround.i64.f64(double)
+// CHECK-DAG: declare i64 @llvm.llround.i64.f32(float)
+// CHECK-DAG: declare i64 @llvm.llround.i64.f64(double)
+// CHECK-DAG: declare i32 @llvm.lrint.i32.f32(float)
+// CHECK-DAG: declare i64 @llvm.lrint.i64.f32(float)
+// CHECK-DAG: declare i32 @llvm.lrint.i32.f64(double)
+// CHECK-DAG: declare i64 @llvm.lrint.i64.f64(double)
+// CHECK-DAG: declare i64 @llvm.llrint.i64.f32(float)
+// CHECK-DAG: declare i64 @llvm.llrint.i64.f64(double)
// CHECK-DAG: declare <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float>, <48 x float>, i32 immarg, i32 immarg, i32 immarg)
// CHECK-DAG: declare <48 x float> @llvm.matrix.transpose.v48f32(<48 x float>, i32 immarg, i32 immarg)
// CHECK-DAG: declare <48 x float> @llvm.matrix.column.major.load.v48f32.i64(ptr nocapture, i64, i1 immarg, i32 immarg, i32 immarg)
More information about the Mlir-commits
mailing list