[Mlir-commits] [mlir] [mlir][arith] Add rounding mode flags to binary arithmetic operations (PR #188458)
Matthias Springer
llvmlistbot at llvm.org
Wed Mar 25 06:06:18 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/188458
>From 306b777f87647a8a60cdd1729f049f929510f148 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 10:33:39 +0000
Subject: [PATCH 1/3] [mlir][arith] Add rounding mode flags to binary
arithmetic operations
---
.../ArithCommon/AttrToLLVMConverter.h | 5 ++
.../include/mlir/Dialect/Arith/IR/ArithOps.td | 90 ++++++++++++++++---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 55 ++++++++----
.../ComplexToStandard/ComplexToStandard.cpp | 8 +-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 9 +-
.../Dialect/Arith/IR/ArithCanonicalization.td | 12 +--
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 ++++++--
.../lib/Dialect/Math/Transforms/ExpandOps.cpp | 18 ++--
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 61 +++++++++++++
mlir/test/Dialect/Arith/canonicalize.mlir | 75 ++++++++++++++++
mlir/test/Dialect/Arith/ops.mlir | 17 ++++
11 files changed, 325 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index fccfe4897114e..c0773c9b69b6b 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -157,6 +157,11 @@ class AttrConverterConstrainedFPToLLVM {
convertedAttr.set(TargetOp::getRoundingModeAttrName(),
convertArithRoundingModeAttrToLLVM(arithAttr));
}
+ // Constrained intrinsics do not support fastmath flags. Remove the
+ // arith fastmath attribute if present.
+ if constexpr (SourceOp::template hasTrait<
+ arith::ArithFastMathInterface::Trait>())
+ convertedAttr.erase(srcOp.getFastMathAttrName());
convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
}
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 45cb3cecef3d8..864c947c005fa 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -90,6 +90,37 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Base class for floating point binary operations with an optional rounding
+// mode.
+class Arith_FloatBinaryOpWithRoundingMode<string mnemonic,
+ list<Trait> traits = []> :
+ Arith_BinaryOp<mnemonic,
+ !listconcat([Pure, DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+ traits)>,
+ Arguments<(ins FloatLike:$lhs, FloatLike:$rhs,
+ DefaultValuedAttr<
+ Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
+ Results<(outs FloatLike:$result)> {
+ let builders = [
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+ build($_builder, $_state, lhs, rhs, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ OpBuilder<(ins "Value":$lhs, "Value":$rhs,
+ "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+ build($_builder, $_state, lhs, rhs, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ ];
+ let assemblyFormat = [{ $lhs `,` $rhs ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
+}
+
// Checks that tensor input and outputs have identical shapes. This is stricker
// than the verification done in `SameOperandsAndResultShape` that allows for
// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
@@ -957,7 +988,7 @@ def Arith_NegFOp : Arith_FloatUnaryOp<"negf"> {
// AddFOp
//===----------------------------------------------------------------------===//
-def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
+def Arith_AddFOp : Arith_FloatBinaryOpWithRoundingMode<"addf", [Commutative]> {
let summary = "floating point addition operation";
let description = [{
The `addf` operation takes two operands and returns one result, each of
@@ -965,6 +996,9 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -976,10 +1010,10 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
// Tensor addition.
%x = arith.addf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar addition with rounding mode.
+ %a = arith.addf %b, %c to_nearest_even : f64
+ ```
}];
let hasFolder = 1;
}
@@ -988,7 +1022,7 @@ def Arith_AddFOp : Arith_FloatBinaryOp<"addf", [Commutative]> {
// SubFOp
//===----------------------------------------------------------------------===//
-def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
+def Arith_SubFOp : Arith_FloatBinaryOpWithRoundingMode<"subf"> {
let summary = "floating point subtraction operation";
let description = [{
The `subf` operation takes two operands and returns one result, each of
@@ -996,6 +1030,9 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -1007,10 +1044,10 @@ def Arith_SubFOp : Arith_FloatBinaryOp<"subf"> {
// Tensor subtraction.
%x = arith.subf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar subtraction with rounding mode.
+ %a = arith.subf %b, %c downward : f64
+ ```
}];
let hasFolder = 1;
}
@@ -1139,7 +1176,7 @@ def Arith_MinUIOp : Arith_TotalIntBinaryOp<"minui", [Commutative]> {
// MulFOp
//===----------------------------------------------------------------------===//
-def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
+def Arith_MulFOp : Arith_FloatBinaryOpWithRoundingMode<"mulf", [Commutative]> {
let summary = "floating point multiplication operation";
let description = [{
The `mulf` operation takes two operands and returns one result, each of
@@ -1147,6 +1184,9 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
scalar type, a vector whose element type is a floating point type, or a
floating point tensor.
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
Example:
```mlir
@@ -1158,10 +1198,10 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
// Tensor pointwise multiplication.
%x = arith.mulf %y, %z : tensor<4x?xbf16>
- ```
- TODO: In the distant future, this will accept optional attributes for fast
- math, contraction, rounding mode, and other controls.
+ // Scalar multiplication with rounding mode.
+ %a = arith.mulf %b, %c upward : f64
+ ```
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
@@ -1171,8 +1211,27 @@ def Arith_MulFOp : Arith_FloatBinaryOp<"mulf", [Commutative]> {
// DivFOp
//===----------------------------------------------------------------------===//
-def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
+def Arith_DivFOp : Arith_FloatBinaryOpWithRoundingMode<"divf"> {
let summary = "floating point division operation";
+ let description = [{
+ The `divf` operation takes two operands and returns one result, each of
+ these is required to be the same type. This type may be a floating point
+ scalar type, a vector whose element type is a floating point type, or a
+ floating point tensor.
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
+
+ Example:
+
+ ```mlir
+ // Scalar division.
+ %a = arith.divf %b, %c : f64
+
+ // Scalar division with rounding mode.
+ %a = arith.divf %b, %c toward_zero : f64
+ ```
+ }];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
@@ -1181,11 +1240,14 @@ def Arith_DivFOp : Arith_FloatBinaryOp<"divf"> {
// RemFOp
//===----------------------------------------------------------------------===//
-def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> {
+def Arith_RemFOp : Arith_FloatBinaryOpWithRoundingMode<"remf"> {
let summary = "floating point division remainder operation";
let description = [{
Returns the floating point division remainder.
The remainder has the same sign as the dividend (lhs operand).
+
+ If the value cannot be exactly represented, it is rounded using the
+ provided rounding mode or the default one if no rounding mode is provided.
}];
let hasFolder = 1;
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index a0346ec6f4fb6..9aba2b42926ce 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -81,9 +81,13 @@ struct IdentityBitcastLowering final
//===----------------------------------------------------------------------===//
using AddFOpLowering =
- VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
arith::AttrConvertOverflowToLLVM>;
@@ -91,9 +95,13 @@ using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
- VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
using DivUIOpLowering =
@@ -139,9 +147,13 @@ using MinSIOpLowering =
using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
- VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
arith::AttrConvertOverflowToLLVM>;
@@ -151,9 +163,13 @@ using NegFOpLowering =
/*FailOnUnsupportedFP=*/true>;
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
- VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
using RemUIOpLowering =
@@ -170,9 +186,13 @@ using ShRUIOpLowering =
using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
- VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- arith::AttrConvertFastMathToLLVM,
- /*FailOnUnsupportedFP=*/true>;
+ ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
+ /*Constrained=*/false,
+ arith::AttrConvertFastMathToLLVM,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
@@ -690,6 +710,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
// clang-format off
patterns.add<
AddFOpLowering,
+ ConstrainedAddFOpLowering,
AddIOpLowering,
AndIOpLowering,
AddUIExtendedOpLowering,
@@ -698,6 +719,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
CmpFOpLowering,
CmpIOpLowering,
DivFOpLowering,
+ ConstrainedDivFOpLowering,
DivSIOpLowering,
DivUIOpLowering,
ExtFOpLowering,
@@ -717,12 +739,14 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
MinSIOpLowering,
MinUIOpLowering,
MulFOpLowering,
+ ConstrainedMulFOpLowering,
MulIOpLowering,
MulSIExtendedOpLowering,
MulUIExtendedOpLowering,
NegFOpLowering,
OrIOpLowering,
RemFOpLowering,
+ ConstrainedRemFOpLowering,
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
@@ -732,6 +756,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
ShRUIOpLowering,
SIToFPOpLowering,
SubFOpLowering,
+ ConstrainedSubFOpLowering,
SubIOpLowering,
TruncFOpLowering,
ConstrainedTruncFOpLowering,
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 9e46b7d78baca..b899220f2e9af 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -182,12 +182,12 @@ struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
- Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
- realRhs, fmf.getValue());
+ Value resultReal =
+ BinaryStandardOp::create(b, realLhs, realRhs, fmf.getValue());
Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
- Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
- imagRhs, fmf.getValue());
+ Value resultImag =
+ BinaryStandardOp::create(b, imagLhs, imagRhs, fmf.getValue());
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..11b3aabcbfeb4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -120,7 +120,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
auto one =
arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
- return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
+ return arith::DivFOp::create(rewriter, loc, one, args[0]);
}
// tosa::MulOp
@@ -140,8 +140,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
"Cannot have shift value for float");
return nullptr;
}
- return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
- args[1]);
+ return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
}
if (isa<IntegerType>(elementTy)) {
@@ -538,8 +537,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
- auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
- return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
+ auto added = arith::AddFOp::create(rewriter, loc, exp, one);
+ return arith::DivFOp::create(rewriter, loc, one, added);
}
// tosa::CastOp
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index e22fc1d478e4f..488dac54569d5 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -449,10 +449,10 @@ def UIToFPOfExtUI :
//===----------------------------------------------------------------------===//
// mulf(negf(x), negf(y)) -> mulf(x,y)
-// (retain fastmath flags of original mulf)
+// (retain fastmath flags and rounding mode of original mulf)
def MulFOfNegF :
- Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_MulFOp $x, $y, $fmf),
+ Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+ (Arith_MulFOp $x, $y, $fmf, $rm),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
//===----------------------------------------------------------------------===//
@@ -460,10 +460,10 @@ def MulFOfNegF :
//===----------------------------------------------------------------------===//
// divf(negf(x), negf(y)) -> divf(x,y)
-// (retain fastmath flags of original divf)
+// (retain fastmath flags and rounding mode of original divf)
def DivFOfNegF :
- Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf),
- (Arith_DivFOp $x, $y, $fmf),
+ Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf, $rm),
+ (Arith_DivFOp $x, $y, $fmf, $rm),
[(Constraint<CPred<"$0.getType() == $1.getType()">> $x, $y)]>;
#endif // ARITH_PATTERNS
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 5f10a94522350..b00ae7bfc4724 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1107,9 +1107,14 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a + b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1121,9 +1126,14 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a - b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
//===----------------------------------------------------------------------===//
@@ -1312,9 +1322,14 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
return getRhs();
}
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a * b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
@@ -1331,9 +1346,14 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();
+ auto rm = getRoundingmodeAttr();
return constFoldBinaryOp<FloatAttr>(
- adaptor.getOperands(),
- [](const APFloat &a, const APFloat &b) { return a / b; });
+ adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
+ APFloat result(a);
+ result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
+ : llvm::RoundingMode::NearestTiesToEven);
+ return result;
+ });
}
void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index f76ddfae2a67a..ec56ab3ab42e5 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -150,7 +150,7 @@ static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
Type type = operand.getType();
Value sin = math::SinOp::create(b, type, operand);
Value cos = math::CosOp::create(b, type, operand);
- Value div = arith::DivFOp::create(b, type, sin, cos);
+ Value div = arith::DivFOp::create(b, sin, cos);
rewriter.replaceOp(op, div);
return success();
}
@@ -212,8 +212,8 @@ static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
Value operandB = op.getOperand(1);
Value operandC = op.getOperand(2);
Type type = op.getType();
- Value mult = arith::MulFOp::create(b, type, operandA, operandB);
- Value add = arith::AddFOp::create(b, type, mult, operandC);
+ Value mult = arith::MulFOp::create(b, operandA, operandB);
+ Value add = arith::AddFOp::create(b, mult, operandC);
rewriter.replaceOp(op, add);
return success();
}
@@ -289,7 +289,7 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value add = arith::AddFOp::create(b, fpFixedConvert, incrValue);
Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
@@ -331,9 +331,9 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
while (absPower > 0) {
if (absPower & 1)
- res = arith::MulFOp::create(b, baseType, base, res);
+ res = arith::MulFOp::create(b, base, res);
absPower >>= 1;
- base = arith::MulFOp::create(b, baseType, base, base);
+ base = arith::MulFOp::create(b, base, base);
}
// Make sure not to introduce UB in case of negative power.
@@ -356,7 +356,7 @@ static LogicalResult convertFPowIOp(math::FPowIOp op,
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
Value negZeroEqCheck =
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
- res = arith::DivFOp::create(b, baseType, one, res);
+ res = arith::DivFOp::create(b, one, res);
res =
arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
@@ -450,7 +450,7 @@ static LogicalResult convertExp2fOp(math::Exp2Op op,
Value operand = op.getOperand();
Type opType = operand.getType();
Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
- Value mult = arith::MulFOp::create(b, opType, operand, ln2);
+ Value mult = arith::MulFOp::create(b, operand, ln2);
Value exp = math::ExpOp::create(b, op->getLoc(), mult);
rewriter.replaceOp(op, exp);
return success();
@@ -478,7 +478,7 @@ static LogicalResult convertRoundOp(math::RoundOp op,
Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
Value incrValue = math::CopySignOp::create(b, half, operand);
- Value add = arith::AddFOp::create(b, opType, operand, incrValue);
+ Value add = arith::AddFOp::create(b, operand, incrValue);
Value fpFixedConvert = createTruncatedFPValue(add, b);
// There are three cases where adding 0.5 to the value and truncating by
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 6a6016c4f5b16..7b952cf3aadfc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -377,6 +377,67 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// -----
+// CHECK-LABEL: experimental_constrained_addf
+func.func @experimental_constrained_addf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore
+ %0 = arith.addf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 downward ignore
+ %1 = arith.addf %arg0, %arg1 downward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 upward ignore
+ %2 = arith.addf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 towardzero ignore
+ %3 = arith.addf %arg0, %arg1 toward_zero : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearestaway ignore
+ %4 = arith.addf %arg0, %arg1 to_nearest_away : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_subf
+func.func @experimental_constrained_subf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 tonearest ignore
+ %0 = arith.subf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fsub %arg0, %arg1 downward ignore
+ %1 = arith.subf %arg0, %arg1 downward : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_mulf
+func.func @experimental_constrained_mulf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 upward ignore
+ %0 = arith.mulf %arg0, %arg1 upward : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fmul %arg0, %arg1 towardzero ignore
+ %1 = arith.mulf %arg0, %arg1 toward_zero : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_divf
+func.func @experimental_constrained_divf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearest ignore
+ %0 = arith.divf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fdiv %arg0, %arg1 tonearestaway ignore
+ %1 = arith.divf %arg0, %arg1 to_nearest_away : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: experimental_constrained_remf
+func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 tonearest ignore
+ %0 = arith.remf %arg0, %arg1 to_nearest_even : f64
+// CHECK-NEXT: = llvm.intr.experimental.constrained.frem %arg0, %arg1 downward ignore
+ %1 = arith.remf %arg0, %arg1 downward : f64
+ return
+}
+
+// -----
+
// CHECK-LABEL: @convertf_f16_to_bf16
func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 18665e2eb6f4a..1e10be80a13cc 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2447,6 +2447,81 @@ func.func @test_divf1(%arg0 : f32, %arg1 : f32) -> (f32) {
// -----
+// Verify that constant folding respects rounding modes. 1.0000001 + 1.0 is not
+// exactly representable in f32. With upward rounding, the result is rounded up,
+// and with downward rounding it is rounded down.
+// CHECK-LABEL: @test_addf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_addf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 2.00000024 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant 1.0 : f32
+ // addf(x, -0) folds even with a rounding mode.
+ %c_neg0 = arith.constant -0.0 : f32
+ %0 = arith.addf %arg0, %c_neg0 to_nearest_even : f32
+ %1 = arith.addf %a, %b upward : f32
+ %2 = arith.addf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_subf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_subf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 2.00000024 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 2.000000e+00 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant -1.0 : f32
+ // subf(x, +0) folds even with a rounding mode.
+ %c0 = arith.constant 0.0 : f32
+ %0 = arith.subf %arg0, %c0 downward : f32
+ %1 = arith.subf %a, %b upward : f32
+ %2 = arith.subf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_mulf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_mulf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 3.00000048 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 3.00000024 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0000001 : f32
+ %b = arith.constant 3.0 : f32
+ // mulf(x, 1) folds even with a rounding mode.
+ %c1 = arith.constant 1.0 : f32
+ %0 = arith.mulf %arg0, %c1 upward : f32
+ %1 = arith.mulf %a, %b upward : f32
+ %2 = arith.mulf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @test_divf_rounding_mode(
+// CHECK-SAME: %[[ARG0:.+]]: f32
+func.func @test_divf_rounding_mode(%arg0 : f32) -> (f32, f32, f32) {
+ // CHECK-DAG: %[[UP:.+]] = arith.constant 0.333333343 : f32
+ // CHECK-DAG: %[[DOWN:.+]] = arith.constant 0.333333313 : f32
+ // CHECK-NEXT: return %[[ARG0]], %[[UP]], %[[DOWN]]
+ %a = arith.constant 1.0 : f32
+ %b = arith.constant 3.0 : f32
+ // divf(x, 1) folds even with a rounding mode.
+ %c1 = arith.constant 1.0 : f32
+ %0 = arith.divf %arg0, %c1 toward_zero : f32
+ %1 = arith.divf %a, %b upward : f32
+ %2 = arith.divf %a, %b downward : f32
+ return %0, %1, %2 : f32, f32, f32
+}
+
+// -----
+
func.func @fold_divui_of_muli_0(%arg0 : index, %arg1 : index) -> index {
%0 = arith.muli %arg0, %arg1 overflow<nuw> : index
%1 = arith.divui %0, %arg0 : index
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 2c5371de9ff24..f6e21c94969f2 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1234,6 +1234,23 @@ func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) {
return
}
+// CHECK-LABEL: @roundingmode
+func.func @roundingmode(%arg0: f32, %arg1: f32) {
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even : f32
+ %0 = arith.addf %arg0, %arg1 to_nearest_even : f32
+// CHECK: {{.*}} = arith.subf %arg0, %arg1 downward : f32
+ %1 = arith.subf %arg0, %arg1 downward : f32
+// CHECK: {{.*}} = arith.mulf %arg0, %arg1 upward : f32
+ %2 = arith.mulf %arg0, %arg1 upward : f32
+// CHECK: {{.*}} = arith.divf %arg0, %arg1 toward_zero : f32
+ %3 = arith.divf %arg0, %arg1 toward_zero : f32
+// CHECK: {{.*}} = arith.remf %arg0, %arg1 to_nearest_away : f32
+ %4 = arith.remf %arg0, %arg1 to_nearest_away : f32
+// CHECK: {{.*}} = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+ %5 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f32
+ return
+}
+
// CHECK-LABEL: @select_tensor
func.func @select_tensor(%arg0 : tensor<8xi1>, %arg1 : tensor<8xi32>, %arg2 : tensor<8xi32>) -> tensor<8xi32> {
// CHECK: = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xi1>, tensor<8xi32>
>From 728fd3c28ca7da846ff5eb082831b4e1de641c12 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 12:56:53 +0000
Subject: [PATCH 2/3] address comments
---
.../Conversion/ArithToLLVM/ArithToLLVM.cpp | 46 +++++++++++--------
.../Conversion/ArithToLLVM/arith-to-llvm.mlir | 12 +++++
2 files changed, 39 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 9aba2b42926ce..0c09da76c2694 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -29,13 +29,20 @@ using namespace mlir;
namespace {
-/// Operations whose conversion will depend on whether they are passed a
-/// rounding mode attribute or not.
+/// Lowering pattern that matches only when the source op's rounding mode
+/// presence agrees with `HasRoundingMode`. This allows registering two
+/// instances of the same pattern for one source op: one that handles the
+/// unconstrained case (no rounding mode, lowering to a regular LLVM op) and
+/// one that handles the constrained case (rounding mode present, lowering to
+/// a constrained LLVM intrinsic).
///
-/// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
-/// to; `AttrConvert` is the attribute conversion to convert the rounding mode
-/// attribute.
-template <typename SourceOp, typename TargetOp, bool Constrained,
+/// * `HasRoundingMode`: the pattern matches if and only if the source op has
+/// a rounding mode attribute.
+/// * `AttrConvert`: attribute converter to translate source attributes to
+/// target attributes.
+/// * `FailOnUnsupportedFP`: whether to fail if the source op has unsupported
+/// floating point types.
+template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough,
bool FailOnUnsupportedFP = false>
@@ -49,7 +56,7 @@ struct ConstrainedVectorConvertToLLVMPattern
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
+ if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
return failure();
return VectorConvertToLLVMPattern<
SourceOp, TargetOp, AttrConvert,
@@ -82,11 +89,11 @@ struct IdentityBitcastLowering final
using AddFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedAddFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::AddFOp, LLVM::ConstrainedFAddIntr, /*Constrained=*/true,
+ arith::AddFOp, LLVM::ConstrainedFAddIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using AddIOpLowering =
VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
@@ -96,11 +103,11 @@ using BitcastOpLowering =
VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
using DivFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedDivFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::DivFOp, LLVM::ConstrainedFDivIntr, /*Constrained=*/true,
+ arith::DivFOp, LLVM::ConstrainedFDivIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using DivSIOpLowering =
VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
@@ -148,11 +155,11 @@ using MinUIOpLowering =
VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
using MulFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedMulFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::MulFOp, LLVM::ConstrainedFMulIntr, /*Constrained=*/true,
+ arith::MulFOp, LLVM::ConstrainedFMulIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using MulIOpLowering =
VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
@@ -164,11 +171,11 @@ using NegFOpLowering =
using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
using RemFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedRemFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::RemFOp, LLVM::ConstrainedFRemIntr, /*Constrained=*/true,
+ arith::RemFOp, LLVM::ConstrainedFRemIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using RemSIOpLowering =
VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
@@ -187,21 +194,22 @@ using SIToFPOpLowering =
VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
using SubFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
- /*Constrained=*/false,
+ /*HasRoundingMode=*/false,
arith::AttrConvertFastMathToLLVM,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedSubFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::SubFOp, LLVM::ConstrainedFSubIntr, /*Constrained=*/true,
+ arith::SubFOp, LLVM::ConstrainedFSubIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using SubIOpLowering =
VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
arith::AttrConvertOverflowToLLVM>;
using TruncFOpLowering =
ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
- false, AttrConvertPassThrough,
+ /*HasRoundingMode=*/false,
+ AttrConvertPassThrough,
/*FailOnUnsupportedFP=*/true>;
using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
- arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
+ arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, /*HasRoundingMode=*/true,
arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7b952cf3aadfc..ba6c4e41387e3 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -438,6 +438,18 @@ func.func @experimental_constrained_remf(%arg0 : f64, %arg1 : f64) {
// -----
+// Verify that fastmath flags are stripped when lowering to constrained
+// intrinsics (constrained FP and fastmath are contradictory).
+// CHECK-LABEL: constrained_addf_with_fastmath
+func.func @constrained_addf_with_fastmath(%arg0 : f64, %arg1 : f64) {
+// CHECK-NEXT: = llvm.intr.experimental.constrained.fadd %arg0, %arg1 tonearest ignore : f64
+// CHECK-NOT: fastmath
+ %0 = arith.addf %arg0, %arg1 to_nearest_even fastmath<fast> : f64
+ return
+}
+
+// -----
+
// CHECK-LABEL: @convertf_f16_to_bf16
func.func @convertf_f16_to_bf16(%arg0 : f16) -> bf16 {
// CHECK-NEXT: %[[EXT:.*]] = llvm.fpext %arg0 : f16 to f32
>From c5d367ec6c542a4f22e2168fabaf5ff4793dc417 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 25 Mar 2026 13:05:58 +0000
Subject: [PATCH 3/3] update
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 19 ++++++++++++-------
1 file changed, 12 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index b00ae7bfc4724..3f65d3f8ea3ec 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -34,6 +34,10 @@
using namespace mlir;
using namespace mlir::arith;
+/// Default rounding mode according to IEEE-754.
+static constexpr llvm::RoundingMode kDefaultRoundingMode =
+ llvm::RoundingMode::NearestTiesToEven;
+
//===----------------------------------------------------------------------===//
// Pattern helpers
//===----------------------------------------------------------------------===//
@@ -1112,7 +1116,7 @@ OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.add(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1131,7 +1135,7 @@ OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.subtract(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1327,7 +1331,7 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.multiply(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1351,7 +1355,7 @@ OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
adaptor.getOperands(), [rm](const APFloat &a, const APFloat &b) {
APFloat result(a);
result.divide(b, rm ? convertArithRoundingModeToLLVMIR(rm.getValue())
- : llvm::RoundingMode::NearestTiesToEven);
+ : kDefaultRoundingMode);
return result;
});
}
@@ -1481,9 +1485,10 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
/// Attempts to convert `sourceValue` to an APFloat value with
/// `targetSemantics` and `roundingMode`, without any information loss.
-static FailureOr<APFloat> convertFloatValue(
- APFloat sourceValue, const llvm::fltSemantics &targetSemantics,
- llvm::RoundingMode roundingMode = llvm::RoundingMode::NearestTiesToEven) {
+static FailureOr<APFloat>
+convertFloatValue(APFloat sourceValue,
+ const llvm::fltSemantics &targetSemantics,
+ llvm::RoundingMode roundingMode = kDefaultRoundingMode) {
// Reject special values that are not representable in the target type before
// calling APFloat::convert, which would llvm_unreachable on them.
using fltNonfiniteBehavior = llvm::fltNonfiniteBehavior;
More information about the Mlir-commits
mailing list