[Mlir-commits] [mlir] [mlir][arith] Match folding of `arith.remf` to `llvm.frem` semantics (PR #96537)

Felix Schneider llvmlistbot at llvm.org
Mon Jun 24 12:20:03 PDT 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/96537

There are multiple ways to define the remainder operation. Depending on the definition, the result is either always positive or has the sign of the dividend.
The pattern lowering `arith.remf` to LLVM assumes that the semantics match `llvm.frem`, which seems to be reasonable. The folder, however, is implemented via `APFloat::remainder()` which has different semantics.

This patch matches the folding behaviour to lowering behavior by using `APFloat::mod()`, which matches the behavior of `llvm.frem` and libm's `fmod()`. It also updates the documentation of `arith.remf` to explicitly state that it follows the semantics of `llvm.frem`.

frem documentation: https://llvm.org/docs/LangRef.html#frem-instruction

Fix https://github.com/llvm/llvm-project/issues/94431

>From 3e7020eaf2f39ddc157ae8ba0891ae2777acb308 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 24 Jun 2024 21:12:17 +0200
Subject: [PATCH] [mlir][arith] Match folding of `arith.remf` to `llvm.frem`
 semantics

There are multiple ways to define the remainder operation. Depending
on the definition, the result is either always positive or has the sign
of the dividend.
The pattern lowering `arith.remf` to LLVM assumes that the semantics
match `llvm.frem`, which seems to be reasonable. The folder, however,
is implemented via `APFloat::remainder()` which has different semantics.

This patch matches the folding behaviour to lowering behavior by using
`APFloat::mod()`, which matches the behavior of `llvm.frem` and libm's
`fmod()`. It also updates the documentation of `arith.remf` to explicitly
state that it follows the semantics of `llvm.frem`.

frem documentation: https://llvm.org/docs/LangRef.html#frem-instruction

Fix https://github.com/llvm/llvm-project/issues/94431
---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td |  4 ++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp         |  4 +++-
 mlir/test/Dialect/Arith/canonicalize.mlir      | 18 ++++++++++++++++--
 3 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 6e24b607cebcf..01581869f9cd8 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1133,6 +1133,10 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
 
 def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
   let summary = "floating point division remainder operation";
+  let description = [{
+    The `arith.remf` operation returns the floating point division remainder
+    following the semantics of `llvm.frem`.
+  }];
   let hasFolder = 1;
 }
 
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index f4781fcb546a3..b926ff008b3ce 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1204,7 +1204,9 @@ OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
   return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
                                       [](const APFloat &a, const APFloat &b) {
                                         APFloat result(a);
-                                        (void)result.remainder(b);
+                                        // Mirror behavior of llvm.frem which
+                                        // behaves like libm's fmod(a, b).
+                                        (void)result.mod(b);
                                         return result;
                                       });
 }
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 4fe7cfb689be8..31ed8b2cafbc0 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2465,9 +2465,10 @@ func.func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) {
 }
 
 // -----
+// llvm.frem: "The remainder has the same sign as the dividend"
 
 // CHECK-LABEL: @test_remf(
-// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32
+// CHECK: %[[res:.+]] = arith.constant 1.000000e+00 : f32
 // CHECK: return %[[res]]
 func.func @test_remf() -> (f32) {
   %v1 = arith.constant 3.0 : f32
@@ -2476,11 +2477,24 @@ func.func @test_remf() -> (f32) {
   return %0 : f32
 }
 
+// CHECK-LABEL: @test_remf2(
+// CHECK: %[[respos:.+]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[resneg:.+]] = arith.constant -1.000000e+00 : f32
+// CHECK: return %[[respos]], %[[resneg]]
+func.func @test_remf2() -> (f32, f32) {
+  %v1 = arith.constant 3.0 : f32
+  %v2 = arith.constant -2.0 : f32
+  %v3 = arith.constant -3.0 : f32
+  %0 = arith.remf %v1, %v2 : f32
+  %1 = arith.remf %v3, %v2 : f32
+  return %0, %1 : f32, f32
+}
+
 // CHECK-LABEL: @test_remf_vec(
 // CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32>
 // CHECK: return %[[res]]
 func.func @test_remf_vec() -> (vector<4xf32>) {
-  %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32>
+  %v1 = arith.constant dense<[1.0, 2.0, -3.0, 4.0]> : vector<4xf32>
   %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32>
   %0 = arith.remf %v1, %v2 : vector<4xf32>
   return %0 : vector<4xf32>



More information about the Mlir-commits mailing list