[Mlir-commits] [mlir] [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion " (PR #151337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 30 07:18:22 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/151337
>From 02e73b2412b774cd4d7eae420801e21de24e7a7c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 06:46:00 -0700
Subject: [PATCH 1/2] add the mlir support for
SPV_INTEL_tensor_float32_conversion extension
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 17 ++++--
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 +++++++++++++++++++
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 21 ++++++++
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 +++++++++++++
mlir/test/Target/SPIRV/intel-ext-ops.mlir | 22 ++++++++
5 files changed, 147 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c9eefd054fa6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
- SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
+ SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
+ SPV_INTEL_tensor_float32_conversion
]>;
//===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}
+def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_tensor_float32_conversion]>
+ ];
+}
+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
- SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+ SPIRV_C_TensorFloat32RoundingINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
- SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+ SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+ SPIRV_OC_OpRoundFToTF32INTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..b692c07122683 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
// at (https://github.com/intel/llvm)
// Supported extensions
// * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
//===----------------------------------------------------------------------===//
@@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
let hasVerifier = 1;
}
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+ let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+ let description = [{
+ Convert value numerically from a 32-bit floating point type to tensor float32,
+ with rounding to the nearest even.
+
+ Result Type must be a scalar or vector of 32-bit floating-point type.
+ The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+ Float Value must be a scalar or vector of floating-point type.
+ It must have the same number of components as Result Type. The component width must be 32 bits.
+
+ Results are computed per component.
+
+
+ ```
+ convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
+ `:` operand-type `to` result-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %1 = spirv.RoundFToTF32 %0 : f32 to f32
+ %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+ ```
+
+ }];
+
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_tensor_float32_conversion]>,
+ Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ );
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fc3e7308356bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.INTELRoundFToTF32Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELRoundFToTF32Op::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d018a6c44..aa5bee5796cfa 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+ // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+ // expected-error @+1 {{operand and result must have same number of elements}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+ // CHECK-LABEL: @f32_to_tf32
+ spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f32_to_tf32_vec
+ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
>From 4f4bfd035051ae70373dec6e90d8f19fa728ef8d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 07:18:11 -0700
Subject: [PATCH 2/2] remove the grammar definition
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index b692c07122683..215d57532ca84 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -127,12 +127,6 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
It must have the same number of components as Result Type. The component width must be 32 bits.
Results are computed per component.
-
-
- ```
- convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
- `:` operand-type `to` result-type
- ```
#### Example:
More information about the Mlir-commits
mailing list