[Mlir-commits] [mlir] 4639a85 - [mlir] Add math.roundeven and llvm.intr.roundeven

Tres Popp llvmlistbot at llvm.org
Thu Aug 25 04:39:14 PDT 2022


Author: Tres Popp
Date: 2022-08-25T13:39:01+02:00
New Revision: 4639a85f94be6355c55dcc802a9b01b622d6f7b1

URL: https://github.com/llvm/llvm-project/commit/4639a85f94be6355c55dcc802a9b01b622d6f7b1
DIFF: https://github.com/llvm/llvm-project/commit/4639a85f94be6355c55dcc802a9b01b622d6f7b1.diff

LOG: [mlir] Add math.roundeven and llvm.intr.roundeven

This is similar to math.round, but rounds to even instead of rounding away from
zero in the case of halfway values. This CL also adds lowerings to libm and
to the LLVM intrinsic.

Differential Revision: https://reviews.llvm.org/D132375

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
index ca989bfc6b5dd..219efd57de357 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
@@ -61,6 +61,7 @@ def LLVM_Prefetch : LLVM_ZeroResultIntrOp<"prefetch", [0]> {
                    LLVM_Type:$cache);
 }
 def LLVM_SinOp : LLVM_UnaryIntrinsicOp<"sin">;
+def LLVM_RoundEvenOp : LLVM_UnaryIntrinsicOp<"roundeven">;
 def LLVM_RoundOp : LLVM_UnaryIntrinsicOp<"round">;
 def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">;
 def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">;

diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 706e2adb5852e..21b5db85f484c 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -741,6 +741,35 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// RoundEvenOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundEvenOp : Math_FloatUnaryOp<"roundeven"> {
+  let summary = "round of the specified value with halfway cases to even";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.roundeven` ssa-use `:` type
+    ```
+
+    The `roundeven` operation returns the operand rounded to the nearest integer
+    value in floating-point format. It takes one operand of floating point type
+    (i.e., scalar, tensor or vector) and produces one result of the same type.  The
+    operation rounds the argument to the nearest integer value in floating-point
+    format, rounding halfway cases to even, regardless of the current
+    rounding direction.
+
+    Example:
+
+    ```mlir
+    // Scalar round operation.
+    %a = math.roundeven %b : f64
+    ```
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // RoundOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index cb34982d64d77..8161cc5e419ff 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -35,6 +35,8 @@ using Log10OpLowering =
 using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
 using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
+using RoundEvenOpLowering =
+    VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
 using RoundOpLowering =
     VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
@@ -285,6 +287,7 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     Log2OpLowering,
     LogOpLowering,
     PowFOpLowering,
+    RoundEvenOpLowering,
     RoundOpLowering,
     RsqrtOpLowering,
     SinOpLowering,

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6e3bd2af7bd62..5071d60f9e473 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -141,16 +141,19 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
 void mlir::populateMathToLibmConversionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit,
     llvm::Optional<PatternBenefit> log1pBenefit) {
-  patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
-               VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
-               VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
-               VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
-               VecOpToScalarOp<math::TanOp>>(patterns.getContext(), benefit);
+  patterns
+      .add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::ExpM1Op>,
+           VecOpToScalarOp<math::TanhOp>, VecOpToScalarOp<math::CosOp>,
+           VecOpToScalarOp<math::SinOp>, VecOpToScalarOp<math::ErfOp>,
+           VecOpToScalarOp<math::RoundEvenOp>, VecOpToScalarOp<math::RoundOp>,
+           VecOpToScalarOp<math::AtanOp>, VecOpToScalarOp<math::TanOp>>(
+          patterns.getContext(), benefit);
   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::ExpM1Op>,
                PromoteOpToF32<math::TanhOp>, PromoteOpToF32<math::CosOp>,
                PromoteOpToF32<math::SinOp>, PromoteOpToF32<math::ErfOp>,
-               PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
-               PromoteOpToF32<math::TanOp>>(patterns.getContext(), benefit);
+               PromoteOpToF32<math::RoundEvenOp>, PromoteOpToF32<math::RoundOp>,
+               PromoteOpToF32<math::AtanOp>, PromoteOpToF32<math::TanOp>>(
+      patterns.getContext(), benefit);
   patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
                                                  "atan", benefit);
   patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
@@ -163,6 +166,8 @@ void mlir::populateMathToLibmConversionPatterns(
                                                 "tan", benefit);
   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
                                                  "tanh", benefit);
+  patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
+      patterns.getContext(), "roundevenf", "roundeven", benefit);
   patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
                                                   "roundf", "round", benefit);
   patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 64e20188e0f2c..b87ddbacab539 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -190,3 +190,13 @@ func.func @round(%arg0 : f32) {
   %0 = math.round %arg0 : f32
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME: f32
+func.func @roundeven(%arg0 : f32) {
+  // CHECK: "llvm.intr.roundeven"(%arg0) : (f32) -> f32
+  %0 = math.roundeven %arg0 : f32
+  func.return
+}

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index b7e9dfcad8c13..641dd568fe494 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -14,6 +14,8 @@
 // CHECK-DAG: @tanhf(f32) -> f32
 // CHECK-DAG: @round(f64) -> f64
 // CHECK-DAG: @roundf(f32) -> f32
+// CHECK-DAG: @roundeven(f64) -> f64
+// CHECK-DAG: @roundevenf(f32) -> f32
 // CHECK-DAG: @cos(f64) -> f64
 // CHECK-DAG: @cosf(f32) -> f32
 // CHECK-DAG: @sin(f64) -> f64
@@ -213,6 +215,19 @@ func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
   return %float_result, %double_result : f32, f64
 }
 
+// CHECK-LABEL: func @roundeven_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @roundeven_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundevenf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.roundeven %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @roundeven(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.roundeven %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}
+
+
 // CHECK-LABEL: func @cos_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -261,6 +276,31 @@ func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (v
   return %float_result, %double_result : vector<2xf32>, vector<2xf64>
 }
 
+// CHECK-LABEL:   func @roundeven_vec_caller(
+// CHECK-SAME:                           %[[VAL_0:.*]]: vector<2xf32>,
+// CHECK-SAME:                           %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+func.func @roundeven_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) {
+  // CHECK-DAG:       %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+  // CHECK-DAG:       %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64>
+  // CHECK:           %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32>
+  // CHECK:           %[[OUT0_F32:.*]] = call @roundevenf(%[[IN0_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32>
+  // CHECK:           %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32>
+  // CHECK:           %[[OUT1_F32:.*]] = call @roundevenf(%[[IN1_F32]]) : (f32) -> f32
+  // CHECK:           %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32>
+  %float_result = math.roundeven %float : vector<2xf32>
+  // CHECK:           %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64>
+  // CHECK:           %[[OUT0_F64:.*]] = call @roundeven(%[[IN0_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64>
+  // CHECK:           %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64>
+  // CHECK:           %[[OUT1_F64:.*]] = call @roundeven(%[[IN1_F64]]) : (f64) -> f64
+  // CHECK:           %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64>
+  %double_result = math.roundeven %double : vector<2xf64>
+  // CHECK:           return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64>
+  return %float_result, %double_result : vector<2xf32>, vector<2xf64>
+}
+
+
 // CHECK-LABEL: func @tan_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index c25d09734e226..1af096c429409 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -233,6 +233,19 @@ func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @roundeven(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @roundeven(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.roundeven %[[F]] : f32
+  %0 = math.roundeven %f : f32
+  // CHECK: %{{.*}} = math.roundeven %[[V]] : vector<4xf32>
+  %1 = math.roundeven %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.roundeven %[[T]] : tensor<4x4x?xf32>
+  %2 = math.roundeven %t : tensor<4x4x?xf32>
+  return
+}
+
+
 // CHECK-LABEL: func @ipowi(
 // CHECK-SAME:             %[[I:.*]]: i32, %[[V:.*]]: vector<4xi32>, %[[T:.*]]: tensor<4x4x?xi32>)
 func.func @ipowi(%i: i32, %v: vector<4xi32>, %t: tensor<4x4x?xi32>) {


        


More information about the Mlir-commits mailing list