[Mlir-commits] [mlir] 81e0de2 - [MLIR][Arith] FastMath extf conversion without NaN checks (#180926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 11 04:46:12 PST 2026


Author: Renato Golin
Date: 2026-02-11T12:46:07Z
New Revision: 81e0de2ea538991a759c921a5d257b271d7d0f7c

URL: https://github.com/llvm/llvm-project/commit/81e0de2ea538991a759c921a5d257b271d7d0f7c
DIFF: https://github.com/llvm/llvm-project/commit/81e0de2ea538991a759c921a5d257b271d7d0f7c.diff

LOG: [MLIR][Arith] FastMath extf conversion without NaN checks (#180926)

This PR allows the expand op converter to consider the NoNaN fastmath
attribute to disable the runtime checks for NaNs in E8M0 types. Default
behaviour is still the same.

The OCP document provides all-ones as NaN for E8M0, but for pre-MX I8
quantization, the checks for NaNs are prohibitively expensive,
especially if the hardware doesn't have native support for that type.

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Arith/expand-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index c4e81e5dbed21..46f8c1037d47b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -452,18 +452,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
     Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
 
     Value bitcast = arith::BitcastOp::create(b, i8Ty, operand);
-    // create constants for NaNs
-    Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
-    Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
     Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
-
     Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast);
     Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth);
 
-    Value isNan =
-        arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
-    // select for NaNs
-    f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    // If FastMathFlag allows no NaN checks, skip it
+    auto fastMath = op.getFastmathAttr();
+    bool NoNaN = fastMath
+                     ? (fastMath.getValue() & arith::FastMathFlags::nnan) ==
+                           arith::FastMathFlags::nnan
+                     : false;
+    if (!NoNaN) {
+      Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+      Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
+      Value isNan =
+          arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN);
+      // select for NaNs
+      f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits);
+    }
+
     Value result = arith::BitcastOp::create(b, f32Ty, f32Bits);
     if (resultETy.getIntOrFloatBitWidth() < 32) {
       result = arith::TruncFOp::create(b, resultTy, result, nullptr,

diff  --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 61e22af31f030..75c4de2168761 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -383,11 +383,11 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f32
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
-// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
@@ -395,6 +395,21 @@ func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
 
 // -----
 
+func.func @extf_f8E8M0FNU_to_f32_no_nan(%arg0 : f8E8M0FNU) -> f32 {
+    %0 = arith.extf %arg0 fastmath<nnan> : f8E8M0FNU to f32
+    return %0 : f32
+}
+
+// CHECK-LABEL: @extf_f8E8M0FNU_to_f32_no_nan
+// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
+// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
+// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SHLI]] : i32 to f32
+// CHECK: return %[[RESULT]]
+
+// -----
+
 func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
     %0 = arith.extf %arg0 : f8E8M0FNU to f16
     return %0 : f16
@@ -402,11 +417,11 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
 
 // CHECK-LABEL: @extf_f8E8M0FNU_to_f16
 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
-// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
-// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
 // CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
 // CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
+// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
 // CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
 // CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
 // CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32


        


More information about the Mlir-commits mailing list