[Mlir-commits] [mlir] [mlir][math] Add rounding modes to `math.fma` (PR #192839)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 19 02:49:01 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Rounding modes have recently been added for `arith` FP operations (#<!-- -->188458). This commit adds rounding modes to `math.fma`, following the same design as for `arith` FP operations.

If a rounding mode is present, the LLVM lowering produces `llvm.intr.experimental.constrained.fma`.

In the absence of a rounding mode, the rounding behavior is deferred to the target backend.


---
Full diff: https://github.com/llvm/llvm-project/pull/192839.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Math/IR/MathBase.td (+7) 
- (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+45-4) 
- (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+37-2) 
- (modified) mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir (+40) 
- (modified) mlir/test/Dialect/Math/ops.mlir (+6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathBase.td b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
index 0e606bb9b63dd..19fb39d9fd51d 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathBase.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathBase.td
@@ -28,6 +28,13 @@ def Math_Dialect : Dialect {
     // Tensor elementwise absolute value.
     %x = math.absf %y : tensor<4x?xf8>
     ```
+
+    Some floating-point operations may specify rounding modes and/or fast-math
+    flags. In the absence of an explicit rounding mode, the math dialect uses
+    this default round mode for internal purposes such as constant folding and
+    canonicalization: round-to-nearest, ties-to-even. The runtime behavior of
+    operations without an explicit rounding mode is deferred to the target
+    backend and may differ from the default math rounding mode.
   }];
   let hasConstantMaterializer = 1;
   let dependentDialects = [
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 1265bfb18aaa2..21e076307eb69 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -110,6 +110,43 @@ class Math_FloatTernaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Base class for floating point ternary operations with an optional rounding
+// mode.
+class Math_FloatTernaryOpWithRoundingMode<string mnemonic,
+                                          list<Trait> traits = []> :
+    Math_FloatTernaryOp<mnemonic,
+      !listconcat([DeclareOpInterfaceMethods<ArithRoundingModeInterface>],
+                  traits)> {
+  let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath,
+      OptionalAttr<Arith_RoundingModeAttr>:$roundingmode);
+  let builders = [
+    OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, a, b, c, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+    OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c,
+      "::mlir::arith::FastMathFlagsAttr":$fastmath), [{
+      build($_builder, $_state, a, b, c, fastmath,
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+    OpBuilder<(ins "Type":$type, "Value":$a, "Value":$b, "Value":$c,
+      CArg<"::mlir::arith::FastMathFlags",
+           "::mlir::arith::FastMathFlags::none">:$fastmath), [{
+      build($_builder, $_state, type, a, b, c,
+            ::mlir::arith::FastMathFlagsAttr::get(
+                $_builder.getContext(), fastmath),
+            ::mlir::arith::RoundingModeAttr{});
+    }]>,
+  ];
+  let assemblyFormat = [{ $a `,` $b `,` $c ($roundingmode^)?
+                          (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($result) }];
+}
+
 //===----------------------------------------------------------------------===//
 // AbsFOp
 //===----------------------------------------------------------------------===//
@@ -747,7 +784,7 @@ def Math_FloorOp : Math_FloatUnaryOp<"floor"> {
 // FmaOp
 //===----------------------------------------------------------------------===//
 
-def Math_FmaOp : Math_FloatTernaryOp<"fma"> {
+def Math_FmaOp : Math_FloatTernaryOpWithRoundingMode<"fma"> {
   let summary = "floating point fused multipy-add operation";
   let description = [{
     The `fma` operation takes three operands and returns one result, each of
@@ -759,12 +796,16 @@ def Math_FmaOp : Math_FloatTernaryOp<"fma"> {
     ```mlir
     // Scalar fused multiply-add: d = a*b + c
     %d = math.fma %a, %b, %c : f64
+
+    // With an explicit IEEE-754 rounding mode.
+    %e = math.fma %a, %b, %c to_nearest_even : f64
     ```
 
     The semantics of the operation correspond to those of the `llvm.fma`
-    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
-    particular case of lowering to LLVM, this is guaranteed to lower
-    to the `llvm.fma.*` intrinsic.
+    [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). When
+    no rounding mode is set, lowering to LLVM is guaranteed to produce the
+    `llvm.fma.*` intrinsic. When a rounding mode is set, the LLVM lowering
+    instead produces `llvm.experimental.constrained.fma`.
   }];
 }
 
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 76ee9d00b53ad..fa2ea13f348f3 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -37,6 +37,34 @@ using ConvertFMFMathToLLVMPattern =
     VectorConvertToLLVMPattern<SourceOp, TargetOp, ConvertFastMath,
                                FailOnUnsupportedFP>;
 
+/// Lowering pattern that matches only when the source op's rounding mode
+/// presence agrees with `HasRoundingMode`. Mirrors the helper of the same
+/// name in `mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp`. This lets us
+/// register two patterns for one math op: an unconstrained one that lowers
+/// to a regular LLVM op, and a constrained one (rounding mode present) that
+/// lowers to an `llvm.intr.experimental.constrained.*` intrinsic.
+template <typename SourceOp, typename TargetOp, bool HasRoundingMode,
+          template <typename, typename> typename AttrConvert =
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = true>
+struct ConstrainedVectorConvertToLLVMPattern
+    : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert,
+                                        FailOnUnsupportedFP> {
+  using VectorConvertToLLVMPattern<
+      SourceOp, TargetOp, AttrConvert,
+      FailOnUnsupportedFP>::VectorConvertToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (HasRoundingMode != static_cast<bool>(op.getRoundingModeAttr()))
+      return failure();
+    return VectorConvertToLLVMPattern<
+        SourceOp, TargetOp, AttrConvert,
+        FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter);
+  }
+};
+
 using AbsFOpLowering =
     ConvertFMFMathToLLVMPattern<math::AbsFOp, LLVM::FAbsOp,
                                 /*FailOnUnsupportedFP=*/true>;
@@ -54,8 +82,14 @@ using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
 using ExpOpLowering = ConvertFMFMathToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
 using FloorOpLowering =
     ConvertFMFMathToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
-using FmaOpLowering = ConvertFMFMathToLLVMPattern<math::FmaOp, LLVM::FMAOp,
-                                                  /*FailOnUnsupportedFP=*/true>;
+using FmaOpLowering =
+    ConstrainedVectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp,
+                                          /*HasRoundingMode=*/false,
+                                          ConvertFastMath,
+                                          /*FailOnUnsupportedFP=*/true>;
+using ConstrainedFmaOpLowering = ConstrainedVectorConvertToLLVMPattern<
+    math::FmaOp, LLVM::ConstrainedFMAIntr, /*HasRoundingMode=*/true,
+    arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>;
 using Log10OpLowering =
     ConvertFMFMathToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
 using Log2OpLowering = ConvertFMFMathToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -444,6 +478,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     FPowIOpLowering,
     FloorOpLowering,
     FmaOpLowering,
+    ConstrainedFmaOpLowering,
     Log10OpLowering,
     Log2OpLowering,
     LogOpLowering,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 504dc1afb0eef..04a13616feee8 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -641,3 +641,43 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN
   %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @experimental_constrained_fma
+func.func @experimental_constrained_fma(%a : f64, %b : f64, %c : f64) {
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearest ignore : f64
+  %0 = math.fma %a, %b, %c to_nearest_even : f64
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} downward ignore : f64
+  %1 = math.fma %a, %b, %c downward : f64
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} upward ignore : f64
+  %2 = math.fma %a, %b, %c upward : f64
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} towardzero ignore : f64
+  %3 = math.fma %a, %b, %c toward_zero : f64
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearestaway ignore : f64
+  %4 = math.fma %a, %b, %c to_nearest_away : f64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @experimental_constrained_fma_vector
+func.func @experimental_constrained_fma_vector(%a : vector<4xf32>,
+                                                %b : vector<4xf32>,
+                                                %c : vector<4xf32>) {
+  // CHECK: llvm.intr.experimental.constrained.fma {{.*}} tonearest ignore : vector<4xf32>
+  %0 = math.fma %a, %b, %c to_nearest_even : vector<4xf32>
+  return
+}
+
+// -----
+
+// Constrained intrinsics never carry fastmath flags; the fastmath attribute
+// must be silently dropped during the lowering.
+// CHECK-LABEL: func @constrained_fma_with_fastmath
+func.func @constrained_fma_with_fastmath(%a : f64, %b : f64, %c : f64) {
+  // CHECK-NEXT: llvm.intr.experimental.constrained.fma %{{.*}}, %{{.*}}, %{{.*}} tonearest ignore : f64
+  // CHECK-NOT: fastmath
+  %0 = math.fma %a, %b, %c to_nearest_even fastmath<fast> : f64
+  return
+}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index f085d1c62ea86..a21eb06e696df 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -304,6 +304,12 @@ func.func @fastmath(%f: f32, %i: i32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>)
   %1 = math.powf %v, %v fastmath<reassoc,nnan,ninf,nsz,arcp,contract,afn> : vector<4xf32>
   // CHECK: math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32>
   %2 = math.fma %t, %t, %t fastmath<none> : tensor<4x4x?xf32>
+  // CHECK: math.fma %[[F]], %[[F]], %[[F]] to_nearest_even : f32
+  %2a = math.fma %f, %f, %f to_nearest_even : f32
+  // CHECK: math.fma %[[F]], %[[F]], %[[F]] downward fastmath<contract> : f32
+  %2b = math.fma %f, %f, %f downward fastmath<contract> : f32
+  // CHECK: math.fma %[[V]], %[[V]], %[[V]] toward_zero : vector<4xf32>
+  %2c = math.fma %v, %v, %v toward_zero : vector<4xf32>
   // CHECK: math.absf %[[F]] fastmath<ninf> : f32
   %3 = math.absf %f fastmath<ninf> : f32
   // CHECK: math.fpowi %[[F]], %[[I]] fastmath<fast> : f32, i32

``````````

</details>


https://github.com/llvm/llvm-project/pull/192839


More information about the Mlir-commits mailing list