[Mlir-commits] [mlir] b003c60 - [mlir][arith] Match folding of `arith.remf` to `llvm.frem` semantics (#96537)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 25 12:01:22 PDT 2024
Author: Felix Schneider
Date: 2024-06-25T21:01:18+02:00
New Revision: b003c60904a78eb62702c613f9e61155bad56798
URL: https://github.com/llvm/llvm-project/commit/b003c60904a78eb62702c613f9e61155bad56798
DIFF: https://github.com/llvm/llvm-project/commit/b003c60904a78eb62702c613f9e61155bad56798.diff
LOG: [mlir][arith] Match folding of `arith.remf` to `llvm.frem` semantics (#96537)
There are multiple ways to define a remainder operation. Depending on
the definition, the result could be either always positive or have the
sign of the dividend.
The pattern lowering `arith.remf` to LLVM assumes that the semantics
match `llvm.frem`, which seems to be reasonable. The folder, however, is
implemented via `APFloat::remainder()` which has different semantics.
This patch matches the folding behaviour to lowering behavior by using
`APFloat::mod()`, which matches the behavior of `llvm.frem` and libm's
`fmod()`. It also updates the documentation of `arith.remf` to explain
this behavior: The sign of the result of the remainder operation always
matches the sign of the dividend (LHS operand).
frem documentation: https://llvm.org/docs/LangRef.html#frem-instruction
Fix https://github.com/llvm/llvm-project/issues/94431
---------
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6e24b607cebcf..477478a4651ce 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1133,6 +1133,10 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
let summary = "floating point division remainder operation";
+ let description = [{
+ Returns the floating point division remainder.
+ The remainder has the same sign as the dividend (lhs operand).
+ }];
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f4781fcb546a3..0e7beffaf2f42 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1204,7 +1204,10 @@ OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) {
APFloat result(a);
- (void)result.remainder(b);
+ // APFloat::mod() offers the remainder
+ // behavior we want, i.e. the result has
+ // the sign of LHS operand.
+ (void)result.mod(b);
return result;
});
}
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..a386a178b7899 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2467,7 +2467,7 @@ func.func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
// -----
// CHECK-LABEL: @test_remf(
-// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[res:.+]] = arith.constant 1.000000e+00 : f32
// CHECK: return %[[res]]
func.func @test_remf() -> (f32) {
%v1 = arith.constant 3.0 : f32
@@ -2476,11 +2476,24 @@ func.func @test_remf() -> (f32) {
return %0 : f32
}
+// CHECK-LABEL: @test_remf2(
+// CHECK: %[[respos:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[resneg:.+]] = arith.constant -1.000000e+00 : f32
+// CHECK: return %[[respos]], %[[resneg]]
+func.func @test_remf2() -> (f32, f32) {
+ %v1 = arith.constant 3.0 : f32
+ %v2 = arith.constant -2.0 : f32
+ %v3 = arith.constant -3.0 : f32
+ %0 = arith.remf %v1, %v2 : f32
+ %1 = arith.remf %v3, %v2 : f32
+ return %0, %1 : f32, f32
+}
+
// CHECK-LABEL: @test_remf_vec(
// CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32>
// CHECK: return %[[res]]
func.func @test_remf_vec() -> (vector<4xf32>) {
- %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+ %v1 = arith.constant dense<[1.0, 2.0, -3.0, 4.0]> : vector<4xf32>
%v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32>
%0 = arith.remf %v1, %v2 : vector<4xf32>
return %0 : vector<4xf32>
More information about the Mlir-commits
mailing list