[Mlir-commits] [mlir] 4639a85 - [mlir] Add math.roundeven and llvm.intr.roundeven
Tres Popp
llvmlistbot at llvm.org
Thu Aug 25 04:39:14 PDT 2022
Author: Tres Popp
Date: 2022-08-25T13:39:01+02:00
New Revision: 4639a85f94be6355c55dcc802a9b01b622d6f7b1
URL: https://github.com/llvm/llvm-project/commit/4639a85f94be6355c55dcc802a9b01b622d6f7b1
DIFF: https://github.com/llvm/llvm-project/commit/4639a85f94be6355c55dcc802a9b01b622d6f7b1.diff
LOG: [mlir] Add math.roundeven and llvm.intr.roundeven
This is similar to math.round, but rounds to even instead of rounding away from
zero in the case of halfway values. This CL also adds lowerings to libm and
to the LLVM intrinsic.
Differential Revision: https://reviews.llvm.org/D132375
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
mlir/test/Dialect/Math/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index ca989bfc6b5dd..219efd57de357 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -61,6 +61,7 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
LLVM_Type:$cache);
}
def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
+def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">;
def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">;
def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 706e2adb5852e..21b5db85f484c 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -741,6 +741,35 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// RoundEvenOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundEvenOp : Math_FloatUnaryOp<"roundeven"> {
+ let summary = "round of the specified value with halfway cases to even";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `math.roundeven` ssa-use `:` type
+ ```
+
+ The `roundeven` operation returns the operand rounded to the nearest integer
+ value in floating-point format. It takes one operand of floating point type
+ (i.e., scalar, tensor or vector) and produces one result of the same type. The
+ operation rounds the argument to the nearest integer value in floating-point
+ format, rounding halfway cases to even, regardless of the current
+ rounding direction.
+
+ Example:
+
+ ```mlir
+ // Scalar round operation.
+ %a = math.roundeven %b : f64
+ ```
+ }];
+}
+
//===----------------------------------------------------------------------===//
// RoundOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index cb34982d64d77..8161cc5e419ff 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -35,6 +35,8 @@ using Log10OpLowering =
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+using RoundEvenOpLowering =
+ VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
using RoundOpLowering =
VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
@@ -285,6 +287,7 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
Log2OpLowering,
LogOpLowering,
PowFOpLowering,
+ RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
SinOpLowering,
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6e3bd2af7bd62..5071d60f9e473 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -141,16 +141,19 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
void mlir::populateMathToLibmConversionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
llvm::Optional<PatternBenefit> log1pBenefit) {
- patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
- VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
- VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
- VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
- VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
+ patterns
+ .add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+ VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
+ VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
+ VecOpToScalarOp<math::RoundEvenOp>, VecOpToScalarOp<math::RoundOp>,
+ VecOpToScalarOp<math::AtanOp>, VecOpToScalarOp<math::TanOp>>(
+ patterns.getContext(), benefit);
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
- PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
- PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
+ PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
+ PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>>(
+ patterns.getContext(), benefit);
patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
"atan", benefit);
patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
@@ -163,6 +166,8 @@ void mlir::populateMathToLibmConversionPatterns(
"tan", benefit);
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
"tanh", benefit);
+ patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
+ patterns.getContext(), "roundevenf", "roundeven", benefit);
patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
"roundf", "round", benefit);
patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 64e20188e0f2c..b87ddbacab539 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -190,3 +190,13 @@ func.func @round(%arg0 : f32) {
%0 = math.round %arg0 : f32
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME: f32
+func.func @roundeven(%arg0 : f32) {
+ // CHECK: "llvm.intr.roundeven"(%arg0) : (f32) -> f32
+ %0 = math.roundeven %arg0 : f32
+ func.return
+}
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index b7e9dfcad8c13..641dd568fe494 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -14,6 +14,8 @@
// CHECK-DAG: @tanhf(f32) -> f32
// CHECK-DAG: @round(f64) -> f64
// CHECK-DAG: @roundf(f32) -> f32
+// CHECK-DAG: @roundeven(f64) -> f64
+// CHECK-DAG: @roundevenf(f32) -> f32
// CHECK-DAG: @cos(f64) -> f64
// CHECK-DAG: @cosf(f32) -> f32
// CHECK-DAG: @sin(f64) -> f64
@@ -213,6 +215,19 @@ func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
+// CHECK-LABEL: func @roundeven_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @roundeven_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundevenf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.roundeven %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @roundeven(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.roundeven %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
+
+
// CHECK-LABEL: func @cos_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -261,6 +276,31 @@ func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
return %float_result, %double_result : vector<2xf32>, vector<2xf64>
}
+// CHECK-LABEL: func @roundeven_vec_caller(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+ // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+ // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+ // CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
+ // CHECK: %[[OUT0_F32:.*]] = call @roundevenf(%[[IN0_F32]]) : (f32) -> f32
+ // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+ // CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
+ // CHECK: %[[OUT1_F32:.*]] = call @roundevenf(%[[IN1_F32]]) : (f32) -> f32
+ // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+ %float_result = math.roundeven %float : vector<2xf32>
+ // CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
+ // CHECK: %[[OUT0_F64:.*]] = call @roundeven(%[[IN0_F64]]) : (f64) -> f64
+ // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+ // CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
+ // CHECK: %[[OUT1_F64:.*]] = call @roundeven(%[[IN1_F64]]) : (f64) -> f64
+ // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+ %double_result = math.roundeven %double : vector<2xf64>
+ // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+ return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+
+
// CHECK-LABEL: func @tan_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index c25d09734e226..1af096c429409 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -233,6 +233,19 @@ func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @roundeven(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.roundeven %[[F]] : f32
+ %0 = math.roundeven %f : f32
+ // CHECK: %{{.*}} = math.roundeven %[[V]] : vector<4xf32>
+ %1 = math.roundeven %v : vector<4xf32>
+ // CHECK: %{{.*}} = math.roundeven %[[T]] : tensor<4x4x?xf32>
+ %2 = math.roundeven %t : tensor<4x4x?xf32>
+ return
+}
+
+
// CHECK-LABEL: func @ipowi(
// CHECK-SAME: %[[I:.*]]: i32, %[[V:.*]]: vector<4xi32>, %[[T:.*]]: tensor<4x4x?xi32>)
func.func @ipowi(%i: i32, %v: vector<4xi32>, %t: tensor<4x4x?xi32>) {
More information about the Mlir-commits
mailing list