[Mlir-commits] [mlir] 1bc5885 - [MLIR][SPIRV] Add spirv.IsFinite and lower math.{isfinite, isinf, isnan} to spirv. (#151552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 31 10:54:17 PDT 2025


Author: Xiaolei Feng
Date: 2025-07-31T13:54:14-04:00
New Revision: 1bc58851868ad1f8ac6313d9f2337ec827b85019

URL: https://github.com/llvm/llvm-project/commit/1bc58851868ad1f8ac6313d9f2337ec827b85019
DIFF: https://github.com/llvm/llvm-project/commit/1bc58851868ad1f8ac6313d9f2337ec827b85019.diff

LOG: [MLIR][SPIRV] Add spirv.IsFinite and lower math.{isfinite,isinf,isnan} to spirv. (#151552)

This patch adds support for lowering several float classification ops
from the Math dialect to the SPIR-V dialect.

### Highlights:
- Introduced a new `spirv.IsFinite` operation corresponding to the
SPIR-V `OpIsFinite` instruction.
- Lowered `math.isfinite`, `math.isinf`, and `math.isnan` to SPIR-V
using `CheckedElementwiseOpPattern`.
- Added corresponding tests for op definition and conversion lowering.

This addresses the discussion in:
https://github.com/llvm/llvm-project/issues/150778

---

Let me know if any additional adjustments are needed!

---------

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
    mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
    mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
    mlir/test/Target/SPIRV/logical-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c74cff0d14f1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4448,6 +4448,7 @@ def SPIRV_OC_OpUMulExtended                   : I32EnumAttrCase<"OpUMulExtended"
 def SPIRV_OC_OpSMulExtended                   : I32EnumAttrCase<"OpSMulExtended", 152>;
 def SPIRV_OC_OpIsNan                          : I32EnumAttrCase<"OpIsNan", 156>;
 def SPIRV_OC_OpIsInf                          : I32EnumAttrCase<"OpIsInf", 157>;
+def SPIRV_OC_OpIsFinite                       : I32EnumAttrCase<"OpIsFinite", 158>;
 def SPIRV_OC_OpOrdered                        : I32EnumAttrCase<"OpOrdered", 162>;
 def SPIRV_OC_OpUnordered                      : I32EnumAttrCase<"OpUnordered", 163>;
 def SPIRV_OC_OpLogicalEqual                   : I32EnumAttrCase<"OpLogicalEqual", 164>;
@@ -4630,7 +4631,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
-      SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
+      SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpIsFinite,
+      SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
       SPIRV_OC_OpLogicalEqual, SPIRV_OC_OpLogicalNotEqual, SPIRV_OC_OpLogicalOr,
       SPIRV_OC_OpLogicalAnd, SPIRV_OC_OpLogicalNot, SPIRV_OC_OpSelect,
       SPIRV_OC_OpIEqual, SPIRV_OC_OpINotEqual, SPIRV_OC_OpUGreaterThan,

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index ab535d7b2a304..9331fc576c7bd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -403,6 +403,28 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
 
 // -----
 
+def SPIRV_IsFiniteOp : SPIRV_LogicalUnaryOp<"IsFinite", SPIRV_Float, []> {
+  let summary = "Result is true if x is an IEEE Finite, otherwise result is false";
+
+  let description = [{
+    Result Type must be a scalar or vector of Boolean type.
+
+    x must be a scalar or vector of floating-point type.  It must have the
+    same number of components as Result Type.
+
+    Results are computed per component.
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.IsFinite %0: f32
+    %3 = spirv.IsFinite %1: vector<4xf32>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> {
   let summary = "Result is true if x is an IEEE Inf, otherwise result is false";
 
@@ -418,7 +440,7 @@ def SPIRV_IsInfOp : SPIRV_LogicalUnaryOp<"IsInf", SPIRV_Float, []> {
 
     ```mlir
     %2 = spirv.IsInf %0: f32
-    %3 = spirv.IsInf %1: vector<4xi32>
+    %3 = spirv.IsInf %1: vector<4xf32>
     ```
   }];
 }
@@ -442,7 +464,7 @@ def SPIRV_IsNanOp : SPIRV_LogicalUnaryOp<"IsNan", SPIRV_Float, []> {
 
     ```mlir
     %2 = spirv.IsNan %0: f32
-    %3 = spirv.IsNan %1: vector<4xi32>
+    %3 = spirv.IsNan %1: vector<4xf32>
     ```
   }];
 }

diff  --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index a877ad21734a2..1787e0a44f8fd 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -488,7 +488,12 @@ namespace mlir {
 void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
                                  RewritePatternSet &patterns) {
   // Core patterns
-  patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
+  patterns
+      .add<CopySignPattern,
+           CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
+           CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
+           CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
+          typeConverter, patterns.getContext());
 
   // GLSL patterns
   patterns

diff  --git a/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
new file mode 100644
index 0000000000000..3e5f592049e7f
--- /dev/null
+++ b/mlir/test/Conversion/MathToSPIRV/math-to-fpclassify-spirv.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --convert-math-to-spirv %s | FileCheck %s
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
+} {
+
+  // CHECK-LABEL: @fpclassify
+  func.func @fpclassify(%x: f32, %v: vector<4xf32>) {
+    // CHECK: spirv.IsFinite %{{.*}} : f32
+    %0 = math.isfinite %x : f32
+    // CHECK: spirv.IsFinite %{{.*}} : vector<4xf32>
+    %1 = math.isfinite %v : vector<4xf32>
+
+    // CHECK: spirv.IsNan %{{.*}} : f32
+    %2 = math.isnan %x : f32
+    // CHECK: spirv.IsNan %{{.*}} : vector<4xf32>
+    %3 = math.isnan %v : vector<4xf32>
+
+    // CHECK: spirv.IsInf %{{.*}} : f32
+    %4 = math.isinf %x : f32
+    // CHECK: spirv.IsInf %{{.*}} : vector<4xf32>
+    %5 = math.isinf %v : vector<4xf32>
+
+    return
+  }
+
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
index d6c34645f5746..58b828877e71d 100644
--- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir
@@ -32,6 +32,24 @@ func.func @inotequal_vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> vecto
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.IsFinite
+//===----------------------------------------------------------------------===//
+
+func.func @isfinite_scalar(%arg0: f32) -> i1 {
+  // CHECK: spirv.IsFinite {{.*}} : f32
+  %0 = spirv.IsFinite %arg0 : f32
+  return %0 : i1
+}
+
+func.func @isfinite_vector(%arg0: vector<2xf32>) -> vector<2xi1> {
+  // CHECK: spirv.IsFinite {{.*}} : vector<2xf32>
+  %0 = spirv.IsFinite %arg0 : vector<2xf32>
+  return %0 : vector<2xi1>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.IsInf
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir
index b2008719b021c..05cbddc048151 100644
--- a/mlir/test/Target/SPIRV/logical-ops.mlir
+++ b/mlir/test/Target/SPIRV/logical-ops.mlir
@@ -84,6 +84,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %15 = spirv.IsNan %arg0 : f32
     // CHECK: spirv.IsInf
     %16 = spirv.IsInf %arg1 : f32
+    // CHECK: spirv.IsFinite
+    %17 = spirv.IsFinite %arg0 : f32
     spirv.Return
   }
 }


        


More information about the Mlir-commits mailing list