[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Propagate fast-math flags to SPIR-V FPFastMathMode decorations (PR #193414)

Arseniy Obolenskiy llvmlistbot at llvm.org
Tue Apr 21 23:03:53 PDT 2026


https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/193414

Add ElementwiseFPOpPattern to convert arith fast-math flags (nnan, ninf, nsz, arcp) to SPIR-V FPFastMathMode decorations on floating-point ops

Unsupported (on SPIR-V side) flags (reassoc, contract, afn) are dropped

>From 60a199c1510b7926007b429f288f29f38487350d Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 22 Apr 2026 07:55:52 +0200
Subject: [PATCH] [mlir][ArithToSPIRV] Propagate fast-math flags to SPIR-V
 FPFastMathMode decorations

Add ElementwiseFPOpPattern to convert arith fast-math flags (nnan, ninf, nsz, arcp) to SPIR-V FPFastMathMode decorations on floating-point ops

Unsupported (on SPIR-V side) flags (reassoc, contract, afn) are dropped
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 58 ++++++++++++++--
 .../Conversion/ArithToSPIRV/fast-math.mlir    | 69 +++++++++++++++++++
 2 files changed, 121 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1b5a8728dd3f8..c2fe10fd7a8e0 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,21 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Converts arith fast-math flags to SPIR-V FPFastMathMode flags.
+static spirv::FPFastMathMode
+convertArithFastMathFlagsToSPIRV(arith::FastMathFlags arithFMF) {
+  spirv::FPFastMathMode spirvFMF = spirv::FPFastMathMode::None;
+  if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nnan))
+    spirvFMF = spirvFMF | spirv::FPFastMathMode::NotNaN;
+  if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::ninf))
+    spirvFMF = spirvFMF | spirv::FPFastMathMode::NotInf;
+  if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nsz))
+    spirvFMF = spirvFMF | spirv::FPFastMathMode::NSZ;
+  if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::arcp))
+    spirvFMF = spirvFMF | spirv::FPFastMathMode::AllowRecip;
+  return spirvFMF;
+}
+
 /// Creates a scalar/vector integer constant.
 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                        OpBuilder &builder, Location loc) {
@@ -225,6 +240,37 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
   }
 };
 
+/// Converts elementwise unary, binary and ternary floating-point arith
+/// operations to SPIR-V operations, propagating fast-math flags as
+/// FPFastMathMode decorations.
+template <typename Op, typename SPIRVOp>
+struct ElementwiseFPOpPattern final : OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() <= 3);
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+    }
+
+    auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+        op, dstType, adaptor.getOperands());
+
+    auto spirvFMF = convertArithFastMathFlagsToSPIRV(op.getFastmath());
+    if (spirvFMF != spirv::FPFastMathMode::None) {
+      newOp->setAttr("fp_fast_math_mode",
+                     spirv::FPFastMathModeAttr::get(op.getContext(), spirvFMF));
+    }
+
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//
@@ -1530,12 +1576,12 @@ void mlir::arith::populateArithToSPIRVPatterns(
     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
     ShRSIBoolPattern,                      // shrsi(a,b) = a (identity; see pattern comment)
     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
-    spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
-    spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
-    spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
-    spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
-    spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
-    spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
+    ElementwiseFPOpPattern<arith::NegFOp, spirv::FNegateOp>,
+    ElementwiseFPOpPattern<arith::AddFOp, spirv::FAddOp>,
+    ElementwiseFPOpPattern<arith::SubFOp, spirv::FSubOp>,
+    ElementwiseFPOpPattern<arith::MulFOp, spirv::FMulOp>,
+    ElementwiseFPOpPattern<arith::DivFOp, spirv::FDivOp>,
+    ElementwiseFPOpPattern<arith::RemFOp, spirv::FRemOp>,
     ExtUIPattern, ExtUII1Pattern,
     ExtSIPattern, ExtSII1Pattern,
     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9bbe28fb127a7..bd3788e8f7615 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -67,3 +67,72 @@ func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32
 }
 
 } // end module
+
+// -----
+
+// FPFastMathMode decoration tests (requires Kernel capability)
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @addf_fast_math
+func.func @addf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf>} : f32
+  %0 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @mulf_no_fast_math
+func.func @mulf_no_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FMul %{{.*}}, %{{.*}} : f32
+  // CHECK-NOT: fp_fast_math_mode
+  %0 = arith.mulf %arg0, %arg1 : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @subf_all_flags
+func.func @subf_all_flags(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FSub %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ|AllowRecip>} : f32
+  %0 = arith.subf %arg0, %arg1 fastmath<fast> : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @negf_fast_math
+func.func @negf_fast_math(%arg0 : f32) -> f32 {
+  // CHECK: spirv.FNegate %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NSZ>} : f32
+  %0 = arith.negf %arg0 fastmath<nsz> : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @divf_fast_math
+func.func @divf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FDiv %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<AllowRecip>} : f32
+  %0 = arith.divf %arg0, %arg1 fastmath<arcp> : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @remf_fast_math
+func.func @remf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FRem %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN>} : f32
+  %0 = arith.remf %arg0, %arg1 fastmath<nnan> : f32
+  return %0: f32
+}
+
+// Test that unsupported flags (reassoc, contract, afn) are silently dropped
+// CHECK-LABEL: @addf_unsupported_flags_only
+func.func @addf_unsupported_flags_only(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} : f32
+  // CHECK-NOT: fp_fast_math_mode
+  %0 = arith.addf %arg0, %arg1 fastmath<reassoc,contract,afn> : f32
+  return %0: f32
+}
+
+// CHECK-LABEL: @addf_vector_fast_math
+func.func @addf_vector_fast_math(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN>} : vector<4xf32>
+  %0 = arith.addf %arg0, %arg1 fastmath<nnan> : vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+} // end module



More information about the Mlir-commits mailing list