[Mlir-commits] [mlir] [mlir][arith] Add rounding mode flags to binary arithmetic operations (PR #188458)
Matthias Springer
llvmlistbot at llvm.org
Tue Apr 7 07:06:46 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/188458
>From 59d4c2efeea9509fc28c39e4176e8d0a629eb588 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 10:33:39 +0000
Subject: [PATCH 1/7] [mlir][arith] Add rounding mode flags to binary
arithmetic operations
---
.../ArithCommon/AttrToLLVMConverter.h | 5 ++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 90 ++++++++++++++++---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 55 ++++++++----
.../ComplexToStandard/ComplexToStandard.cpp | 8 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 9 +-
.../Dialect/Arith/IR/ArithCanonicalization.td | 12 +--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 ++++++--
.../lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 ++--
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 61 +++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 75 ++++++++++++++++
mlir/test/Dialect/Arith/ops.mlir | 17 ++++
11 files changed, 325 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index fccfe4897114e..c0773c9b69b6b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
convertedAttr.set(TargetOp::getRoundingModeAttrName(),
convertArithRoundingModeAttrToLLVM(arithAttr));
}
+ // Constrained intrinsics do not support fastmath flags. Remove the
+ // arith fastmath attribute if present.
+ if constexpr (SourceOp::template hasTrait<
+ arith::ArithFastMathInterface::Trait>())
+ convertedAttr.erase(srcOp.getFastMathAttrName());
convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 45cb3cecef3d8..864c947c005fa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -90,6 +90,37 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Base class for floating point binary operations with an optional rounding
+// mode.
+class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
+ list<Trait> traits = []> :
+ Arith_BinaryOp<mnemonic,
+ !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+ traits)>,
+ Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+ DefaultValuedAttr<
+ Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$result)> {
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+ build($_builder, $_state, lhs, rhs, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+ "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+ build($_builder, $_state, lhs, rhs, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ ];
+ let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
+}
+
// Checks that tensor input and outputs have identical shapes. This is stricker
// than the verification done in `SameOperandsAndResultShape` that allows for
// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
@@ -957,7 +988,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
// AddFOp
//===----------------------------------------------------------------------===//
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
+def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
let summary = "floating point addition operation";
let description = [{
The `addf` operation takes two operands and returns one result, each of
@@ -965,6 +996,9 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -976,10 +1010,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
// Tensor addition.
%x = arith.addf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar addition with rounding mode.
+ %a = arith.addf %b, %c to_nearest_even : f64
+ ```
}];
let hasFolder = 1;
}
@@ -988,7 +1022,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
// SubFOp
//===----------------------------------------------------------------------===//
-def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
+def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
let summary = "floating point subtraction operation";
let description = [{
The `subf` operation takes two operands and returns one result, each of
@@ -996,6 +1030,9 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
// Tensor subtraction.
%x = arith.subf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar subtraction with rounding mode.
+ %a = arith.subf %b, %c downward : f64
+ ```
}];
let hasFolder = 1;
}
@@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
// MulFOp
//===----------------------------------------------------------------------===//
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
+def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
let summary = "floating point multiplication operation";
let description = [{
The `mulf` operation takes two operands and returns one result, each of
@@ -1147,6 +1184,9 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -1158,10 +1198,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
// Tensor pointwise multiplication.
%x = arith.mulf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar multiplication with rounding mode.
+ %a = arith.mulf %b, %c upward : f64
+ ```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
@@ -1171,8 +1211,27 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
// DivFOp
//===----------------------------------------------------------------------===//
-def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
+def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
let summary = "floating point division operation";
+ let description = [{
+ The `divf` operation takes two operands and returns one result, each of
+ these is required to be the same type. This type may be a floating point
+ scalar type, a vector whose element type is a floating point type, or a
+ floating point tensor.
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
+ Example:
+
+ ```mlir
+ // Scalar division.
+ %a = arith.divf %b, %c : f64
+
+ // Scalar division with rounding mode.
+ %a = arith.divf %b, %c toward_zero : f64
+ ```
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1181,11 +1240,14 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
// RemFOp
//===----------------------------------------------------------------------===//
-def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
let summary = "floating point division remainder operation";
let description = [{
Returns the floating point division remainder.
The remainder has the same sign as the dividend (lhs operand).
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
}];
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index f9ea8dba105a4..51001cb55a627 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -81,9 +81,13 @@ struct IdentityBitcastLowering final
//===----------------------------------------------------------------------===//
using AddFOpLowering =
- VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -91,9 +95,13 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
- VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
@@ -139,9 +147,13 @@ using MinSIOpLowering =
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
- VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
@@ -151,9 +163,13 @@ using NegFOpLowering =
/*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
- VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -170,9 +186,13 @@ using ShRUIOpLowering =
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
- VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
@@ -700,6 +720,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
AddFOpLowering,
+ ConstrainedAddFOpLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
@@ -708,6 +729,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
CmpFOpLowering,
CmpIOpLowering,
DivFOpLowering,
+ ConstrainedDivFOpLowering,
DivSIOpLowering,
DivUIOpLowering,
ExtFOpLowering,
@@ -727,12 +749,14 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
MinSIOpLowering,
MinUIOpLowering,
MulFOpLowering,
+ ConstrainedMulFOpLowering,
MulIOpLowering,
MulSIExtendedOpLowering,
MulUIExtendedOpLowering,
NegFOpLowering,
OrIOpLowering,
RemFOpLowering,
+ ConstrainedRemFOpLowering,
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
@@ -742,6 +766,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ShRUIOpLowering,
SIToFPOpLowering,
SubFOpLowering,
+ ConstrainedSubFOpLowering,
SubIOpLowering,
TruncFOpLowering,
ConstrainedTruncFOpLowering,
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..b899220f2e9af 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
- Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
- realRhs, fmf.getValue());
+ Value resultReal =
+ BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
- Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
- imagRhs, fmf.getValue());
+ Value resultImag =
+ BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..11b3aabcbfeb4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -120,7 +120,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
- return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
+ return arith::DivFOp::create(rewriter, loc, one, args[0]);
}
// tosa::MulOp
@@ -140,8 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
"Cannot have shift value for float");
return nullptr;
}
- return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
- args[1]);
+ return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<IntegerType>(elementTy)) {
@@ -538,8 +537,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
- auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
- return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
+ auto added = arith::AddFOp::create(rewriter, loc, exp, one);
+ return arith::DivFOp::create(rewriter, loc, one, added);
}
// tosa::CastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index a15e19b24e54b..03200c16ac378 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -437,10 +437,10 @@ def UIToFPOfExtUI :
//===----------------------------------------------------------------------===//
// mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and rounding mode of original mulf)
def MulFOfNegF :
- Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_MulFOp $x, $y, $fmf),
+ Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+ (Arith_MulFOp $x, $y, $fmf, $rm),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
//===----------------------------------------------------------------------===//
@@ -448,10 +448,10 @@ def MulFOfNegF :
//===----------------------------------------------------------------------===//
// divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and rounding mode of original divf)
def DivFOfNegF :
- Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_DivFOp $x, $y, $fmf),
+ Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+ (Arith_DivFOp $x, $y, $fmf, $rm),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
#endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 569d1869a5abe..7d1d6f4ae1207 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1107,9 +1107,14 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a + b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1121,9 +1126,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a - b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1312,9 +1322,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
return getRhs();
}
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a * b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1331,9 +1346,14 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a / b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index f76ddfae2a67a..ec56ab3ab42e5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
Type type = operand.getType();
Value sin = math::SinOp::create(b, type, operand);
Value cos = math::CosOp::create(b, type, operand);
- Value div = arith::DivFOp::create(b, type, sin, cos);
+ Value div = arith::DivFOp::create(b, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = arith::MulFOp::create(b, type, operandA, operandB);
- Value add = arith::AddFOp::create(b, type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, operandA, operandB);
+ Value add = arith::AddFOp::create(b, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = arith::MulFOp::create(b, baseType, base, res);
+ res = arith::MulFOp::create(b, base, res);
absPower >>= 1;
- base = arith::MulFOp::create(b, baseType, base, base);
+ base = arith::MulFOp::create(b, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
- res = arith::DivFOp::create(b, baseType, one, res);
+ res = arith::DivFOp::create(b, one, res);
res =
arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
Value operand = op.getOperand();
Type opType = operand.getType();
Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
- Value mult = arith::MulFOp::create(b, opType, operand, ln2);
+ Value mult = arith::MulFOp::create(b, operand, ln2);
Value exp = math::ExpOp::create(b, op->getLoc(), mult);
rewriter.replaceOp(op, exp);
return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
Value incrValue = math::CopySignOp::create(b, half, operand);
- Value add = arith::AddFOp::create(b, opType, operand, incrValue);
+ Value add = arith::AddFOp::create(b, operand, incrValue);
Value fpFixedConvert = createTruncatedFPValue(add, b);
// There are three cases where adding 0.5 to the value and truncating by
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 75601e215744c..d4d425ad72970 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -401,6 +401,67 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// -----
+// CHECK-LABEL: experimental_constrained_addf
+func.func @experimental_constrained_addf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore
+ %0 = arith.addf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 downward ignore
+ %1 = arith.addf %arg0, %arg1 downward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 upward ignore
+ %2 = arith.addf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 towardzero ignore
+ %3 = arith.addf %arg0, %arg1 toward_zero : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearestaway ignore
+ %4 = arith.addf %arg0, %arg1 to_nearest_away : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_subf
+func.func @experimental_constrained_subf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 tonearest ignore
+ %0 = arith.subf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 downward ignore
+ %1 = arith.subf %arg0, %arg1 downward : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_mulf
+func.func @experimental_constrained_mulf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 upward ignore
+ %0 = arith.mulf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 towardzero ignore
+ %1 = arith.mulf %arg0, %arg1 toward_zero : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_divf
+func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearest ignore
+ %0 = arith.divf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearestaway ignore
+ %1 = arith.divf %arg0, %arg1 to_nearest_away : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_remf
+func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
+ %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
+ %1 = arith.remf %arg0, %arg1 downward : f64
+ return
+}
+
+// -----
+
// CHECK-LABEL: @convertf_f16_to_bf16
func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index ee3e713f8481e..b153bd7c32261 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2554,6 +2554,81 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
+// Verify that constant folding respects rounding modes. 1.0000001 + 1.0 is not
+// exactly representable in f32. With upward rounding, the result is rounded up,
+// and with downward rounding it is rounded down.
+// CHECK-LABEL: @test_addf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_addf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 2.00000024 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant 1.0 : f32
+ // addf(x, -0) folds even with a rounding mode.
+ %c_neg0 = arith.constant -0.0 : f32
+ %0 = arith.addf %arg0, %c_neg0 to_nearest_even : f32
+ %1 = arith.addf %a, %b upward : f32
+ %2 = arith.addf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 2.00000024 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant -1.0 : f32
+ // subf(x, +0) folds even with a rounding mode.
+ %c0 = arith.constant 0.0 : f32
+ %0 = arith.subf %arg0, %c0 downward : f32
+ %1 = arith.subf %a, %b upward : f32
+ %2 = arith.subf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_mulf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 3.00000048 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 3.00000024 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant 3.0 : f32
+ // mulf(x, 1) folds even with a rounding mode.
+ %c1 = arith.constant 1.0 : f32
+ %0 = arith.mulf %arg0, %c1 upward : f32
+ %1 = arith.mulf %a, %b upward : f32
+ %2 = arith.mulf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_divf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 0.333333343 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 0.333333313 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0 : f32
+ %b = arith.constant 3.0 : f32
+ // divf(x, 1) folds even with a rounding mode.
+ %c1 = arith.constant 1.0 : f32
+ %0 = arith.divf %arg0, %c1 toward_zero : f32
+ %1 = arith.divf %a, %b upward : f32
+ %2 = arith.divf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
%1 = arith.divui %0, %arg0 : index
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 2c5371de9ff24..f6e21c94969f2 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1234,6 +1234,23 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
return
}
+// CHECK-LABEL: @roundingmode
+func.func @roundingmode(%arg0: f32, %arg1: f32) {
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even : f32
+ %0 = arith.addf %arg0, %arg1 to_nearest_even : f32
+// CHECK: {{.*}} = arith.subf %arg0, %arg1 downward : f32
+ %1 = arith.subf %arg0, %arg1 downward : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 upward : f32
+ %2 = arith.mulf %arg0, %arg1 upward : f32
+// CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
+ %3 = arith.divf %arg0, %arg1 toward_zero : f32
+// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
+ %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+ %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+ return
+}
+
// CHECK-LABEL: @select_tensor
func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
// CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>
>From 360b1dbce7a7fad3408925fda871e6ccc065bbd2 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 12:56:53 +0000
Subject: [PATCH 2/7] address comments
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 46 +++++++++++--------
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 12 +++++
2 files changed, 39 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 51001cb55a627..30f32b8e248f4 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -29,13 +29,20 @@ using namespace mlir;
namespace {
-/// Operations whose conversion will depend on whether they are passed a
-/// rounding mode attribute or not.
+/// Lowering pattern that matches only when the source op's rounding mode
+/// presence agrees with `HasRoundingMode`. This allows registering two
+/// instances of the same pattern for one source op: one that handles the
+/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
+/// one that handles the constrained case (rounding mode present, lowering to
+/// a constrained LLVM intrinsic).
///
-/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
-/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
-/// attribute.
-template <typename SourceOp, typename TargetOp, bool Constrained,
+/// * `HasRoundingMode`: the pattern matches if and only if the source op has
+/// a rounding mode attribute.
+/// * `AttrConvert`: attribute converter to translate source attributes to
+/// target attributes.
+/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
+/// floating point types.
+template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
@@ -49,7 +56,7 @@ struct ConstrainedVectorConvertToLLVMPattern
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
return VectorConvertToLLVMPattern<
SourceOp, TargetOp, AttrConvert,
@@ -82,11 +89,11 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+ arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
@@ -96,11 +103,11 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+ arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
@@ -148,11 +155,11 @@ using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+ arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
@@ -164,11 +171,11 @@ using NegFOpLowering =
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+ arith::RemFOp, LLVM::ConstrainedFRemIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
@@ -187,21 +194,22 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+ arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false, AttrConvertPassThrough,
+ /*HasRoundingMode=*/false,
+ AttrConvertPassThrough,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+ arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index d4d425ad72970..a13f850a7cc2a 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -462,6 +462,18 @@ func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
// -----
+// Verify that fastmath flags are stripped when lowering to constrained
+// intrinsics (constrained FP and fastmath are contradictory).
+// CHECK-LABEL: constrained_addf_with_fastmath
+func.func @constrained_addf_with_fastmath(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore : f64
+// CHECK-NOT: fastmath
+ %0 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f64
+ return
+}
+
+// -----
+
// CHECK-LABEL: @convertf_f16_to_bf16
func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
>From 8d2e4ec80380f163f24daaae2c25f73d6639e8f5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 13:05:58 +0000
Subject: [PATCH 3/7] update
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 ++++++++++++-------
1 file changed, 12 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7d1d6f4ae1207..8eb47fea92f8b 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -34,6 +34,10 @@
using namespace mlir;
using namespace mlir::arith;
+/// Default rounding mode according to IEEE-754.
+static constexpr llvm::RoundingMode kDefaultRoundingMode =
+ llvm::RoundingMode::NearestTiesToEven;
+
//===----------------------------------------------------------------------===//
// Pattern helpers
//===----------------------------------------------------------------------===//
@@ -1112,7 +1116,7 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1131,7 +1135,7 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1327,7 +1331,7 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1351,7 +1355,7 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1481,9 +1485,10 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
/// Attempts to convert `sourceValue` to an APFloat value with
/// `targetSemantics` and `roundingMode`, without any information loss.
-static FailureOr<APFloat> convertFloatValue(
- APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
- llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
+static FailureOr<APFloat>
+convertFloatValue(APFloat sourceValue,
+ const llvm::fltSemantics &targetSemantics,
+ llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
// Reject special values that are not representable in the target type before
// calling APFloat::convert, which would llvm_unreachable on them.
using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
>From 1394971c61546bf5445cb7663948dbaffc1d01c6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 13:32:23 +0000
Subject: [PATCH 4/7] fix
---
mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 8 ++++++++
.../ComplexToStandard/ComplexToStandard.cpp | 8 ++++----
mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 +++++++++---------
3 files changed, 21 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 864c947c005fa..3ae1622c71e04 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -115,6 +115,14 @@ class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
build($_builder, $_state, lhs, rhs, fastmath,
::mlir::arith::RoundingModeAttr{});
}]>,
+ OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+ build($_builder, $_state, type, lhs, rhs,
+ ::mlir::arith::FastMathFlagsAttr::get(
+ $_builder.getContext(), fastmath),
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
];
let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
(`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index b899220f2e9af..9e46b7d78baca 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
- Value resultReal =
- BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
+ Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
+ realRhs, fmf.getValue());
Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
- Value resultImag =
- BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
+ Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
+ imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index ec56ab3ab42e5..f76ddfae2a67a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
Type type = operand.getType();
Value sin = math::SinOp::create(b, type, operand);
Value cos = math::CosOp::create(b, type, operand);
- Value div = arith::DivFOp::create(b, sin, cos);
+ Value div = arith::DivFOp::create(b, type, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = arith::MulFOp::create(b, operandA, operandB);
- Value add = arith::AddFOp::create(b, mult, operandC);
+ Value mult = arith::MulFOp::create(b, type, operandA, operandB);
+ Value add = arith::AddFOp::create(b, type, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
+ Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = arith::MulFOp::create(b, base, res);
+ res = arith::MulFOp::create(b, baseType, base, res);
absPower >>= 1;
- base = arith::MulFOp::create(b, base, base);
+ base = arith::MulFOp::create(b, baseType, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
- res = arith::DivFOp::create(b, one, res);
+ res = arith::DivFOp::create(b, baseType, one, res);
res =
arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
Value operand = op.getOperand();
Type opType = operand.getType();
Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
- Value mult = arith::MulFOp::create(b, operand, ln2);
+ Value mult = arith::MulFOp::create(b, opType, operand, ln2);
Value exp = math::ExpOp::create(b, op->getLoc(), mult);
rewriter.replaceOp(op, exp);
return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
Value incrValue = math::CopySignOp::create(b, half, operand);
- Value add = arith::AddFOp::create(b, operand, incrValue);
+ Value add = arith::AddFOp::create(b, opType, operand, incrValue);
Value fpFixedConvert = createTruncatedFPValue(add, b);
// There are three cases where adding 0.5 to the value and truncating by
>From 11a73a190d070e00a95d70f9d00a91ed529d1298 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 13:18:25 +0000
Subject: [PATCH 5/7] address comments
---
.../ArithCommon/AttrToLLVMConverter.h | 4 +-
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 49 +++++++++++--------
2 files changed, 30 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index c0773c9b69b6b..feb74c86e349f 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,8 +157,8 @@ class AttrConverterConstrainedFPToLLVM {
convertedAttr.set(TargetOp::getRoundingModeAttrName(),
convertArithRoundingModeAttrToLLVM(arithAttr));
}
- // Constrained intrinsics do not support fastmath flags. Remove the
- // arith fastmath attribute if present.
+ // Constrained intrinsics (llvm.intr.experimental.constrained.*) do not
+ // support fastmath flags. Remove the arith fastmath attribute if present.
if constexpr (SourceOp::template hasTrait<
arith::ArithFastMathInterface::Trait>())
convertedAttr.erase(srcOp.getFastMathAttrName());
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 3ae1622c71e04..9fe8cde876b0c 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -81,11 +81,11 @@ class Arith_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic,
!listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>],
- traits)>,
- Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+ traits)> {
+ let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
- Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath)>,
- Results<(outs FloatLike:$result)> {
+ Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let results = (outs FloatLike:$result);
let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($result) }];
}
@@ -94,15 +94,13 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
// mode.
class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
list<Trait> traits = []> :
- Arith_BinaryOp<mnemonic,
- !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
- DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
- traits)>,
- Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+ Arith_FloatBinaryOp<mnemonic,
+ !listconcat([DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+ traits)> {
+ let arguments = (ins FloatLike:$lhs, FloatLike:$rhs,
DefaultValuedAttr<
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
- OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
- Results<(outs FloatLike:$result)> {
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
CArg<"::mlir::arith::FastMathFlags",
@@ -1005,7 +1003,8 @@ def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
floating point tensor.
If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
Example:
@@ -1039,7 +1038,8 @@ def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
floating point tensor.
If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
Example:
@@ -1193,7 +1193,8 @@ def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
floating point tensor.
If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
Example:
@@ -1228,7 +1229,8 @@ def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
floating point tensor.
If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
Example:
@@ -1255,7 +1257,8 @@ def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
The remainder has the same sign as the dividend (lhs operand).
If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
}];
let hasFolder = 1;
}
@@ -1490,9 +1493,11 @@ def Arith_TruncFOp :
let description = [{
Truncate a floating-point value to a smaller floating-point-typed value.
The destination type must be strictly narrower than the source type.
- If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or the default one if no rounding mode is provided.
When operating on vectors, casts elementwise.
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
}];
let builders = [
OpBuilder<(ins "Type":$out, "Value":$in), [{
@@ -1531,9 +1536,11 @@ def Arith_ConvertFOp :
be represented by `arith.extf` or `arith.truncf`.
The source and destination element types must be different and must have
- the same bitwidth. If the value cannot be exactly represented, it is
- rounded using the provided rounding mode or the default one if no rounding
- mode is provided. When operating on vectors, casts elementwise.
+ the same bitwidth. When operating on vectors, casts elementwise.
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or, if no rounding mode is provided, according to
+ the default LLVM floating-point environment.
}];
let hasFolder = 1;
>From 23a52279abed2a24c68c146654fdd3795491ec54 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 13:50:07 +0000
Subject: [PATCH 6/7] remove remainderf
---
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 6 ++--
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 11 ++----
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 ++++++++-----------
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 11 ------
mlir/test/Dialect/Arith/ops.mlir | 4 +--
5 files changed, 21 insertions(+), 47 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 9fe8cde876b0c..6b447960d74b7 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1250,15 +1250,13 @@ def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
// RemFOp
//===----------------------------------------------------------------------===//
-def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
let summary = "floating point division remainder operation";
let description = [{
Returns the floating point division remainder.
The remainder has the same sign as the dividend (lhs operand).
- If the value cannot be exactly represented, it is rounded using the
- provided rounding mode or, if no rounding mode is provided, according to
- the default LLVM floating-point environment.
+ TODO: Add support for rounding modes.
}];
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 30f32b8e248f4..2624420cf5318 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -170,13 +170,9 @@ using NegFOpLowering =
/*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
- ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- /*HasRoundingMode=*/false,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
-using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::RemFOp, LLVM::ConstrainedFRemIntr, /*HasRoundingMode=*/true,
- arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
+ VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -764,7 +760,6 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
NegFOpLowering,
OrIOpLowering,
RemFOpLowering,
- ConstrainedRemFOpLowering,
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8eb47fea92f8b..e11a38ffec50c 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -34,7 +34,7 @@
using namespace mlir;
using namespace mlir::arith;
-/// Default rounding mode according to IEEE-754.
+/// Default rounding mode according to default LLVM floating-point environment.
static constexpr llvm::RoundingMode kDefaultRoundingMode =
llvm::RoundingMode::NearestTiesToEven;
@@ -109,8 +109,10 @@ arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
/// circular dependency with MLIRArithAttrToLLVMConversion and make arith depend
/// on the LLVM dialect and on translation to LLVM.
static llvm::RoundingMode
-convertArithRoundingModeToLLVMIR(RoundingMode roundingMode) {
- switch (roundingMode) {
+convertArithRoundingModeToLLVMIR(std::optional<RoundingMode> roundingMode) {
+ if (!roundingMode)
+ return kDefaultRoundingMode;
+ switch (*roundingMode) {
case RoundingMode::downward:
return llvm::RoundingMode::TowardNegative;
case RoundingMode::to_nearest_away:
@@ -1111,12 +1113,11 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
return getLhs();
- auto rm = getRoundingmodeAttr();
+ auto rm = getRoundingmode();
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
- result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : kDefaultRoundingMode);
+ result.add(b, convertArithRoundingModeToLLVMIR(rm));
return result;
});
}
@@ -1130,12 +1131,11 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();
- auto rm = getRoundingmodeAttr();
+ auto rm = getRoundingmode();
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
- result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : kDefaultRoundingMode);
+ result.subtract(b, convertArithRoundingModeToLLVMIR(rm));
return result;
});
}
@@ -1326,12 +1326,11 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
return getRhs();
}
- auto rm = getRoundingmodeAttr();
+ auto rm = getRoundingmode();
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
- result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : kDefaultRoundingMode);
+ result.multiply(b, convertArithRoundingModeToLLVMIR(rm));
return result;
});
}
@@ -1350,12 +1349,11 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
- auto rm = getRoundingmodeAttr();
+ auto rm = getRoundingmode();
return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
- result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : kDefaultRoundingMode);
+ result.divide(b, convertArithRoundingModeToLLVMIR(rm));
return result;
});
}
@@ -1710,10 +1708,8 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[this, &targetSemantics](const APFloat &a, bool &castStatus) {
- RoundingMode roundingMode =
- getRoundingmode().value_or(RoundingMode::to_nearest_even);
llvm::RoundingMode llvmRoundingMode =
- convertArithRoundingModeToLLVMIR(roundingMode);
+ convertArithRoundingModeToLLVMIR(getRoundingmode());
FailureOr<APFloat> result =
convertFloatValue(a, targetSemantics, llvmRoundingMode);
if (failed(result)) {
@@ -1747,10 +1743,8 @@ OpFoldResult arith::ConvertFOp::fold(FoldAdaptor adaptor) {
return constFoldCastOp<FloatAttr, FloatAttr>(
adaptor.getOperands(), getType(),
[this, &targetSemantics](const APFloat &a, bool &castStatus) {
- RoundingMode roundingMode =
- getRoundingmode().value_or(RoundingMode::to_nearest_even);
llvm::RoundingMode llvmRoundingMode =
- convertArithRoundingModeToLLVMIR(roundingMode);
+ convertArithRoundingModeToLLVMIR(getRoundingmode());
FailureOr<APFloat> result =
convertFloatValue(a, targetSemantics, llvmRoundingMode);
if (failed(result)) {
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a13f850a7cc2a..df58d4ffcaf51 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -451,17 +451,6 @@ func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
// -----
-// CHECK-LABEL: experimental_constrained_remf
-func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
-// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
- %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
-// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
- %1 = arith.remf %arg0, %arg1 downward : f64
- return
-}
-
-// -----
-
// Verify that fastmath flags are stripped when lowering to constrained
// intrinsics (constrained FP and fastmath are contradictory).
// CHECK-LABEL: constrained_addf_with_fastmath
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f6e21c94969f2..3874c85818eb4 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1244,10 +1244,8 @@ func.func @roundingmode(%arg0: f32, %arg1: f32) {
%2 = arith.mulf %arg0, %arg1 upward : f32
// CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
%3 = arith.divf %arg0, %arg1 toward_zero : f32
-// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
- %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
- %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+ %4 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
return
}
>From 2b7ce18d850d82050d022f9610b0249bc1e515ca Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 7 Apr 2026 14:03:56 +0000
Subject: [PATCH 7/7] skip canonicalization if rounding mode is set
---
mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 03200c16ac378..b822f1eadf0eb 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -437,21 +437,25 @@ def UIToFPOfExtUI :
//===----------------------------------------------------------------------===//
// mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags and rounding mode of original mulf)
+// TODO: Verify if this canonicalization is safe when a rounding mode is
+// specified. For the moment, bail on custom rounding modes.
def MulFOfNegF :
Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
(Arith_MulFOp $x, $y, $fmf, $rm),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+ [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y),
+ (Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rm)]>;
//===----------------------------------------------------------------------===//
// DivFOp
//===----------------------------------------------------------------------===//
// divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags and rounding mode of original divf)
+// TODO: Verify if this canonicalization is safe when a rounding mode is
+// specified. For the moment, bail on custom rounding modes.
def DivFOfNegF :
Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
(Arith_DivFOp $x, $y, $fmf, $rm),
- [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
+ [(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y),
+ (Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rm)]>;
#endif // ARITH_PATTERNS
More information about the Mlir-commits
mailing list