[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Propagate fast-math flags to SPIR-V FPFastMathMode decorations (PR #193414)
Arseniy Obolenskiy
llvmlistbot at llvm.org
Tue Apr 21 23:03:53 PDT 2026
https://github.com/aobolensk created https://github.com/llvm/llvm-project/pull/193414
Add ElementwiseFPOpPattern to convert arith fast-math flags (nnan, ninf, nsz, arcp) to SPIR-V FPFastMathMode decorations on floating-point ops
Unsupported (on SPIR-V side) flags (reassoc, contract, afn) are dropped
>From 60a199c1510b7926007b429f288f29f38487350d Mon Sep 17 00:00:00 2001
From: Arseniy Obolenskiy <arseniy.obolenskiy at amd.com>
Date: Wed, 22 Apr 2026 07:55:52 +0200
Subject: [PATCH] [mlir][ArithToSPIRV] Propagate fast-math flags to SPIR-V
FPFastMathMode decorations
Add ElementwiseFPOpPattern to convert arith fast-math flags (nnan, ninf, nsz, arcp) to SPIR-V FPFastMathMode decorations on floating-point ops
Unsupported (on SPIR-V side) flags (reassoc, contract, afn) are dropped
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 58 ++++++++++++++--
.../Conversion/ArithToSPIRV/fast-math.mlir | 69 +++++++++++++++++++
2 files changed, 121 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1b5a8728dd3f8..c2fe10fd7a8e0 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,21 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Converts arith fast-math flags to SPIR-V FPFastMathMode flags.
+static spirv::FPFastMathMode
+convertArithFastMathFlagsToSPIRV(arith::FastMathFlags arithFMF) {
+ spirv::FPFastMathMode spirvFMF = spirv::FPFastMathMode::None;
+ if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nnan))
+ spirvFMF = spirvFMF | spirv::FPFastMathMode::NotNaN;
+ if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::ninf))
+ spirvFMF = spirvFMF | spirv::FPFastMathMode::NotInf;
+ if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nsz))
+ spirvFMF = spirvFMF | spirv::FPFastMathMode::NSZ;
+ if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::arcp))
+ spirvFMF = spirvFMF | spirv::FPFastMathMode::AllowRecip;
+ return spirvFMF;
+}
+
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
@@ -225,6 +240,37 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
}
};
+/// Converts elementwise unary, binary and ternary floating-point arith
+/// operations to SPIR-V operations, propagating fast-math flags as
+/// FPFastMathMode decorations.
+template <typename Op, typename SPIRVOp>
+struct ElementwiseFPOpPattern final : OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() <= 3);
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert type {0} for SPIR-V", op.getType()));
+ }
+
+ auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
+ op, dstType, adaptor.getOperands());
+
+ auto spirvFMF = convertArithFastMathFlagsToSPIRV(op.getFastmath());
+ if (spirvFMF != spirv::FPFastMathMode::None) {
+ newOp->setAttr("fp_fast_math_mode",
+ spirv::FPFastMathModeAttr::get(op.getContext(), spirvFMF));
+ }
+
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
@@ -1530,12 +1576,12 @@ void mlir::arith::populateArithToSPIRVPatterns(
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
ShRSIBoolPattern, // shrsi(a,b) = a (identity; see pattern comment)
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
- spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
- spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
- spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
- spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
- spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
- spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
+ ElementwiseFPOpPattern<arith::NegFOp, spirv::FNegateOp>,
+ ElementwiseFPOpPattern<arith::AddFOp, spirv::FAddOp>,
+ ElementwiseFPOpPattern<arith::SubFOp, spirv::FSubOp>,
+ ElementwiseFPOpPattern<arith::MulFOp, spirv::FMulOp>,
+ ElementwiseFPOpPattern<arith::DivFOp, spirv::FDivOp>,
+ ElementwiseFPOpPattern<arith::RemFOp, spirv::FRemOp>,
ExtUIPattern, ExtUII1Pattern,
ExtSIPattern, ExtSII1Pattern,
TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9bbe28fb127a7..bd3788e8f7615 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -67,3 +67,72 @@ func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32
}
} // end module
+
+// -----
+
+// FPFastMathMode decoration tests (requires Kernel capability)
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @addf_fast_math
+func.func @addf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf>} : f32
+ %0 = arith.addf %arg0, %arg1 fastmath<nnan,ninf> : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @mulf_no_fast_math
+func.func @mulf_no_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FMul %{{.*}}, %{{.*}} : f32
+ // CHECK-NOT: fp_fast_math_mode
+ %0 = arith.mulf %arg0, %arg1 : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @subf_all_flags
+func.func @subf_all_flags(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FSub %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ|AllowRecip>} : f32
+ %0 = arith.subf %arg0, %arg1 fastmath<fast> : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @negf_fast_math
+func.func @negf_fast_math(%arg0 : f32) -> f32 {
+ // CHECK: spirv.FNegate %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NSZ>} : f32
+ %0 = arith.negf %arg0 fastmath<nsz> : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @divf_fast_math
+func.func @divf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FDiv %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<AllowRecip>} : f32
+ %0 = arith.divf %arg0, %arg1 fastmath<arcp> : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @remf_fast_math
+func.func @remf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FRem %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN>} : f32
+ %0 = arith.remf %arg0, %arg1 fastmath<nnan> : f32
+ return %0: f32
+}
+
+// Test that unsupported flags (reassoc, contract, afn) are silently dropped
+// CHECK-LABEL: @addf_unsupported_flags_only
+func.func @addf_unsupported_flags_only(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}} : f32
+ // CHECK-NOT: fp_fast_math_mode
+ %0 = arith.addf %arg0, %arg1 fastmath<reassoc,contract,afn> : f32
+ return %0: f32
+}
+
+// CHECK-LABEL: @addf_vector_fast_math
+func.func @addf_vector_fast_math(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN>} : vector<4xf32>
+ %0 = arith.addf %arg0, %arg1 fastmath<nnan> : vector<4xf32>
+ return %0: vector<4xf32>
+}
+
+} // end module
More information about the Mlir-commits
mailing list