[Mlir-commits] [mlir] b63a9b7 - [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" (#151337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 1 12:20:35 PDT 2025
Author: YixingZhang007
Date: 2025-08-01T15:20:32-04:00
New Revision: b63a9b7a3cdc1e41545df595215077e50bfd04af
URL: https://github.com/llvm/llvm-project/commit/b63a9b7a3cdc1e41545df595215077e50bfd04af
DIFF: https://github.com/llvm/llvm-project/commit/b63a9b7a3cdc1e41545df595215077e50bfd04af.diff
LOG: [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion" (#151337)
This PR provides the support for the capability
`TensorFloat32RoundingINTEL` and the instruction `OpRoundFToTF32INTE`L,
as specified by the `SPV_INTEL_tensor_float32_conversion` extension.
This extension introduces a rounding instruction that converts standard
32-bit floating-point values to the TensorFloat32 (TF32) format.
Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_tensor_float32_conversion.asciidoc
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
mlir/test/Target/SPIRV/intel-ext-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 37ee85b04f1eb..bdfd728d1d0b3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -468,6 +469,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
+ SPV_INTEL_tensor_float32_conversion,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}
+def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_tensor_float32_conversion]>
+ ];
+}
+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
- SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+ SPIRV_C_TensorFloat32RoundingINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4587,6 +4596,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4692,7 +4702,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
- SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+ SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+ SPIRV_OC_OpRoundFToTF32INTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..2a7fa534cc3dc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
// at (https://github.com/intel/llvm)
// Supported extensions
// * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
//===----------------------------------------------------------------------===//
@@ -19,7 +20,7 @@
// -----
-def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";
let description = [{
@@ -58,16 +59,17 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
let results = (outs
SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
);
+
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasVerifier = 1;
+ let hasVerifier = 0;
}
// -----
-def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";
let description = [{
@@ -107,9 +109,57 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasVerifier = 1;
+
+ let hasVerifier = 0;
}
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> {
+ let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+ let description = [{
+ Convert value numerically from a 32-bit floating point type to tensor float32,
+ with rounding to the nearest even.
+
+ Result Type must be a scalar or vector of 32-bit floating-point type.
+ The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+ Float Value must be a scalar or vector of floating-point type.
+ It must have the same number of components as Result Type. The component width must be 32 bits.
+
+ Results are computed per component.
+
+ #### Example:
+
+ ```mlir
+ %1 = spirv.RoundFToTF32 %0 : f32 to f32
+ %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+ ```
+
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_tensor_float32_conversion]>,
+ Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
+
+ let hasVerifier = 0;
+}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fcf4eb6fbcf60 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -269,48 +269,6 @@ LogicalResult ConvertUToFOp::verify() {
/*skipBitWidthCheck=*/true);
}
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertBF16ToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertBF16ToFOp::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertFToBF16Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertFToBF16Op::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 22352da07cf13..2e2fb1a9df328 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
// -----
spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
spirv.Return
}
@@ -65,13 +65,49 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
// -----
spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+ // CHECK-LABEL: @f32_to_tf32
+ spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f32_to_tf32_vec
+ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list