[Mlir-commits] [mlir] [mlir][arith] Add rounding mode flags to binary arithmetic operations (PR #188458)

Matthias Springer llvmlistbot at llvm.org
Wed Mar 25 03:38:37 PDT 2026


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/188458

Add rounding mode flags for `addf`, `subf`, `mulf`, `divf`, `remf`. This addresses a TODO in the op description.

Also improve the folder to take into account the rounding mode. If no rounding mode is specified, the folder uses `rmNearestTiesToEven`. (This behavior has not changed. We may want to deactivate folding in the future when no rounding mode is provided and the result cannot be represented exactly.)

Also add a lowering to LLVM intrinsics such as `llvm.intr.experimental.constrained.fadd`.


>From 306b777f87647a8a60cdd1729f049f929510f148 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 10:33:39 +0000
Subject: [PATCH] [mlir][arith] Add rounding mode flags to binary arithmetic
 operations

---
 .../ArithCommon/AttrToLLVMConverter.h         |  5 ++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 90 ++++++++++++++++---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 55 ++++++++----
 .../ComplexToStandard/ComplexToStandard.cpp   |  8 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  9 +-
 .../Dialect/Arith/IR/ArithCanonicalization.td | 12 +--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 36 ++++++--
 .../lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 ++--
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 61 +++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 75 ++++++++++++++++
 mlir/test/Dialect/Arith/ops.mlir              | 17 ++++
 11 files changed, 325 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index fccfe4897114e..c0773c9b69b6b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
       convertedAttr.set(TargetOp::getRoundingModeAttrName(),
                         convertArithRoundingModeAttrToLLVM(arithAttr));
     }
+    // Constrained intrinsics do not support fastmath flags. Remove the
+    // arith fastmath attribute if present.
+    if constexpr (SourceOp::template hasTrait<
+                      arith::ArithFastMathInterface::Trait>())
+      convertedAttr.erase(srcOp.getFastMathAttrName());
     convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
                       getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
   }
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 45cb3cecef3d8..864c947c005fa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -90,6 +90,37 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Base class for floating point binary operations with an optional rounding
+// mode.
+class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
+                                          list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic,
+      !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                   DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+                  traits)>,
+    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+      DefaultValuedAttr<
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$result)> {
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+  ];
+  let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
+                          (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
+}
+
 // Checks that tensor input and outputs have identical shapes. This is stricker
 // than the verification done in `SameOperandsAndResultShape` that allows for
 // tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
@@ -957,7 +988,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
+def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
   let summary = "floating point addition operation";
   let description = [{
     The `addf` operation takes two operands and returns one result, each of
@@ -965,6 +996,9 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -976,10 +1010,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 
     // Tensor addition.
     %x = arith.addf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar addition with rounding mode.
+    %a = arith.addf %b, %c to_nearest_even : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -988,7 +1022,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
+def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
   let summary = "floating point subtraction operation";
   let description = [{
     The `subf` operation takes two operands and returns one result, each of
@@ -996,6 +1030,9 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
 
     // Tensor subtraction.
     %x = arith.subf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar subtraction with rounding mode.
+    %a = arith.subf %b, %c downward : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
+def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
   let summary = "floating point multiplication operation";
   let description = [{
     The `mulf` operation takes two operands and returns one result, each of
@@ -1147,6 +1184,9 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1158,10 +1198,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 
     // Tensor pointwise multiplication.
     %x = arith.mulf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar multiplication with rounding mode.
+    %a = arith.mulf %b, %c upward : f64
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -1171,8 +1211,27 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 // DivFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
+def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
   let summary = "floating point division operation";
+  let description = [{
+    The `divf` operation takes two operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
+    Example:
+
+    ```mlir
+    // Scalar division.
+    %a = arith.divf %b, %c : f64
+
+    // Scalar division with rounding mode.
+    %a = arith.divf %b, %c toward_zero : f64
+    ```
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
@@ -1181,11 +1240,14 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
   let summary = "floating point division remainder operation";
   let description = [{
     Returns the floating point division remainder.
     The remainder has the same sign as the dividend (lhs operand).
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
   }];
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a0346ec6f4fb6..9aba2b42926ce 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -81,9 +81,13 @@ struct IdentityBitcastLowering final
 //===----------------------------------------------------------------------===//
 
 using AddFOpLowering =
-    VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -91,9 +95,13 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
-    VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -139,9 +147,13 @@ using MinSIOpLowering =
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
-    VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -151,9 +163,13 @@ using NegFOpLowering =
                                /*FailOnUnsupportedFP=*/true>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
-    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -170,9 +186,13 @@ using ShRUIOpLowering =
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
-    VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -690,6 +710,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
   // clang-format off
   patterns.add<
     AddFOpLowering,
+    ConstrainedAddFOpLowering,
     AddIOpLowering,
     AndIOpLowering,
     AddUIExtendedOpLowering,
@@ -698,6 +719,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     CmpFOpLowering,
     CmpIOpLowering,
     DivFOpLowering,
+    ConstrainedDivFOpLowering,
     DivSIOpLowering,
     DivUIOpLowering,
     ExtFOpLowering,
@@ -717,12 +739,14 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     MinSIOpLowering,
     MinUIOpLowering,
     MulFOpLowering,
+    ConstrainedMulFOpLowering,
     MulIOpLowering,
     MulSIExtendedOpLowering,
     MulUIExtendedOpLowering,
     NegFOpLowering,
     OrIOpLowering,
     RemFOpLowering,
+    ConstrainedRemFOpLowering,
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
@@ -732,6 +756,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     ShRUIOpLowering,
     SIToFPOpLowering,
     SubFOpLowering,
+    ConstrainedSubFOpLowering,
     SubIOpLowering,
     TruncFOpLowering,
     ConstrainedTruncFOpLowering,
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..b899220f2e9af 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
 
     Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
     Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
-    Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
-                                                realRhs, fmf.getValue());
+    Value resultReal =
+        BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
     Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
-                                                imagRhs, fmf.getValue());
+    Value resultImag =
+        BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..11b3aabcbfeb4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -120,7 +120,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
     auto one =
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
+    return arith::DivFOp::create(rewriter, loc, one, args[0]);
   }
 
   // tosa::MulOp
@@ -140,8 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                           "Cannot have shift value for float");
         return nullptr;
       }
-      return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
-                                   args[1]);
+      return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
     }
 
     if (isa<IntegerType>(elementTy)) {
@@ -538,8 +537,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
     auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
     auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
-    auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
+    auto added = arith::AddFOp::create(rewriter, loc, exp, one);
+    return arith::DivFOp::create(rewriter, loc, one, added);
   }
 
   // tosa::CastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..488dac54569d5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -449,10 +449,10 @@ def UIToFPOfExtUI :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and rounding mode of original mulf)
 def MulFOfNegF :
-    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_MulFOp $x, $y, $fmf),
+    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_MulFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 //===----------------------------------------------------------------------===//
@@ -460,10 +460,10 @@ def MulFOfNegF :
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and rounding mode of original divf)
 def DivFOfNegF :
-    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_DivFOp $x, $y, $fmf),
+    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_DivFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 #endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..b00ae7bfc4724 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1107,9 +1107,14 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a + b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                         : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1121,9 +1126,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a - b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                              : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1312,9 +1322,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
       return getRhs();
   }
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a * b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                              : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1331,9 +1346,14 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a / b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                            : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index f76ddfae2a67a..ec56ab3ab42e5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   Type type = operand.getType();
   Value sin = math::SinOp::create(b, type, operand);
   Value cos = math::CosOp::create(b, type, operand);
-  Value div = arith::DivFOp::create(b, type, sin, cos);
+  Value div = arith::DivFOp::create(b, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
-  Value add = arith::AddFOp::create(b, type, mult, operandC);
+  Value mult = arith::MulFOp::create(b, operandA, operandB);
+  Value add = arith::AddFOp::create(b, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
   Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = arith::MulFOp::create(b, baseType, base, res);
+      res = arith::MulFOp::create(b, base, res);
     absPower >>= 1;
-    base = arith::MulFOp::create(b, baseType, base, base);
+    base = arith::MulFOp::create(b, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
-    res = arith::DivFOp::create(b, baseType, one, res);
+    res = arith::DivFOp::create(b, one, res);
     res =
         arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
     res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
   Value operand = op.getOperand();
   Type opType = operand.getType();
   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
-  Value mult = arith::MulFOp::create(b, opType, operand, ln2);
+  Value mult = arith::MulFOp::create(b, operand, ln2);
   Value exp = math::ExpOp::create(b, op->getLoc(), mult);
   rewriter.replaceOp(op, exp);
   return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
 
   Value incrValue = math::CopySignOp::create(b, half, operand);
-  Value add = arith::AddFOp::create(b, opType, operand, incrValue);
+  Value add = arith::AddFOp::create(b, operand, incrValue);
   Value fpFixedConvert = createTruncatedFPValue(add, b);
 
   // There are three cases where adding 0.5 to the value and truncating by
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 6a6016c4f5b16..7b952cf3aadfc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -377,6 +377,67 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
 
 // -----
 
+// CHECK-LABEL: experimental_constrained_addf
+func.func @experimental_constrained_addf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore
+  %0 = arith.addf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 downward ignore
+  %1 = arith.addf %arg0, %arg1 downward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 upward ignore
+  %2 = arith.addf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 towardzero ignore
+  %3 = arith.addf %arg0, %arg1 toward_zero : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearestaway ignore
+  %4 = arith.addf %arg0, %arg1 to_nearest_away : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_subf
+func.func @experimental_constrained_subf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 tonearest ignore
+  %0 = arith.subf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 downward ignore
+  %1 = arith.subf %arg0, %arg1 downward : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_mulf
+func.func @experimental_constrained_mulf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 upward ignore
+  %0 = arith.mulf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 towardzero ignore
+  %1 = arith.mulf %arg0, %arg1 toward_zero : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_divf
+func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearest ignore
+  %0 = arith.divf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearestaway ignore
+  %1 = arith.divf %arg0, %arg1 to_nearest_away : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_remf
+func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
+  %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
+  %1 = arith.remf %arg0, %arg1 downward : f64
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @convertf_f16_to_bf16
 func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
 // CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..1e10be80a13cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2447,6 +2447,81 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
 
 // -----
 
+// Verify that constant folding respects rounding modes. 1.0000001 + 1.0 is not
+// exactly representable in f32. With upward rounding, the result is rounded up,
+// and with downward rounding it is rounded down.
+// CHECK-LABEL: @test_addf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_addf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 2.00000024 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant 1.0 : f32
+  // addf(x, -0) folds even with a rounding mode.
+  %c_neg0 = arith.constant -0.0 : f32
+  %0 = arith.addf %arg0, %c_neg0 to_nearest_even : f32
+  %1 = arith.addf %a, %b upward : f32
+  %2 = arith.addf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 2.00000024 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant -1.0 : f32
+  // subf(x, +0) folds even with a rounding mode.
+  %c0 = arith.constant 0.0 : f32
+  %0 = arith.subf %arg0, %c0 downward : f32
+  %1 = arith.subf %a, %b upward : f32
+  %2 = arith.subf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_mulf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 3.00000048 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 3.00000024 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant 3.0 : f32
+  // mulf(x, 1) folds even with a rounding mode.
+  %c1 = arith.constant 1.0 : f32
+  %0 = arith.mulf %arg0, %c1 upward : f32
+  %1 = arith.mulf %a, %b upward : f32
+  %2 = arith.mulf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_divf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 0.333333343 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 0.333333313 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0 : f32
+  %b = arith.constant 3.0 : f32
+  // divf(x, 1) folds even with a rounding mode.
+  %c1 = arith.constant 1.0 : f32
+  %0 = arith.divf %arg0, %c1 toward_zero : f32
+  %1 = arith.divf %a, %b upward : f32
+  %2 = arith.divf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
 func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
   %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
   %1 = arith.divui %0, %arg0 : index
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 2c5371de9ff24..f6e21c94969f2 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1234,6 +1234,23 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   return
 }
 
+// CHECK-LABEL: @roundingmode
+func.func @roundingmode(%arg0: f32, %arg1: f32) {
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even : f32
+  %0 = arith.addf %arg0, %arg1 to_nearest_even : f32
+// CHECK: {{.*}} = arith.subf %arg0, %arg1 downward : f32
+  %1 = arith.subf %arg0, %arg1 downward : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 upward : f32
+  %2 = arith.mulf %arg0, %arg1 upward : f32
+// CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
+  %3 = arith.divf %arg0, %arg1 toward_zero : f32
+// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
+  %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+  %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+  return
+}
+
 // CHECK-LABEL: @select_tensor
 func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
   // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>



More information about the Mlir-commits mailing list