[Mlir-commits] [mlir] [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" (PR #151337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 11:55:13 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/151337

>From 01d179fcab3d07b3a37eed94e94f2aaa32c4da6a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 06:46:00 -0700
Subject: [PATCH 1/2] add the mlir support for
 SPV_INTEL_tensor_float32_conversion extension

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 15 ++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 58 +++++++++++++++++--
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         | 42 --------------
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 40 ++++++++++++-
 mlir/test/Target/SPIRV/intel-ext-ops.mlir     | 22 +++++++
 5 files changed, 127 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 37ee85b04f1eb..bdfd728d1d0b3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_me
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr :
       SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
       SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
       SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
+      SPV_INTEL_tensor_float32_conversion,
       SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"B
   ];
 }
 
+def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_tensor_float32_conversion]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
-      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+      SPIRV_C_TensorFloat32RoundingINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4587,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL      : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4692,7 +4702,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
-      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+      SPIRV_OC_OpRoundFToTF32INTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..2a7fa534cc3dc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
 // at (https://github.com/intel/llvm)
 // Supported extensions
 // * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
 //===----------------------------------------------------------------------===//
 
 
@@ -19,7 +20,7 @@
 
 // -----
 
-def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> {
   let summary = "See extension SPV_INTEL_bfloat16_conversion";
 
   let description = [{
@@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
   let results = (outs
     SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
   );
+
   let assemblyFormat = [{
     $operand attr-dict `:` type($operand) `to` type($result)
   }];
 
-  let hasVerifier = 1;
+  let hasVerifier = 0;
 }
 
 // -----
 
-def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> {
   let summary = "See extension SPV_INTEL_bfloat16_conversion";
 
   let description = [{
@@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
   let assemblyFormat = [{
     $operand attr-dict `:` type($operand) `to` type($result)
   }];
-  let hasVerifier = 1;
+
+  let hasVerifier = 0;
 }
 
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> {
+  let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+  let description = [{
+    Convert value numerically from a 32-bit floating point type to tensor float32,
+    with rounding to the nearest even.
+
+    Result Type must be a scalar or vector of 32-bit floating-point type.
+    The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+    Float Value must be a scalar or vector of floating-point type.
+    It must have the same number of components as Result Type. The component width must be 32 bits.
+
+    Results are computed per component.
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.RoundFToTF32 %0 : f32 to f32
+    %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+    ```
+
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_tensor_float32_conversion]>,
+    Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+  );
+
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+
+  let hasVerifier = 0;
+}
 
 // -----
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fcf4eb6fbcf60 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -269,48 +269,6 @@ LogicalResult ConvertUToFOp::verify() {
                       /*skipBitWidthCheck=*/true);
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertBF16ToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertBF16ToFOp::verify() {
-  auto operandType = getOperand().getType();
-  auto resultType = getResult().getType();
-  // ODS checks that vector result type and vector operand type have the same
-  // shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    unsigned operandNumElements = vectorType.getNumElements();
-    unsigned resultNumElements =
-        llvm::cast<VectorType>(resultType).getNumElements();
-    if (operandNumElements != resultNumElements) {
-      return emitOpError(
-          "operand and result must have same number of elements");
-    }
-  }
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertFToBF16Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertFToBF16Op::verify() {
-  auto operandType = getOperand().getType();
-  auto resultType = getResult().getType();
-  // ODS checks that vector result type and vector operand type have the same
-  // shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    unsigned operandNumElements = vectorType.getNumElements();
-    unsigned resultNumElements =
-        llvm::cast<VectorType>(resultType).getNumElements();
-    if (operandNumElements != resultNumElements) {
-      return emitOpError(
-          "operand and result must have same number of elements");
-    }
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.FConvertOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 22352da07cf13..0fdbdb66857f6 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
 // -----
 
 spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
-  // expected-error @+1 {{operand and result must have same number of elements}}
+  // expected-error @+1 {{op requires the same shape for all operands and results}}
   %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
   spirv.Return
 }
@@ -65,13 +65,49 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
 // -----
 
 spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
-  // expected-error @+1 {{operand and result must have same number of elements}}
+  // expected-error @+1 {{op requires the same shape for all operands and results}}
   %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
   spirv.Return
 }
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+  // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+  // expected-error @+1 {{op requires the same shape for all operands and results}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+  // CHECK-LABEL: @f32_to_tf32
+  spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @f32_to_tf32_vec
+  spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//

>From cd465dc4ebbd3682a22df411a274462ea871b403 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 1 Aug 2025 11:55:01 -0700
Subject: [PATCH 2/2] update intel-ext-ops.mlir

---
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 0fdbdb66857f6..2e2fb1a9df328 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
 // -----
 
 spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
-  // expected-error @+1 {{op operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+  // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
   %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
   spirv.Return
 }



More information about the Mlir-commits mailing list