[Mlir-commits] [mlir] a0fc94a - [MLIR][Math] Add round operation

lorenzo chelini llvmlistbot at llvm.org
Wed Jun 8 04:07:44 PDT 2022


Author: lorenzo chelini
Date: 2022-06-08T13:07:39+02:00
New Revision: a0fc94ab618973dc4454d0695abb104f6a8644d2

URL: https://github.com/llvm/llvm-project/commit/a0fc94ab618973dc4454d0695abb104f6a8644d2
DIFF: https://github.com/llvm/llvm-project/commit/a0fc94ab618973dc4454d0695abb104f6a8644d2.diff

LOG: [MLIR][Math] Add round operation

Introduce RoundOp in the math dialect. The operation rounds the operand to the
nearest integer value in floating-point format. RoundOp lowers to LLVM
intrinsics 'llvm.intr.round' or as a function call to libm (round or roundf).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D127286

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
    mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
    mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
    mlir/test/Dialect/Math/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1378135cb354..58cf55fefaff 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -652,4 +652,30 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// RoundOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundOp : Math_FloatUnaryOp<"round"> {
+  let summary = "round of the specified value";
+  let description = [{
+    Syntax:
+
+    ```
+    operation ::= ssa-id `=` `math.round` ssa-use `:` type
+    ```
+
+    The `round` operation returns the operand rounded to the nearest integer
+    value in floating-point format. It takes one operand of floating point type
+    (i.e., scalar, tensor or vector) and produces one result of the same type.
+
+    Example:
+
+    ```mlir
+    // Scalar round operation.
+    %a = math.round %b : f64
+    ```
+  }];
+}
+
 #endif // MATH_OPS

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 189680c76df3..510540d6aa05 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -37,6 +37,8 @@ using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
 using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
 using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
 using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+using RoundOpLowering =
+    VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
 
 // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
 template <typename MathOp, typename LLVMOp>
@@ -285,7 +287,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     PowFOpLowering,
     RsqrtOpLowering,
     SinOpLowering,
-    SqrtOpLowering
+    SqrtOpLowering,
+    RoundOpLowering
   >(converter);
   // clang-format on
 }

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6c9d02c273e5..78835e12e734 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -152,6 +152,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
                                                   "expm1f", "expm1", benefit);
   patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
                                                  "tanh", benefit);
+  patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
+                                                  "roundf", "round", benefit);
 }
 
 namespace {

diff  --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index af32271ace9d..6378ea6475f2 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -172,3 +172,12 @@ func.func @powf(%arg0 : f64) {
   func.return
 }
 
+// -----
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME: f32
+func.func @round(%arg0 : f32) {
+  // CHECK: "llvm.intr.round"(%arg0) : (f32) -> f32
+  %0 = math.round %arg0 : f32
+  func.return
+}

diff  --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 7cdb56e783e7..cb09988b59e1 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -8,6 +8,8 @@
 // CHECK-DAG: @atan2f(f32, f32) -> f32
 // CHECK-DAG: @tanh(f64) -> f64
 // CHECK-DAG: @tanhf(f32) -> f32
+// CHECK-DAG: @round(f64) -> f64
+// CHECK-DAG: @roundf(f32) -> f32
 
 // CHECK-LABEL: func @tanh_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
@@ -21,7 +23,6 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64)  {
   return %float_result, %double_result : f32, f64
 }
 
-
 // CHECK-LABEL: func @atan2_caller
 // CHECK-SAME: %[[FLOAT:.*]]: f32
 // CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -116,3 +117,15 @@ func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32
 // CHECK:           %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
 // CHECK:           return %[[VAL_4]] : vector<2x2xf32>
 // CHECK:         }
+
+// CHECK-LABEL: func @round_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
+  // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundf(%[[FLOAT]]) : (f32) -> f32
+  %float_result = math.round %float : f32
+  // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @round(%[[DOUBLE]]) : (f64) -> f64
+  %double_result = math.round %double : f64
+  // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+  return %float_result, %double_result : f32, f64
+}

diff  --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index a1bb9af1786f..7acd8933f7b4 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -194,3 +194,15 @@ func.func @tanh(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   %2 = math.tanh %t : tensor<4x4x?xf32>
   return
 }
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME:             %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.round %[[F]] : f32
+  %0 = math.round %f : f32
+  // CHECK: %{{.*}} = math.round %[[V]] : vector<4xf32>
+  %1 = math.round %v : vector<4xf32>
+  // CHECK: %{{.*}} = math.round %[[T]] : tensor<4x4x?xf32>
+  %2 = math.round %t : tensor<4x4x?xf32>
+  return
+}


        


More information about the Mlir-commits mailing list