[Mlir-commits] [mlir] a0fc94a - [MLIR][Math] Add round operation
lorenzo chelini
llvmlistbot at llvm.org
Wed Jun 8 04:07:44 PDT 2022
Author: lorenzo chelini
Date: 2022-06-08T13:07:39+02:00
New Revision: a0fc94ab618973dc4454d0695abb104f6a8644d2
URL: https://github.com/llvm/llvm-project/commit/a0fc94ab618973dc4454d0695abb104f6a8644d2
DIFF: https://github.com/llvm/llvm-project/commit/a0fc94ab618973dc4454d0695abb104f6a8644d2.diff
LOG: [MLIR][Math] Add round operation
Introduce RoundOp in the math dialect. The operation rounds the operand to the
nearest integer value in floating-point format. RoundOp lowers to LLVM
intrinsics 'llvm.intr.round' or as a function call to libm (round or roundf).
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D127286
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
mlir/test/Dialect/Math/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1378135cb354..58cf55fefaff 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -652,4 +652,30 @@ def Math_TanhOp : Math_FloatUnaryOp<"tanh"> {
}];
}
+//===----------------------------------------------------------------------===//
+// RoundOp
+//===----------------------------------------------------------------------===//
+
+def Math_RoundOp : Math_FloatUnaryOp<"round"> {
+ let summary = "round of the specified value";
+ let description = [{
+ Syntax:
+
+ ```
+ operation ::= ssa-id `=` `math.round` ssa-use `:` type
+ ```
+
+ The `round` operation returns the operand rounded to the nearest integer
+ value in floating-point format. It takes one operand of floating point type
+ (i.e., scalar, tensor or vector) and produces one result of the same type.
+
+ Example:
+
+ ```mlir
+ // Scalar round operation.
+ %a = math.round %b : f64
+ ```
+ }];
+}
+
#endif // MATH_OPS
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 189680c76df3..510540d6aa05 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -37,6 +37,8 @@ using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
+using RoundOpLowering =
+ VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
// A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`.
template <typename MathOp, typename LLVMOp>
@@ -285,7 +287,8 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
PowFOpLowering,
RsqrtOpLowering,
SinOpLowering,
- SqrtOpLowering
+ SqrtOpLowering,
+ RoundOpLowering
>(converter);
// clang-format on
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 6c9d02c273e5..78835e12e734 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -152,6 +152,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
"expm1f", "expm1", benefit);
patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
"tanh", benefit);
+ patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
+ "roundf", "round", benefit);
}
namespace {
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index af32271ace9d..6378ea6475f2 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -172,3 +172,12 @@ func.func @powf(%arg0 : f64) {
func.return
}
+// -----
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME: f32
+func.func @round(%arg0 : f32) {
+ // CHECK: "llvm.intr.round"(%arg0) : (f32) -> f32
+ %0 = math.round %arg0 : f32
+ func.return
+}
diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
index 7cdb56e783e7..cb09988b59e1 100644
--- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
+++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir
@@ -8,6 +8,8 @@
// CHECK-DAG: @atan2f(f32, f32) -> f32
// CHECK-DAG: @tanh(f64) -> f64
// CHECK-DAG: @tanhf(f32) -> f32
+// CHECK-DAG: @round(f64) -> f64
+// CHECK-DAG: @roundf(f32) -> f32
// CHECK-LABEL: func @tanh_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
@@ -21,7 +23,6 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) {
return %float_result, %double_result : f32, f64
}
-
// CHECK-LABEL: func @atan2_caller
// CHECK-SAME: %[[FLOAT:.*]]: f32
// CHECK-SAME: %[[DOUBLE:.*]]: f64
@@ -116,3 +117,15 @@ func.func @expm1_multidim_vec_caller(%float: vector<2x2xf32>) -> (vector<2x2xf32
// CHECK: %[[VAL_4:.*]] = vector.insert %[[OUT1_1_F32]], %[[VAL_3]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: return %[[VAL_4]] : vector<2x2xf32>
// CHECK: }
+
+// CHECK-LABEL: func @round_caller
+// CHECK-SAME: %[[FLOAT:.*]]: f32
+// CHECK-SAME: %[[DOUBLE:.*]]: f64
+func.func @round_caller(%float: f32, %double: f64) -> (f32, f64) {
+ // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @roundf(%[[FLOAT]]) : (f32) -> f32
+ %float_result = math.round %float : f32
+ // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @round(%[[DOUBLE]]) : (f64) -> f64
+ %double_result = math.round %double : f64
+ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]]
+ return %float_result, %double_result : f32, f64
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index a1bb9af1786f..7acd8933f7b4 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -194,3 +194,15 @@ func.func @tanh(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
%2 = math.tanh %t : tensor<4x4x?xf32>
return
}
+
+// CHECK-LABEL: func @round(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @round(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.round %[[F]] : f32
+ %0 = math.round %f : f32
+ // CHECK: %{{.*}} = math.round %[[V]] : vector<4xf32>
+ %1 = math.round %v : vector<4xf32>
+ // CHECK: %{{.*}} = math.round %[[T]] : tensor<4x4x?xf32>
+ %2 = math.round %t : tensor<4x4x?xf32>
+ return
+}
More information about the Mlir-commits
mailing list