[Mlir-commits] [mlir] [mlir][spirv] Add OpExtension "SPV_INTEL_tensor_float32_conversion " (PR #151337)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 31 12:57:18 PDT 2025
https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/151337
>From 02e73b2412b774cd4d7eae420801e21de24e7a7c Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 06:46:00 -0700
Subject: [PATCH 1/8] add the mlir support for
SPV_INTEL_tensor_float32_conversion extension
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 17 ++++--
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 +++++++++++++++++++
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 21 ++++++++
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 +++++++++++++
mlir/test/Target/SPIRV/intel-ext-ops.mlir | 22 ++++++++
5 files changed, 147 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c9eefd054fa6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
- SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
+ SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
+ SPV_INTEL_tensor_float32_conversion
]>;
//===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}
+def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_tensor_float32_conversion]>
+ ];
+}
+
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
list<Availability> availability = [
Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
- SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+ SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+ SPIRV_C_TensorFloat32RoundingINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
- SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+ SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+ SPIRV_OC_OpRoundFToTF32INTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..b692c07122683 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
// at (https://github.com/intel/llvm)
// Supported extensions
// * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
//===----------------------------------------------------------------------===//
@@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
let hasVerifier = 1;
}
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+ let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+ let description = [{
+ Convert value numerically from a 32-bit floating point type to tensor float32,
+ with rounding to the nearest even.
+
+ Result Type must be a scalar or vector of 32-bit floating-point type.
+ The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+ Float Value must be a scalar or vector of floating-point type.
+ It must have the same number of components as Result Type. The component width must be 32 bits.
+
+ Results are computed per component.
+
+
+ ```
+ convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
+ `:` operand-type `to` result-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %1 = spirv.RoundFToTF32 %0 : f32 to f32
+ %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+ ```
+
+ }];
+
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_tensor_float32_conversion]>,
+ Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ );
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fc3e7308356bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.INTELRoundFToTF32Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELRoundFToTF32Op::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d018a6c44..aa5bee5796cfa 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+ // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+ // expected-error @+1 {{operand and result must have same number of elements}}
+ %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
// -----
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+ // CHECK-LABEL: @f32_to_tf32
+ spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f32_to_tf32_vec
+ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+ %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+ spirv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.INTEL.SplitBarrier
//===----------------------------------------------------------------------===//
>From 4f4bfd035051ae70373dec6e90d8f19fa728ef8d Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 07:18:11 -0700
Subject: [PATCH 2/8] remove the grammar definition
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 6 ------
1 file changed, 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index b692c07122683..215d57532ca84 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -127,12 +127,6 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
It must have the same number of components as Result Type. The component width must be 32 bits.
Results are computed per component.
-
-
- ```
- convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
- `:` operand-type `to` result-type
- ```
#### Example:
>From 0bbd2687ecdd2ca51ec53792fea5dadcc1383a68 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 08:36:55 -0700
Subject: [PATCH 3/8] modify CastOps.cpp to add vector non-scalable check
---
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 51 ++++++++++++++++-----------
1 file changed, 30 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index fc3e7308356bf..d3672220d7c03 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -276,13 +276,16 @@ LogicalResult ConvertUToFOp::verify() {
LogicalResult INTELConvertBF16ToFOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
+ // ODS checks that vector result type and vector operand type are
+ // non-scalable and have the same shape.
+ auto operandVectorType = dyn_cast<VectorType>(operandType);
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
+ if (operandVectorType && resultVectorType) {
+ if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+ return emitOpError("scalable vectors are not supported");
+ }
+ if (operandVectorType.getNumElements() !=
+ resultVectorType.getNumElements()) {
return emitOpError(
"operand and result must have same number of elements");
}
@@ -297,13 +300,16 @@ LogicalResult INTELConvertBF16ToFOp::verify() {
LogicalResult INTELConvertFToBF16Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
+ // ODS checks that vector result type and vector operand type are
+ // non-scalable and have the same shape.
+ auto operandVectorType = dyn_cast<VectorType>(operandType);
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
+ if (operandVectorType && resultVectorType) {
+ if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+ return emitOpError("scalable vectors are not supported");
+ }
+ if (operandVectorType.getNumElements() !=
+ resultVectorType.getNumElements()) {
return emitOpError(
"operand and result must have same number of elements");
}
@@ -318,13 +324,16 @@ LogicalResult INTELConvertFToBF16Op::verify() {
LogicalResult INTELRoundFToTF32Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
+ // ODS checks that vector result type and vector operand type are
+ // non-scalable and have the same shape.
+ auto operandVectorType = dyn_cast<VectorType>(operandType);
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
+ if (operandVectorType && resultVectorType) {
+ if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
+ return emitOpError("scalable vectors are not supported");
+ }
+ if (operandVectorType.getNumElements() !=
+ resultVectorType.getNumElements()) {
return emitOpError(
"operand and result must have same number of elements");
}
>From 446149f885ef2efadfcdac583c8bbe5d7f1df88a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 13:38:55 -0700
Subject: [PATCH 4/8] use SameOperandsAndResultShape vector shape check
---
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 6 ++--
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 36 +++++--------------
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 6 ++--
3 files changed, 15 insertions(+), 33 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 215d57532ca84..62d5826e008b3 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -20,7 +20,7 @@
// -----
-def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";
let description = [{
@@ -68,7 +68,7 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
// -----
-def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_bfloat16_conversion";
let description = [{
@@ -113,7 +113,7 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
// -----
-def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperandsAndResultShape]> {
let summary = "See extension SPV_INTEL_tensor_float32_conversion";
let description = [{
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index d3672220d7c03..d5f19ab710daa 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -277,18 +277,12 @@ LogicalResult INTELConvertBF16ToFOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type are
- // non-scalable and have the same shape.
- auto operandVectorType = dyn_cast<VectorType>(operandType);
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType && resultVectorType) {
+ // non-scalable.
+ if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
return emitOpError("scalable vectors are not supported");
}
- if (operandVectorType.getNumElements() !=
- resultVectorType.getNumElements()) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
}
return success();
}
@@ -301,18 +295,12 @@ LogicalResult INTELConvertFToBF16Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type are
- // non-scalable and have the same shape.
- auto operandVectorType = dyn_cast<VectorType>(operandType);
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType && resultVectorType) {
+ // non-scalable.
+ if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
return emitOpError("scalable vectors are not supported");
}
- if (operandVectorType.getNumElements() !=
- resultVectorType.getNumElements()) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
}
return success();
}
@@ -325,18 +313,12 @@ LogicalResult INTELRoundFToTF32Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type are
- // non-scalable and have the same shape.
- auto operandVectorType = dyn_cast<VectorType>(operandType);
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType && resultVectorType) {
+ // non-scalable.
+ if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
return emitOpError("scalable vectors are not supported");
}
- if (operandVectorType.getNumElements() !=
- resultVectorType.getNumElements()) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
}
return success();
}
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index aa5bee5796cfa..e3cce924802a3 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -29,7 +29,7 @@ spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
// -----
spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
spirv.Return
}
@@ -65,7 +65,7 @@ spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
// -----
spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}
@@ -101,7 +101,7 @@ spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
// -----
spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
- // expected-error @+1 {{operand and result must have same number of elements}}
+ // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
spirv.Return
}
>From d12696c0fcc1e48df9e1043ecd881c4e4331e458 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 14:05:37 -0700
Subject: [PATCH 5/8] formating
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 62d5826e008b3..805719bda770d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -152,6 +152,7 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperand
let results = (outs
SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
);
+
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
>From fd900223a6b2c7513ff2cc73b1606fdd93dcdcfe Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Wed, 30 Jul 2025 14:32:00 -0700
Subject: [PATCH 6/8] formating
---
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 9c9eefd054fa6..89ae6bba13149 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -469,14 +469,14 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
+ SPV_INTEL_tensor_float32_conversion,
SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
- SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
- SPV_INTEL_tensor_float32_conversion
+ SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
]>;
//===----------------------------------------------------------------------===//
>From 9c078c1d6ad4e0fa564a88842b9e2407796a489f Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 31 Jul 2025 11:26:11 -0700
Subject: [PATCH 7/8] replace the verify function with
FixedVectorOfLengthAndType
---
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 38 +++++++++----
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 54 -------------------
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 6 +--
3 files changed, 32 insertions(+), 66 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 805719bda770d..7729703a1f4e7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -53,17 +53,24 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOpe
];
let arguments = (ins
- SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ AnyTypeOf<[
+ SPIRV_Float32,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
+ ]>:$operand
);
let results = (outs
- SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
+ AnyTypeOf<[
+ SPIRV_Int16,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Int16]>
+ ]>:$result
);
+
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasVerifier = 1;
+ let hasVerifier = 0;
}
// -----
@@ -98,17 +105,24 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOpe
];
let arguments = (ins
- SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$operand
+ AnyTypeOf<[
+ SPIRV_Int16,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Int16]>
+ ]>:$operand
);
let results = (outs
- SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ AnyTypeOf<[
+ SPIRV_Float32,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
+ ]>:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasVerifier = 1;
+
+ let hasVerifier = 0;
}
// -----
@@ -146,18 +160,24 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperand
];
let arguments = (ins
- SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ AnyTypeOf<[
+ SPIRV_Float32,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
+ ]>:$operand
);
let results = (outs
- SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ AnyTypeOf<[
+ SPIRV_Float32,
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
+ ]>:$result
);
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
- let hasVerifier = 1;
+ let hasVerifier = 0;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index d5f19ab710daa..fcf4eb6fbcf60 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -269,60 +269,6 @@ LogicalResult ConvertUToFOp::verify() {
/*skipBitWidthCheck=*/true);
}
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertBF16ToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertBF16ToFOp::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type are
- // non-scalable.
- if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
- return emitOpError("scalable vectors are not supported");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertFToBF16Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertFToBF16Op::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type are
- // non-scalable.
- if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
- return emitOpError("scalable vectors are not supported");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELRoundFToTF32Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELRoundFToTF32Op::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type are
- // non-scalable.
- if (auto operandVectorType = dyn_cast<VectorType>(operandType)) {
- auto resultVectorType = dyn_cast<VectorType>(resultType);
- if (operandVectorType.isScalable() || resultVectorType.isScalable()) {
- return emitOpError("scalable vectors are not supported");
- }
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index e3cce924802a3..55153a78fba5b 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -21,7 +21,7 @@ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
- // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
%0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
spirv.Return
}
@@ -57,7 +57,7 @@ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
// -----
spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
- // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{op result #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f16'}}
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
spirv.Return
}
@@ -93,7 +93,7 @@ spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
// -----
spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
- // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ // expected-error @+1 {{op operand #0 must be Float32 or fixed-length vector of Float32 values of length 2/3/4/8/16, but got 'f64'}}
%0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
spirv.Return
}
>From cc8641e5519f58da1995e336b3a7a448a2e0f071 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 31 Jul 2025 12:57:03 -0700
Subject: [PATCH 8/8] update SPIRV_VectorOf to use FixedVectorOfLengthAndType
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 2 +-
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 30 ++++---------------
2 files changed, 7 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 89ae6bba13149..305a51aae050f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4286,7 +4286,7 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
"Matrix">;
class SPIRV_VectorOf<Type type> :
- VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
+ FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 7729703a1f4e7..abf373cf3c511 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -53,17 +53,11 @@ def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", [SameOpe
];
let arguments = (ins
- AnyTypeOf<[
- SPIRV_Float32,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
- ]>:$operand
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
);
let results = (outs
- AnyTypeOf<[
- SPIRV_Int16,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Int16]>
- ]>:$result
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
);
let assemblyFormat = [{
@@ -105,17 +99,11 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", [SameOpe
];
let arguments = (ins
- AnyTypeOf<[
- SPIRV_Int16,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Int16]>
- ]>:$operand
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$operand
);
let results = (outs
- AnyTypeOf<[
- SPIRV_Float32,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
- ]>:$result
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
);
let assemblyFormat = [{
@@ -160,17 +148,11 @@ def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", [SameOperand
];
let arguments = (ins
- AnyTypeOf<[
- SPIRV_Float32,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
- ]>:$operand
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
);
let results = (outs
- AnyTypeOf<[
- SPIRV_Float32,
- FixedVectorOfLengthAndType<[2, 3, 4, 8, 16], [SPIRV_Float32]>
- ]>:$result
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
);
let assemblyFormat = [{
More information about the Mlir-commits
mailing list