[Mlir-commits] [mlir] [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion " (PR #151337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 30 08:37:06 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/151337

>From 02e73b2412b774cd4d7eae420801e21de24e7a7c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 06:46:00 -0700
Subject: [PATCH 1/3] add the mlir support for
 SPV_INTEL_tensor_float32_conversion extension

---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 17 ++++--
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 +++++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         | 21 ++++++++
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 +++++++++++++
 mlir/test/Target/SPIRV/intel-ext-ops.mlir     | 22 ++++++++
 5 files changed, 147 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c9eefd054fa6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_me
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
       SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
       SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
       SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
-      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
+      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
+      SPV_INTEL_tensor_float32_conversion
     ]>;
 
 //===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"B
   ];
 }
 
+def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_tensor_float32_conversion]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
-      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+      SPIRV_C_TensorFloat32RoundingINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL      : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
-      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+      SPIRV_OC_OpRoundFToTF32INTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..b692c07122683 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
 // at (https://github.com/intel/llvm)
 // Supported extensions
 // * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
 //===----------------------------------------------------------------------===//
 
 
@@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
   let hasVerifier = 1;
 }
 
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+  let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+  let description = [{
+    Convert value numerically from a 32-bit floating point type to tensor float32,
+    with rounding to the nearest even.
+
+    Result Type must be a scalar or vector of 32-bit floating-point type.
+    The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+    Float Value must be a scalar or vector of floating-point type.
+    It must have the same number of components as Result Type. The component width must be 32 bits.
+
+    Results are computed per component.
+  
+
+    ```
+    convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
+                          `:` operand-type `to` result-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.RoundFToTF32 %0 : f32 to f32
+    %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+    ```
+
+  }];
+
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_tensor_float32_conversion]>,
+    Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+  );
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
 
 // -----
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fc3e7308356bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTELRoundFToTF32Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELRoundFToTF32Op::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getResult().getType();
+  // ODS checks that vector result type and vector operand type have the same
+  // shape.
+  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+    unsigned operandNumElements = vectorType.getNumElements();
+    unsigned resultNumElements =
+        llvm::cast<VectorType>(resultType).getNumElements();
+    if (operandNumElements != resultNumElements) {
+      return emitOpError(
+          "operand and result must have same number of elements");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.FConvertOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d018a6c44..aa5bee5796cfa 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+  // expected-error @+1 {{operand and result must have same number of elements}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+  // CHECK-LABEL: @f32_to_tf32
+  spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @f32_to_tf32_vec
+  spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//

>From 4f4bfd035051ae70373dec6e90d8f19fa728ef8d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 07:18:11 -0700
Subject: [PATCH 2/3] remove the grammar definition

---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index b692c07122683..215d57532ca84 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -127,12 +127,6 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
     It must have the same number of components as Result Type. The component width must be 32 bits.
 
     Results are computed per component.
-  
-
-    ```
-    convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
-                          `:` operand-type `to` result-type
-    ```
 
     #### Example:
 

>From 0bbd2687ecdd2ca51ec53792fea5dadcc1383a68 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 08:36:55 -0700
Subject: [PATCH 3/3] modify CastOps.cpp to add vector non-scalable check

---
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 51 ++++++++++++++++-----------
 1 file changed, 30 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index fc3e7308356bf..d3672220d7c03 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -276,13 +276,16 @@ LogicalResult ConvertUToFOp::verify() {
 LogicalResult INTELConvertBF16ToFOp::verify() {
   auto operandType = getOperand().getType();
   auto resultType = getResult().getType();
-  // ODS checks that vector result type and vector operand type have the same
-  // shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    unsigned operandNumElements = vectorType.getNumElements();
-    unsigned resultNumElements =
-        llvm::cast<VectorType>(resultType).getNumElements();
-    if (operandNumElements != resultNumElements) {
+  // ODS checks that vector result type and vector operand type are
+  // non-scalable and have the same shape.
+  auto operandVectorType = dyn_cast<VectorType>(operandType);
+  auto resultVectorType = dyn_cast<VectorType>(resultType);
+  if (operandVectorType && resultVectorType) {
+    if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+      return emitOpError("scalable vectors are not supported");
+    }
+    if (operandVectorType.getNumElements() !=
+        resultVectorType.getNumElements()) {
       return emitOpError(
           "operand and result must have same number of elements");
     }
@@ -297,13 +300,16 @@ LogicalResult INTELConvertBF16ToFOp::verify() {
 LogicalResult INTELConvertFToBF16Op::verify() {
   auto operandType = getOperand().getType();
   auto resultType = getResult().getType();
-  // ODS checks that vector result type and vector operand type have the same
-  // shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    unsigned operandNumElements = vectorType.getNumElements();
-    unsigned resultNumElements =
-        llvm::cast<VectorType>(resultType).getNumElements();
-    if (operandNumElements != resultNumElements) {
+  // ODS checks that vector result type and vector operand type are
+  // non-scalable and have the same shape.
+  auto operandVectorType = dyn_cast<VectorType>(operandType);
+  auto resultVectorType = dyn_cast<VectorType>(resultType);
+  if (operandVectorType && resultVectorType) {
+    if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+      return emitOpError("scalable vectors are not supported");
+    }
+    if (operandVectorType.getNumElements() !=
+        resultVectorType.getNumElements()) {
       return emitOpError(
           "operand and result must have same number of elements");
     }
@@ -318,13 +324,16 @@ LogicalResult INTELConvertFToBF16Op::verify() {
 LogicalResult INTELRoundFToTF32Op::verify() {
   auto operandType = getOperand().getType();
   auto resultType = getResult().getType();
-  // ODS checks that vector result type and vector operand type have the same
-  // shape.
-  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
-    unsigned operandNumElements = vectorType.getNumElements();
-    unsigned resultNumElements =
-        llvm::cast<VectorType>(resultType).getNumElements();
-    if (operandNumElements != resultNumElements) {
+  // ODS checks that vector result type and vector operand type are
+  // non-scalable and have the same shape.
+  auto operandVectorType = dyn_cast<VectorType>(operandType);
+  auto resultVectorType = dyn_cast<VectorType>(resultType);
+  if (operandVectorType && resultVectorType) {
+    if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+      return emitOpError("scalable vectors are not supported");
+    }
+    if (operandVectorType.getNumElements() !=
+        resultVectorType.getNumElements()) {
       return emitOpError(
           "operand and result must have same number of elements");
     }



More information about the Mlir-commits mailing list