[Mlir-commits] [mlir] [mlir][arith] Add rounding mode flags to binary arithmetic operations (PR #188458)

Matthias Springer llvmlistbot at llvm.org
Tue Apr 7 07:06:46 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/188458

>From 59d4c2efeea9509fc28c39e4176e8d0a629eb588 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 10:33:39 +0000
Subject: [PATCH 1/7] [mlir][arith] Add rounding mode flags to binary
 arithmetic operations

---
 .../ArithCommon/AttrToLLVMConverter.h         |  5 ++
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 90 ++++++++++++++++---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 55 ++++++++----
 .../ComplexToStandard/ComplexToStandard.cpp   |  8 +-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  9 +-
 .../Dialect/Arith/IR/ArithCanonicalization.td | 12 +--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 36 ++++++--
 .../lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 ++--
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 61 +++++++++++++
 mlir/test/Dialect/Arith/canonicalize.mlir     | 75 ++++++++++++++++
 mlir/test/Dialect/Arith/ops.mlir              | 17 ++++
 11 files changed, 325 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index fccfe4897114e..c0773c9b69b6b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
       convertedAttr.set(TargetOp::getRoundingModeAttrName(),
                         convertArithRoundingModeAttrToLLVM(arithAttr));
     }
+    // Constrained intrinsics do not support fastmath flags. Remove the
+    // arith fastmath attribute if present.
+    if constexpr (SourceOp::template hasTrait<
+                      arith::ArithFastMathInterface::Trait>())
+      convertedAttr.erase(srcOp.getFastMathAttrName());
     convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
                       getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
   }
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 45cb3cecef3d8..864c947c005fa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -90,6 +90,37 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Base class for floating point binary operations with an optional rounding
+// mode.
+class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
+                                          list<Trait> traits = []> :
+    Arith_BinaryOp<mnemonic,
+      !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
+                   DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+                  traits)>,
+    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+      DefaultValuedAttr<
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+    Results<(outs FloatLike:$result)> {
+  let builders = [
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+      "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+      build($_builder, $_state, lhs, rhs, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+  ];
+  let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
+                          (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
+}
+
 // Checks that tensor input and outputs have identical shapes. This is stricker
 // than the verification done in `SameOperandsAndResultShape` that allows for
 // tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
@@ -957,7 +988,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
 // AddFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
+def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
   let summary = "floating point addition operation";
   let description = [{
     The `addf` operation takes two operands and returns one result, each of
@@ -965,6 +996,9 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -976,10 +1010,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 
     // Tensor addition.
     %x = arith.addf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar addition with rounding mode.
+    %a = arith.addf %b, %c to_nearest_even : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -988,7 +1022,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
 // SubFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
+def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
   let summary = "floating point subtraction operation";
   let description = [{
     The `subf` operation takes two operands and returns one result, each of
@@ -996,6 +1030,9 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
 
     // Tensor subtraction.
     %x = arith.subf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar subtraction with rounding mode.
+    %a = arith.subf %b, %c downward : f64
+    ```
   }];
   let hasFolder = 1;
 }
@@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
 // MulFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
+def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
   let summary = "floating point multiplication operation";
   let description = [{
     The `mulf` operation takes two operands and returns one result, each of
@@ -1147,6 +1184,9 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
     scalar type, a vector whose element type is a floating point type, or a
     floating point tensor.
 
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
     Example:
 
     ```mlir
@@ -1158,10 +1198,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 
     // Tensor pointwise multiplication.
     %x = arith.mulf %y, %z : tensor<4x?xbf16>
-    ```
 
-    TODO: In the distant future, this will accept optional attributes for fast
-    math, contraction, rounding mode, and other controls.
+    // Scalar multiplication with rounding mode.
+    %a = arith.mulf %b, %c upward : f64
+    ```
   }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
@@ -1171,8 +1211,27 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
 // DivFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
+def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
   let summary = "floating point division operation";
+  let description = [{
+    The `divf` operation takes two operands and returns one result, each of
+    these is required to be the same type. This type may be a floating point
+    scalar type, a vector whose element type is a floating point type, or a
+    floating point tensor.
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
+
+    Example:
+
+    ```mlir
+    // Scalar division.
+    %a = arith.divf %b, %c : f64
+
+    // Scalar division with rounding mode.
+    %a = arith.divf %b, %c toward_zero : f64
+    ```
+  }];
   let hasFolder = 1;
   let hasCanonicalizer = 1;
 }
@@ -1181,11 +1240,14 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
   let summary = "floating point division remainder operation";
   let description = [{
     Returns the floating point division remainder.
     The remainder has the same sign as the dividend (lhs operand).
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or the default one if no rounding mode is provided.
   }];
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f9ea8dba105a4..51001cb55a627 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -81,9 +81,13 @@ struct IdentityBitcastLowering final
 //===----------------------------------------------------------------------===//
 
 using AddFOpLowering =
-    VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -91,9 +95,13 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
 using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
-    VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
 using DivUIOpLowering =
@@ -139,9 +147,13 @@ using MinSIOpLowering =
 using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
-    VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -151,9 +163,13 @@ using NegFOpLowering =
                                /*FailOnUnsupportedFP=*/true>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
-    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -170,9 +186,13 @@ using ShRUIOpLowering =
 using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
-    VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                               arith::AttrConvertFastMathToLLVM,
-                               /*FailOnUnsupportedFP=*/true>;
+    ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+                                          /*Constrained=*/false,
+                                          arith::AttrConvertFastMathToLLVM,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
@@ -700,6 +720,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
   // clang-format off
   patterns.add<
     AddFOpLowering,
+    ConstrainedAddFOpLowering,
     AddIOpLowering,
     AndIOpLowering,
     AddUIExtendedOpLowering,
@@ -708,6 +729,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     CmpFOpLowering,
     CmpIOpLowering,
     DivFOpLowering,
+    ConstrainedDivFOpLowering,
     DivSIOpLowering,
     DivUIOpLowering,
     ExtFOpLowering,
@@ -727,12 +749,14 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     MinSIOpLowering,
     MinUIOpLowering,
     MulFOpLowering,
+    ConstrainedMulFOpLowering,
     MulIOpLowering,
     MulSIExtendedOpLowering,
     MulUIExtendedOpLowering,
     NegFOpLowering,
     OrIOpLowering,
     RemFOpLowering,
+    ConstrainedRemFOpLowering,
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
@@ -742,6 +766,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     ShRUIOpLowering,
     SIToFPOpLowering,
     SubFOpLowering,
+    ConstrainedSubFOpLowering,
     SubIOpLowering,
     TruncFOpLowering,
     ConstrainedTruncFOpLowering,
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..b899220f2e9af 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
 
     Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
     Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
-    Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
-                                                realRhs, fmf.getValue());
+    Value resultReal =
+        BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
     Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
-                                                imagRhs, fmf.getValue());
+    Value resultImag =
+        BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..11b3aabcbfeb4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -120,7 +120,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
     auto one =
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
+    return arith::DivFOp::create(rewriter, loc, one, args[0]);
   }
 
   // tosa::MulOp
@@ -140,8 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
                                           "Cannot have shift value for float");
         return nullptr;
       }
-      return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
-                                   args[1]);
+      return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
     }
 
     if (isa<IntegerType>(elementTy)) {
@@ -538,8 +537,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
         arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
     auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
     auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
-    auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
-    return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
+    auto added = arith::AddFOp::create(rewriter, loc, exp, one);
+    return arith::DivFOp::create(rewriter, loc, one, added);
   }
 
   // tosa::CastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index a15e19b24e54b..03200c16ac378 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -437,10 +437,10 @@ def UIToFPOfExtUI :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and rounding mode of original mulf)
 def MulFOfNegF :
-    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_MulFOp $x, $y, $fmf),
+    Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_MulFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 //===----------------------------------------------------------------------===//
@@ -448,10 +448,10 @@ def MulFOfNegF :
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and rounding mode of original divf)
 def DivFOfNegF :
-    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
-        (Arith_DivFOp $x, $y, $fmf),
+    Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+        (Arith_DivFOp $x, $y, $fmf, $rm),
         [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
 
 #endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 569d1869a5abe..7d1d6f4ae1207 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1107,9 +1107,14 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a + b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                         : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1121,9 +1126,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a - b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                              : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1312,9 +1322,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
       return getRhs();
   }
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a * b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                              : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1331,9 +1346,14 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
     return getLhs();
 
+  auto rm = getRoundingmodeAttr();
   return constFoldBinaryOp<FloatAttr>(
-      adaptor.getOperands(),
-      [](const APFloat &a, const APFloat &b) { return a / b; });
+      adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+        APFloat result(a);
+        result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+                            : llvm::RoundingMode::NearestTiesToEven);
+        return result;
+      });
 }
 
 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index f76ddfae2a67a..ec56ab3ab42e5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   Type type = operand.getType();
   Value sin = math::SinOp::create(b, type, operand);
   Value cos = math::CosOp::create(b, type, operand);
-  Value div = arith::DivFOp::create(b, type, sin, cos);
+  Value div = arith::DivFOp::create(b, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
-  Value add = arith::AddFOp::create(b, type, mult, operandC);
+  Value mult = arith::MulFOp::create(b, operandA, operandB);
+  Value add = arith::AddFOp::create(b, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
   Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = arith::MulFOp::create(b, baseType, base, res);
+      res = arith::MulFOp::create(b, base, res);
     absPower >>= 1;
-    base = arith::MulFOp::create(b, baseType, base, base);
+    base = arith::MulFOp::create(b, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
-    res = arith::DivFOp::create(b, baseType, one, res);
+    res = arith::DivFOp::create(b, one, res);
     res =
         arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
     res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
   Value operand = op.getOperand();
   Type opType = operand.getType();
   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
-  Value mult = arith::MulFOp::create(b, opType, operand, ln2);
+  Value mult = arith::MulFOp::create(b, operand, ln2);
   Value exp = math::ExpOp::create(b, op->getLoc(), mult);
   rewriter.replaceOp(op, exp);
   return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
 
   Value incrValue = math::CopySignOp::create(b, half, operand);
-  Value add = arith::AddFOp::create(b, opType, operand, incrValue);
+  Value add = arith::AddFOp::create(b, operand, incrValue);
   Value fpFixedConvert = createTruncatedFPValue(add, b);
 
   // There are three cases where adding 0.5 to the value and truncating by
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 75601e215744c..d4d425ad72970 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -401,6 +401,67 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
 
 // -----
 
+// CHECK-LABEL: experimental_constrained_addf
+func.func @experimental_constrained_addf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore
+  %0 = arith.addf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 downward ignore
+  %1 = arith.addf %arg0, %arg1 downward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 upward ignore
+  %2 = arith.addf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 towardzero ignore
+  %3 = arith.addf %arg0, %arg1 toward_zero : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearestaway ignore
+  %4 = arith.addf %arg0, %arg1 to_nearest_away : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_subf
+func.func @experimental_constrained_subf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 tonearest ignore
+  %0 = arith.subf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 downward ignore
+  %1 = arith.subf %arg0, %arg1 downward : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_mulf
+func.func @experimental_constrained_mulf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 upward ignore
+  %0 = arith.mulf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 towardzero ignore
+  %1 = arith.mulf %arg0, %arg1 toward_zero : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_divf
+func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearest ignore
+  %0 = arith.divf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearestaway ignore
+  %1 = arith.divf %arg0, %arg1 to_nearest_away : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_remf
+func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
+  %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
+  %1 = arith.remf %arg0, %arg1 downward : f64
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @convertf_f16_to_bf16
 func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
 // CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ee3e713f8481e..b153bd7c32261 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2554,6 +2554,81 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
 
 // -----
 
+// Verify that constant folding respects rounding modes. 1.0000001 + 1.0 is not
+// exactly representable in f32. With upward rounding, the result is rounded up,
+// and with downward rounding it is rounded down.
+// CHECK-LABEL: @test_addf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_addf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 2.00000024 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant 1.0 : f32
+  // addf(x, -0) folds even with a rounding mode.
+  %c_neg0 = arith.constant -0.0 : f32
+  %0 = arith.addf %arg0, %c_neg0 to_nearest_even : f32
+  %1 = arith.addf %a, %b upward : f32
+  %2 = arith.addf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 2.00000024 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant -1.0 : f32
+  // subf(x, +0) folds even with a rounding mode.
+  %c0 = arith.constant 0.0 : f32
+  %0 = arith.subf %arg0, %c0 downward : f32
+  %1 = arith.subf %a, %b upward : f32
+  %2 = arith.subf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_mulf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 3.00000048 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 3.00000024 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0000001 : f32
+  %b = arith.constant 3.0 : f32
+  // mulf(x, 1) folds even with a rounding mode.
+  %c1 = arith.constant 1.0 : f32
+  %0 = arith.mulf %arg0, %c1 upward : f32
+  %1 = arith.mulf %a, %b upward : f32
+  %2 = arith.mulf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_divf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+  // CHECK-DAG:  %[[UP:.+]] = arith.constant 0.333333343 : f32
+  // CHECK-DAG:  %[[DOWN:.+]] = arith.constant 0.333333313 : f32
+  // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+  %a = arith.constant 1.0 : f32
+  %b = arith.constant 3.0 : f32
+  // divf(x, 1) folds even with a rounding mode.
+  %c1 = arith.constant 1.0 : f32
+  %0 = arith.divf %arg0, %c1 toward_zero : f32
+  %1 = arith.divf %a, %b upward : f32
+  %2 = arith.divf %a, %b downward : f32
+  return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
 func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
   %0 = arith.muli %arg0, %arg1 overflow<nuw> : index
   %1 = arith.divui %0, %arg0 : index
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 2c5371de9ff24..f6e21c94969f2 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1234,6 +1234,23 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
   return
 }
 
+// CHECK-LABEL: @roundingmode
+func.func @roundingmode(%arg0: f32, %arg1: f32) {
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even : f32
+  %0 = arith.addf %arg0, %arg1 to_nearest_even : f32
+// CHECK: {{.*}} = arith.subf %arg0, %arg1 downward : f32
+  %1 = arith.subf %arg0, %arg1 downward : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 upward : f32
+  %2 = arith.mulf %arg0, %arg1 upward : f32
+// CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
+  %3 = arith.divf %arg0, %arg1 toward_zero : f32
+// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
+  %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+  %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+  return
+}
+
 // CHECK-LABEL: @select_tensor
 func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
   // CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>

>From 360b1dbce7a7fad3408925fda871e6ccc065bbd2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 12:56:53 +0000
Subject: [PATCH 2/7] address comments

---
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 46 +++++++++++--------
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 12 +++++
 2 files changed, 39 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 51001cb55a627..30f32b8e248f4 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -29,13 +29,20 @@ using namespace mlir;
 
 namespace {
 
-/// Operations whose conversion will depend on whether they are passed a
-/// rounding mode attribute or not.
+/// Lowering pattern that matches only when the source op's rounding mode
+/// presence agrees with `HasRoundingMode`. This allows registering two
+/// instances of the same pattern for one source op: one that handles the
+/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
+/// one that handles the constrained case (rounding mode present, lowering to
+/// a constrained LLVM intrinsic).
 ///
-/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
-/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
-/// attribute.
-template <typename SourceOp, typename TargetOp, bool Constrained,
+/// * `HasRoundingMode`: the pattern matches if and only if the source op has
+///   a rounding mode attribute.
+/// * `AttrConvert`: attribute converter to translate source attributes to
+///   target attributes.
+/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
+///   floating point types.
+template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
           template <typename, typename> typename AttrConvert =
               AttrConvertPassThrough,
           bool FailOnUnsupportedFP = false>
@@ -49,7 +56,7 @@ struct ConstrainedVectorConvertToLLVMPattern
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+    if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
       return failure();
     return VectorConvertToLLVMPattern<
         SourceOp, TargetOp, AttrConvert,
@@ -82,11 +89,11 @@ struct IdentityBitcastLowering final
 
 using AddFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
-                                          /*Constrained=*/false,
+                                          /*HasRoundingMode=*/false,
                                           arith::AttrConvertFastMathToLLVM,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+    arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using AddIOpLowering =
     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
@@ -96,11 +103,11 @@ using BitcastOpLowering =
     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
 using DivFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
-                                          /*Constrained=*/false,
+                                          /*HasRoundingMode=*/false,
                                           arith::AttrConvertFastMathToLLVM,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+    arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using DivSIOpLowering =
     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
@@ -148,11 +155,11 @@ using MinUIOpLowering =
     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
 using MulFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
-                                          /*Constrained=*/false,
+                                          /*HasRoundingMode=*/false,
                                           arith::AttrConvertFastMathToLLVM,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+    arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using MulIOpLowering =
     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
@@ -164,11 +171,11 @@ using NegFOpLowering =
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                                          /*Constrained=*/false,
+                                          /*HasRoundingMode=*/false,
                                           arith::AttrConvertFastMathToLLVM,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
@@ -187,21 +194,22 @@ using SIToFPOpLowering =
     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
 using SubFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
-                                          /*Constrained=*/false,
+                                          /*HasRoundingMode=*/false,
                                           arith::AttrConvertFastMathToLLVM,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+    arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using SubIOpLowering =
     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
                                arith::AttrConvertOverflowToLLVM>;
 using TruncFOpLowering =
     ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
-                                          false, AttrConvertPassThrough,
+                                          /*HasRoundingMode=*/false,
+                                          AttrConvertPassThrough,
                                           /*FailOnUnsupportedFP=*/true>;
 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+    arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
     arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using TruncIOpLowering =
     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index d4d425ad72970..a13f850a7cc2a 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -462,6 +462,18 @@ func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
 
 // -----
 
+// Verify that fastmath flags are stripped when lowering to constrained
+// intrinsics (constrained FP and fastmath are contradictory).
+// CHECK-LABEL: constrained_addf_with_fastmath
+func.func @constrained_addf_with_fastmath(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore : f64
+// CHECK-NOT: fastmath
+  %0 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f64
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @convertf_f16_to_bf16
 func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
 // CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32

>From 8d2e4ec80380f163f24daaae2c25f73d6639e8f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 13:05:58 +0000
Subject: [PATCH 3/7] update

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7d1d6f4ae1207..8eb47fea92f8b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -34,6 +34,10 @@
 using namespace mlir;
 using namespace mlir::arith;
 
+/// Default rounding mode according to IEEE-754.
+static constexpr llvm::RoundingMode kDefaultRoundingMode =
+    llvm::RoundingMode::NearestTiesToEven;
+
 //===----------------------------------------------------------------------===//
 // Pattern helpers
 //===----------------------------------------------------------------------===//
@@ -1112,7 +1116,7 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
         result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                         : llvm::RoundingMode::NearestTiesToEven);
+                         : kDefaultRoundingMode);
         return result;
       });
 }
@@ -1131,7 +1135,7 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
         result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                              : llvm::RoundingMode::NearestTiesToEven);
+                              : kDefaultRoundingMode);
         return result;
       });
 }
@@ -1327,7 +1331,7 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
         result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                              : llvm::RoundingMode::NearestTiesToEven);
+                              : kDefaultRoundingMode);
         return result;
       });
 }
@@ -1351,7 +1355,7 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
         result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                            : llvm::RoundingMode::NearestTiesToEven);
+                            : kDefaultRoundingMode);
         return result;
       });
 }
@@ -1481,9 +1485,10 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
 
 /// Attempts to convert `sourceValue` to an APFloat value with
 /// `targetSemantics` and `roundingMode`, without any information loss.
-static FailureOr<APFloat> convertFloatValue(
-    APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
-    llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
+static FailureOr<APFloat>
+convertFloatValue(APFloat sourceValue,
+                  const llvm::fltSemantics &targetSemantics,
+                  llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
   // Reject special values that are not representable in the target type before
   // calling APFloat::convert, which would llvm_unreachable on them.
   using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;

>From 1394971c61546bf5445cb7663948dbaffc1d01c6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 13:32:23 +0000
Subject: [PATCH 4/7] fix

---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td |  8 ++++++++
 .../ComplexToStandard/ComplexToStandard.cpp    |  8 ++++----
 mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 +++++++++---------
 3 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 864c947c005fa..3ae1622c71e04 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -115,6 +115,14 @@ class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
       build($_builder, $_state, lhs, rhs, fastmath,
             ::mlir::arith::RoundingModeAttr{});
     }]>,
+    OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, type, lhs, rhs,
+            ::mlir::arith::FastMathFlagsAttr::get(
+                $_builder.getContext(), fastmath),
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
   ];
   let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
                           (`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index b899220f2e9af..9e46b7d78baca 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
 
     Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
     Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
-    Value resultReal =
-        BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
+    Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
+                                                realRhs, fmf.getValue());
     Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
     Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
-    Value resultImag =
-        BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
+    Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
+                                                imagRhs, fmf.getValue());
     rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
                                                    resultImag);
     return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index ec56ab3ab42e5..f76ddfae2a67a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
   Type type = operand.getType();
   Value sin = math::SinOp::create(b, type, operand);
   Value cos = math::CosOp::create(b, type, operand);
-  Value div = arith::DivFOp::create(b, sin, cos);
+  Value div = arith::DivFOp::create(b, type, sin, cos);
   rewriter.replaceOp(op, div);
   return success();
 }
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
   Value operandB = op.getOperand(1);
   Value operandC = op.getOperand(2);
   Type type = op.getType();
-  Value mult = arith::MulFOp::create(b, operandA, operandB);
-  Value add = arith::AddFOp::create(b, mult, operandC);
+  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+  Value add = arith::AddFOp::create(b, type, mult, operandC);
   rewriter.replaceOp(op, add);
   return success();
 }
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
   Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
 
   while (absPower > 0) {
     if (absPower & 1)
-      res = arith::MulFOp::create(b, base, res);
+      res = arith::MulFOp::create(b, baseType, base, res);
     absPower >>= 1;
-    base = arith::MulFOp::create(b, base, base);
+    base = arith::MulFOp::create(b, baseType, base, base);
   }
 
   // Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
     Value negZeroEqCheck =
         arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
-    res = arith::DivFOp::create(b, one, res);
+    res = arith::DivFOp::create(b, baseType, one, res);
     res =
         arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
     res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
   Value operand = op.getOperand();
   Type opType = operand.getType();
   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
-  Value mult = arith::MulFOp::create(b, operand, ln2);
+  Value mult = arith::MulFOp::create(b, opType, operand, ln2);
   Value exp = math::ExpOp::create(b, op->getLoc(), mult);
   rewriter.replaceOp(op, exp);
   return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
   Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
 
   Value incrValue = math::CopySignOp::create(b, half, operand);
-  Value add = arith::AddFOp::create(b, operand, incrValue);
+  Value add = arith::AddFOp::create(b, opType, operand, incrValue);
   Value fpFixedConvert = createTruncatedFPValue(add, b);
 
   // There are three cases where adding 0.5 to the value and truncating by

>From 11a73a190d070e00a95d70f9d00a91ed529d1298 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 13:18:25 +0000
Subject: [PATCH 5/7] address comments

---
 .../ArithCommon/AttrToLLVMConverter.h         |  4 +-
 .../include/mlir/Dialect/Arith/IR/ArithOps.td | 49 +++++++++++--------
 2 files changed, 30 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index c0773c9b69b6b..feb74c86e349f 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,8 +157,8 @@ class AttrConverterConstrainedFPToLLVM {
       convertedAttr.set(TargetOp::getRoundingModeAttrName(),
                         convertArithRoundingModeAttrToLLVM(arithAttr));
     }
-    // Constrained intrinsics do not support fastmath flags. Remove the
-    // arith fastmath attribute if present.
+    // Constrained intrinsics (llvm.intr.experimental.constrained.*) do not
+    // support fastmath flags. Remove the arith fastmath attribute if present.
     if constexpr (SourceOp::template hasTrait<
                       arith::ArithFastMathInterface::Trait>())
       convertedAttr.erase(srcOp.getFastMathAttrName());
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3ae1622c71e04..9fe8cde876b0c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -81,11 +81,11 @@ class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
 class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
     Arith_BinaryOp<mnemonic,
       !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
-                  traits)>,
-    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+                  traits)> {
+  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
       DefaultValuedAttr<
-        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
-    Results<(outs FloatLike:$result)> {
+        Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
+  let results = (outs FloatLike:$result);
   let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
                           attr-dict `:` type($result) }];
 }
@@ -94,15 +94,13 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
 // mode.
 class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
                                           list<Trait> traits = []> :
-    Arith_BinaryOp<mnemonic,
-      !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
-                   DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
-                  traits)>,
-    Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+    Arith_FloatBinaryOp<mnemonic,
+      !listconcat([DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+                  traits)> {
+  let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
       DefaultValuedAttr<
         Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
-      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
-    Results<(outs FloatLike:$result)> {
+      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode);
   let builders = [
     OpBuilder<(ins "Value":$lhs, "Value":$rhs,
       CArg<"::mlir::arith::FastMathFlags",
@@ -1005,7 +1003,8 @@ def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
     floating point tensor.
 
     If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
 
     Example:
 
@@ -1039,7 +1038,8 @@ def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
     floating point tensor.
 
     If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
 
     Example:
 
@@ -1193,7 +1193,8 @@ def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
     floating point tensor.
 
     If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
 
     Example:
 
@@ -1228,7 +1229,8 @@ def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
     floating point tensor.
 
     If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
 
     Example:
 
@@ -1255,7 +1257,8 @@ def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
     The remainder has the same sign as the dividend (lhs operand).
 
     If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
   }];
   let hasFolder = 1;
 }
@@ -1490,9 +1493,11 @@ def Arith_TruncFOp :
   let description = [{
     Truncate a floating-point value to a smaller floating-point-typed value.
     The destination type must be strictly narrower than the source type.
-    If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or the default one if no rounding mode is provided.
     When operating on vectors, casts elementwise.
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
   }];
   let builders = [
     OpBuilder<(ins "Type":$out, "Value":$in), [{
@@ -1531,9 +1536,11 @@ def Arith_ConvertFOp :
     be represented by `arith.extf` or `arith.truncf`.
 
     The source and destination element types must be different and must have
-    the same bitwidth. If the value cannot be exactly represented, it is
-    rounded using the provided rounding mode or the default one if no rounding
-    mode is provided. When operating on vectors, casts elementwise.
+    the same bitwidth. When operating on vectors, casts elementwise.
+
+    If the value cannot be exactly represented, it is rounded using the
+    provided rounding mode or, if no rounding mode is provided, according to
+    the default LLVM floating-point environment.
   }];
 
   let hasFolder = 1;

>From 23a52279abed2a24c68c146654fdd3795491ec54 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 13:50:07 +0000
Subject: [PATCH 6/7] remove remainderf

---
 .../include/mlir/Dialect/Arith/IR/ArithOps.td |  6 ++--
 .../Conversion/ArithToLLVM/ArithToLLVM.cpp    | 11 ++----
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 36 ++++++++-----------
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 11 ------
 mlir/test/Dialect/Arith/ops.mlir              |  4 +--
 5 files changed, 21 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 9fe8cde876b0c..6b447960d74b7 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1250,15 +1250,13 @@ def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
 // RemFOp
 //===----------------------------------------------------------------------===//
 
-def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
   let summary = "floating point division remainder operation";
   let description = [{
     Returns the floating point division remainder.
     The remainder has the same sign as the dividend (lhs operand).
 
-    If the value cannot be exactly represented, it is rounded using the
-    provided rounding mode or, if no rounding mode is provided, according to
-    the default LLVM floating-point environment.
+    TODO: Add support for rounding modes.
   }];
   let hasFolder = 1;
 }
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 30f32b8e248f4..2624420cf5318 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -170,13 +170,9 @@ using NegFOpLowering =
                                /*FailOnUnsupportedFP=*/true>;
 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
 using RemFOpLowering =
-    ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
-                                          /*HasRoundingMode=*/false,
-                                          arith::AttrConvertFastMathToLLVM,
-                                          /*FailOnUnsupportedFP=*/true>;
-using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
-    arith::RemFOp, LLVM::ConstrainedFRemIntr, /*HasRoundingMode=*/true,
-    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
+    VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+                               arith::AttrConvertFastMathToLLVM,
+                               /*FailOnUnsupportedFP=*/true>;
 using RemSIOpLowering =
     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
 using RemUIOpLowering =
@@ -764,7 +760,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
     NegFOpLowering,
     OrIOpLowering,
     RemFOpLowering,
-    ConstrainedRemFOpLowering,
     RemSIOpLowering,
     RemUIOpLowering,
     SelectOpLowering,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8eb47fea92f8b..e11a38ffec50c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -34,7 +34,7 @@
 using namespace mlir;
 using namespace mlir::arith;
 
-/// Default rounding mode according to IEEE-754.
+/// Default rounding mode according to default LLVM floating-point environment.
 static constexpr llvm::RoundingMode kDefaultRoundingMode =
     llvm::RoundingMode::NearestTiesToEven;
 
@@ -109,8 +109,10 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
 /// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
 /// on the LLVM dialect and on translation to LLVM.
 static llvm::RoundingMode
-convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
-  switch (roundingMode) {
+convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
+  if (!roundingMode)
+    return kDefaultRoundingMode;
+  switch (*roundingMode) {
   case RoundingMode::downward:
     return llvm::RoundingMode::TowardNegative;
   case RoundingMode::to_nearest_away:
@@ -1111,12 +1113,11 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
     return getLhs();
 
-  auto rm = getRoundingmodeAttr();
+  auto rm = getRoundingmode();
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
-        result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                         : kDefaultRoundingMode);
+        result.add(b, convertArithRoundingModeToLLVMIR(rm));
         return result;
       });
 }
@@ -1130,12 +1131,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
     return getLhs();
 
-  auto rm = getRoundingmodeAttr();
+  auto rm = getRoundingmode();
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
-        result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                              : kDefaultRoundingMode);
+        result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
         return result;
       });
 }
@@ -1326,12 +1326,11 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
       return getRhs();
   }
 
-  auto rm = getRoundingmodeAttr();
+  auto rm = getRoundingmode();
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
-        result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                              : kDefaultRoundingMode);
+        result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
         return result;
       });
 }
@@ -1350,12 +1349,11 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
   if (matchPattern(adaptor.getRhs(), m_OneFloat()))
     return getLhs();
 
-  auto rm = getRoundingmodeAttr();
+  auto rm = getRoundingmode();
   return constFoldBinaryOp<FloatAttr>(
       adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
         APFloat result(a);
-        result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
-                            : kDefaultRoundingMode);
+        result.divide(b, convertArithRoundingModeToLLVMIR(rm));
         return result;
       });
 }
@@ -1710,10 +1708,8 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
       [this, &targetSemantics](const APFloat &a, bool &castStatus) {
-        RoundingMode roundingMode =
-            getRoundingmode().value_or(RoundingMode::to_nearest_even);
         llvm::RoundingMode llvmRoundingMode =
-            convertArithRoundingModeToLLVMIR(roundingMode);
+            convertArithRoundingModeToLLVMIR(getRoundingmode());
         FailureOr<APFloat> result =
             convertFloatValue(a, targetSemantics, llvmRoundingMode);
         if (failed(result)) {
@@ -1747,10 +1743,8 @@ OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
   return constFoldCastOp<FloatAttr, FloatAttr>(
       adaptor.getOperands(), getType(),
       [this, &targetSemantics](const APFloat &a, bool &castStatus) {
-        RoundingMode roundingMode =
-            getRoundingmode().value_or(RoundingMode::to_nearest_even);
         llvm::RoundingMode llvmRoundingMode =
-            convertArithRoundingModeToLLVMIR(roundingMode);
+            convertArithRoundingModeToLLVMIR(getRoundingmode());
         FailureOr<APFloat> result =
             convertFloatValue(a, targetSemantics, llvmRoundingMode);
         if (failed(result)) {
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a13f850a7cc2a..df58d4ffcaf51 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -451,17 +451,6 @@ func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
 
 // -----
 
-// CHECK-LABEL: experimental_constrained_remf
-func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
-  %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
-// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
-  %1 = arith.remf %arg0, %arg1 downward : f64
-  return
-}
-
-// -----
-
 // Verify that fastmath flags are stripped when lowering to constrained
 // intrinsics (constrained FP and fastmath are contradictory).
 // CHECK-LABEL: constrained_addf_with_fastmath
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f6e21c94969f2..3874c85818eb4 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1244,10 +1244,8 @@ func.func @roundingmode(%arg0: f32, %arg1: f32) {
   %2 = arith.mulf %arg0, %arg1 upward : f32
 // CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
   %3 = arith.divf %arg0, %arg1 toward_zero : f32
-// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
-  %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
 // CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
-  %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+  %4 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
   return
 }
 

>From 2b7ce18d850d82050d022f9610b0249bc1e515ca Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 14:03:56 +0000
Subject: [PATCH 7/7] skip canonicalization if rounding mode is set

---
 mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 03200c16ac378..b822f1eadf0eb 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -437,21 +437,25 @@ def UIToFPOfExtUI :
 //===----------------------------------------------------------------------===//
 
 // mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags and rounding mode of original mulf)
+// TODO: Verify if this canonicalization is safe when a rounding mode is
+// specified. For the moment, bail on custom rounding modes.
 def MulFOfNegF :
     Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
         (Arith_MulFOp $x, $y, $fmf, $rm),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y),
+         (Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rm)]>;
 
 //===----------------------------------------------------------------------===//
 // DivFOp
 //===----------------------------------------------------------------------===//
 
 // divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags and rounding mode of original divf)
+// TODO: Verify if this canonicalization is safe when a rounding mode is
+// specified. For the moment, bail on custom rounding modes.
 def DivFOfNegF :
     Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
         (Arith_DivFOp $x, $y, $fmf, $rm),
-        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+        [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y),
+         (Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rm)]>;
 
 #endif // ARITH_PATTERNS



More information about the Mlir-commits mailing list