[Mlir-commits] [mlir] [mlir][math] Add rounding modes to `math.fma` (PR #192839)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 19 02:49:01 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Rounding modes have recently been added for `arith` FP operations (#<!-- -->188458). This commit adds rounding modes to `math.fma`, following the same design as for `arith` FP operations.
If a rounding mode is present, the LLVM lowering produces `llvm.intr.experimental.constrained.fma`.
In the absence of a rounding mode, the rounding behavior is deferred to the target backend.
---
Full diff: https://github.com/llvm/llvm-project/pull/192839.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Math/IR/MathBase.td (+7)
- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+45-4)
- (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+37-2)
- (modified) mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir (+40)
- (modified) mlir/test/Dialect/Math/ops.mlir (+6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
index 0e606bb9b63dd..19fb39d9fd51d 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
@@ -28,6 +28,13 @@ def Math_Dialect : Dialect {
// Tensor elementwise absolute value.
%x = math.absf %y : tensor<4x?xf8>
```
+
+ Some floating-point operations may specify rounding modes and/or fast-math
+ flags. In the absence of an explicit rounding mode, the math dialect uses
+ this default round mode for internal purposes such as constant folding and
+ canonicalization: round-to-nearest, ties-to-even. The runtime behavior of
+ operations without an explicit rounding mode is deferred to the target
+ backend and may differ from the default math rounding mode.
}];
let hasConstantMaterializer = 1;
let dependentDialects = [
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..21e076307eb69 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -110,6 +110,43 @@ class Math_FloatTernaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Base class for floating point ternary operations with an optional rounding
+// mode.
+class Math_FloatTernaryOpWithRoundingMode<string mnemonic,
+ list<Trait> traits = []> :
+ Math_FloatTernaryOp<mnemonic,
+ !listconcat([DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+ traits)> {
+ let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath,
+ OptionalAttr<Arith_RoundingModeAttr>:$roundingmode);
+ let builders = [
+ OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+ build($_builder, $_state, a, b, c, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c,
+ "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+ build($_builder, $_state, a, b, c, fastmath,
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ OpBuilder<(ins "Type":$type, "Value":$a, "Value":$b, "Value":$c,
+ CArg<"::mlir::arith::FastMathFlags",
+ "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+ build($_builder, $_state, type, a, b, c,
+ ::mlir::arith::FastMathFlagsAttr::get(
+ $_builder.getContext(), fastmath),
+ ::mlir::arith::RoundingModeAttr{});
+ }]>,
+ ];
+ let assemblyFormat = [{ $a `,` $b `,` $c ($roundingmode^)?
+ (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($result) }];
+}
+
//===----------------------------------------------------------------------===//
// AbsFOp
//===----------------------------------------------------------------------===//
@@ -747,7 +784,7 @@ def Math_FloorOp : Math_FloatUnaryOp<"floor"> {
// FmaOp
//===----------------------------------------------------------------------===//
-def Math_FmaOp : Math_FloatTernaryOp<"fma"> {
+def Math_FmaOp : Math_FloatTernaryOpWithRoundingMode<"fma"> {
let summary = "floating point fused multipy-add operation";
let description = [{
The `fma` operation takes three operands and returns one result, each of
@@ -759,12 +796,16 @@ def Math_FmaOp : Math_FloatTernaryOp<"fma"> {
```mlir
// Scalar fused multiply-add: d = a*b + c
%d = math.fma %a, %b, %c : f64
+
+ // With an explicit IEEE-754 rounding mode.
+ %e = math.fma %a, %b, %c to_nearest_even : f64
```
The semantics of the operation correspond to those of the `llvm.fma`
- [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
- particular case of lowering to LLVM, this is guaranteed to lower
- to the `llvm.fma.*` intrinsic.
+ [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). When
+ no rounding mode is set, lowering to LLVM is guaranteed to produce the
+ `llvm.fma.*` intrinsic. When a rounding mode is set, the LLVM lowering
+ instead produces `llvm.experimental.constrained.fma`.
}];
}
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 76ee9d00b53ad..fa2ea13f348f3 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -37,6 +37,34 @@ using ConvertFMFMathToLLVMPattern =
VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
FailOnUnsupportedFP>;
+/// Lowering pattern that matches only when the source op's rounding mode
+/// presence agrees with `HasRoundingMode`. Mirrors the helper of the same
+/// name in `mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp`. This lets us
+/// register two patterns for one math op: an unconstrained one that lowers
+/// to a regular LLVM op, and a constrained one (rounding mode present) that
+/// lowers to an `llvm.intr.experimental.constrained.*` intrinsic.
+template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
+ template <typename, typename> typename AttrConvert =
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = true>
+struct ConstrainedVectorConvertToLLVMPattern
+ : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP> {
+ using VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
+ return failure();
+ return VectorConvertToLLVMPattern<
+ SourceOp, TargetOp, AttrConvert,
+ FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
+ }
+};
+
using AbsFOpLowering =
ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
/*FailOnUnsupportedFP=*/true>;
@@ -54,8 +82,14 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
- /*FailOnUnsupportedFP=*/true>;
+using FmaOpLowering =
+ ConstrainedVectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp,
+ /*HasRoundingMode=*/false,
+ ConvertFastMath,
+ /*FailOnUnsupportedFP=*/true>;
+using ConstrainedFmaOpLowering = ConstrainedVectorConvertToLLVMPattern<
+ math::FmaOp, LLVM::ConstrainedFMAIntr, /*HasRoundingMode=*/true,
+ arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
using Log10OpLowering =
ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -444,6 +478,7 @@ void mlir::populateMathToLLVMConversionPatterns(
FPowIOpLowering,
FloorOpLowering,
FmaOpLowering,
+ ConstrainedFmaOpLowering,
Log10OpLowering,
Log2OpLowering,
LogOpLowering,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 504dc1afb0eef..04a13616feee8 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -641,3 +641,43 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN
%2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
return
}
+
+// -----
+
+// CHECK-LABEL: func @experimental_constrained_fma
+func.func @experimental_constrained_fma(%a : f64, %b : f64, %c : f64) {
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearest ignore : f64
+ %0 = math.fma %a, %b, %c to_nearest_even : f64
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} downward ignore : f64
+ %1 = math.fma %a, %b, %c downward : f64
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} upward ignore : f64
+ %2 = math.fma %a, %b, %c upward : f64
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} towardzero ignore : f64
+ %3 = math.fma %a, %b, %c toward_zero : f64
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearestaway ignore : f64
+ %4 = math.fma %a, %b, %c to_nearest_away : f64
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @experimental_constrained_fma_vector
+func.func @experimental_constrained_fma_vector(%a : vector<4xf32>,
+ %b : vector<4xf32>,
+ %c : vector<4xf32>) {
+ // CHECK: llvm.intr.experimental.constrained.fma {{.*}} tonearest ignore : vector<4xf32>
+ %0 = math.fma %a, %b, %c to_nearest_even : vector<4xf32>
+ return
+}
+
+// -----
+
+// Constrained intrinsics never carry fastmath flags; the fastmath attribute
+// must be silently dropped during the lowering.
+// CHECK-LABEL: func @constrained_fma_with_fastmath
+func.func @constrained_fma_with_fastmath(%a : f64, %b : f64, %c : f64) {
+ // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearest ignore : f64
+ // CHECK-NOT: fastmath
+ %0 = math.fma %a, %b, %c to_nearest_even fastmath<fast> : f64
+ return
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index f085d1c62ea86..a21eb06e696df 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -304,6 +304,12 @@ func.func @fastmath(%f: f32, %i: i32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>)
%1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : vector<4xf32>
// CHECK: math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32>
%2 = math.fma %t, %t, %t fastmath<none> : tensor<4x4x?xf32>
+ // CHECK: math.fma %[[F]], %[[F]], %[[F]] to_nearest_even : f32
+ %2a = math.fma %f, %f, %f to_nearest_even : f32
+ // CHECK: math.fma %[[F]], %[[F]], %[[F]] downward fastmath<contract> : f32
+ %2b = math.fma %f, %f, %f downward fastmath<contract> : f32
+ // CHECK: math.fma %[[V]], %[[V]], %[[V]] toward_zero : vector<4xf32>
+ %2c = math.fma %v, %v, %v toward_zero : vector<4xf32>
// CHECK: math.absf %[[F]] fastmath<ninf> : f32
%3 = math.absf %f fastmath<ninf> : f32
// CHECK: math.fpowi %[[F]], %[[I]] fastmath<fast> : f32, i32
``````````
</details>
https://github.com/llvm/llvm-project/pull/192839
More information about the Mlir-commits
mailing list