[Mlir-commits] [mlir] [mlir][spirv] Small fix around SPV_EXT_FP8 when converting types (PR #192466)

Davide Grohmann llvmlistbot at llvm.org
Thu Apr 16 07:23:41 PDT 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/192466

Expand tests

>From 857a6018ee2fff9d0c7a6bbeeb6d297918407914 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 16 Apr 2026 16:09:44 +0200
Subject: [PATCH] [mlir][spirv] Small fix around SPV_EXT_FP8 when converting
 types

Expand tests

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia80ace0a83e1d638ba00a43d04fa7212ed7e7092
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 11 ++-
 .../FuncToSPIRV/types-to-spirv.mlir           | 77 ++++++++++++-------
 mlir/test/Dialect/SPIRV/IR/types.mlir         | 35 +++++++--
 mlir/test/Target/SPIRV/tensorARM.mlir         | 22 +++++-
 4 files changed, 106 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 0853c5aa59f92..7ab7c46ce594b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -541,6 +541,9 @@ bool ScalarType::classof(Type type) {
 }
 
 bool ScalarType::isValid(FloatType type) {
+  if (type.isF8E4M3FN() || type.isF8E5M2()) {
+    return true;
+  }
   return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
 }
 
@@ -549,12 +552,12 @@ bool ScalarType::isValid(IntegerType type) {
 }
 
 void TypeExtensionVisitor::addConcrete(ScalarType type) {
-  if (isa<BFloat16Type>(type)) {
+  if (type.isBF16()) {
     static constexpr auto ext = Extension::SPV_KHR_bfloat16;
     extensions.push_back(ext);
   }
 
-  if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
+  if (type.isF8E4M3FN() || type.isF8E5M2()) {
     static constexpr auto ext = Extension::SPV_EXT_float8;
     extensions.push_back(ext);
   }
@@ -657,7 +660,7 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
     assert(isa<FloatType>(type));
     switch (bitwidth) {
     case 8: {
-      if (isa<Float8E4M3FNType, Float8E5M2Type>(type)) {
+      if (type.isF8E4M3FN() || type.isF8E5M2()) {
         static constexpr auto cap = Capability::Float8EXT;
         capabilities.push_back(cap);
       } else {
@@ -666,7 +669,7 @@ void TypeCapabilityVisitor::addConcrete(ScalarType type) {
       break;
     }
     case 16: {
-      if (isa<BFloat16Type>(type)) {
+      if (type.isBF16()) {
         static constexpr auto cap = Capability::BFloat16TypeKHR;
         capabilities.push_back(cap);
       } else {
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 0c77c88334572..829401c194f55 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -962,37 +962,60 @@ module attributes {
   // CHECK-SAME: %arg3: i8
   // CHECK-SAME: %arg4: i8
   // CHECK-SAME: %arg5: i8
-  // CHECK-SAME: %arg6: i8
-  // CHECK-SAME: %arg7: i8
-  // CHECK-SAME: %arg8: vector<4xi8>
-  // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
-  // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+  // CHECK-SAME: %arg6: vector<4xi8>
+  // CHECK-SAME: %arg7: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-SAME: %arg8: !spirv.array<4 x i8>
   // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
-  // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
-  // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
-  // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
-  // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
-  // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
-  // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
-  // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
-  // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
-  // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
-  // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
-  // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+  // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E4M3
+  // UNSUPPORTED_FLOAT-SAME: %arg1: f8E5M2FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg3: f8E4M3B11FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg4: f8E3M4
+  // UNSUPPORTED_FLOAT-SAME: %arg5: f8E8M0FNU
+  // UNSUPPORTED_FLOAT-SAME: %arg6: vector<4xf8E4M3B11FNUZ>
+  // UNSUPPORTED_FLOAT-SAME: %arg7: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+  // UNSUPPORTED_FLOAT-SAME: %arg8: tensor<4xf8E4M3>
   // UNSUPPORTED_FLOAT-SAME: ) {
 
   func.func @float8_to_integer8(
-    %arg0: f8E5M2,                   // CHECK-NOT: f8E5M2
-    %arg1: f8E4M3,                   // CHECK-NOT: f8E4M3
-    %arg2: f8E4M3FN,                // CHECK-NOT: f8E4M3FN
-    %arg3: f8E5M2FNUZ,              // CHECK-NOT: f8E5M2FNUZ
-    %arg4: f8E4M3FNUZ,              // CHECK-NOT: f8E4M3FNUZ
-    %arg5: f8E4M3B11FNUZ,           // CHECK-NOT: f8E4M3B11FNUZ
-    %arg6: f8E3M4,                  // CHECK-NOT: f8E3M4
-    %arg7: f8E8M0FNU,               // CHECK-NOT: f8E8M0FNU
-    %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
-    %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
-    %arg10: tensor<4xf8E5M2>        // CHECK-NOT: tensor
+    %arg0: f8E4M3,                  // CHECK-NOT: f8E4M3
+    %arg1: f8E5M2FNUZ,              // CHECK-NOT: f8E5M2FNUZ
+    %arg2: f8E4M3FNUZ,              // CHECK-NOT: f8E4M3FNUZ
+    %arg3: f8E4M3B11FNUZ,           // CHECK-NOT: f8E4M3B11FNUZ
+    %arg4: f8E3M4,                  // CHECK-NOT: f8E3M4
+    %arg5: f8E8M0FNU,               // CHECK-NOT: f8E8M0FNU
+    %arg6: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+    %arg7: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+    %arg8: tensor<4xf8E4M3>        // CHECK-NOT: tensor
+  ) {
+    // CHECK: spirv.Return
+    return
+  }
+}
+
+// -----
+
+// Check that supported Float8EXT types remain legal SPIR-V scalar types when
+// float emulation is disabled.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Float8EXT], [SPV_EXT_float8]>, #spirv.resource_limits<>>
+} {
+
+  // CHECK-LABEL: spirv.func @supported_float8_types
+  // CHECK-SAME: (%arg0: vector<4xi32>
+  // CHECK-SAME: %arg1: vector<4xi32>
+  // CHECK-SAME: %arg2: !spirv.array<4 x i32>
+  // CHECK-SAME: %arg3: !spirv.array<4 x i32>
+  // UNSUPPORTED_FLOAT-LABEL: spirv.func @supported_float8_types
+  // UNSUPPORTED_FLOAT-SAME: (%arg0: vector<4xf8E5M2>
+  // UNSUPPORTED_FLOAT-SAME: %arg1: vector<4xf8E4M3FN>
+  // UNSUPPORTED_FLOAT-SAME: %arg2: !spirv.array<4 x f8E5M2>
+  // UNSUPPORTED_FLOAT-SAME: %arg3: !spirv.array<4 x f8E4M3FN>
+  func.func @supported_float8_types(
+    %arg0: vector<4xf8E5M2>,
+    %arg1: vector<4xf8E4M3FN>,
+    %arg2: tensor<4xf8E5M2>,
+    %arg3: tensor<4xf8E4M3FN>
   ) {
     // CHECK: spirv.Return
     return
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 99443a13e0ec3..12a7e5df8b592 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -606,6 +606,24 @@ func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> ()
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Float8_EXT
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @type_f8E4M3FN(f8E4M3FN)
+func.func private @type_f8E4M3FN(f8E4M3FN) -> ()
+
+// CHECK: func private @vector_type_f8E4M3FN(vector<4xf8E4M3FN>)
+func.func private @vector_type_f8E4M3FN(vector<4xf8E4M3FN>) -> ()
+
+// CHECK: func private @type_f8E5M2(f8E5M2)
+func.func private @type_f8E5M2(f8E5M2) -> ()
+
+// CHECK: func private @vector_type_f8E5M2(vector<4xf8E5M2>)
+func.func private @vector_type_f8E5M2(vector<4xf8E5M2>) -> ()
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // TensorArm
 //===----------------------------------------------------------------------===//
@@ -659,12 +677,15 @@ func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// Float8_EXT
-//===----------------------------------------------------------------------===//
+// CHECK: func private @arm_tensor_type_bf16(!spirv.arm.tensor<2x3xbf16>)
+func.func private @arm_tensor_type_bf16(!spirv.arm.tensor<2x3xbf16>) -> ()
 
-// CHECK: func private @type_f8E4M3FN(f8E4M3FN)
-func.func private @type_f8E4M3FN(f8E4M3FN) -> ()
+// -----
 
-// CHECK: func private @type_f8E5M2(f8E5M2)
-func.func private @type_f8E5M2(f8E5M2) -> ()
+// CHECK: func private @arm_tensor_type_fp8e4m3fn(!spirv.arm.tensor<2x3xf8E4M3FN>)
+func.func private @arm_tensor_type_fp8e4m3fn(!spirv.arm.tensor<2x3xf8E4M3FN>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_fp8e5m2(!spirv.arm.tensor<2x3xf8E5M2>)
+func.func private @arm_tensor_type_fp8e5m2(!spirv.arm.tensor<2x3xf8E5M2>) -> ()
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
index 53a41e19f930f..65a0fe62e8c86 100644
--- a/mlir/test/Target/SPIRV/tensorARM.mlir
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -5,7 +5,7 @@
 // RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --spirv-save-validation-files-with-prefix=%t/module %s %}
 // RUN: %if spirv-tools %{ spirv-val %t %}
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorsARM, Int64], [SPV_ARM_tensors]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, TensorsARM, Int64, BFloat16TypeKHR, Float8EXT], [SPV_ARM_tensors, SPV_KHR_bfloat16, SPV_EXT_float8]> {
   // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
   spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
     spirv.Return
@@ -68,4 +68,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Tensors
   spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor<?x?xi32>) "None" {
     spirv.Return
   }
+// -----
+
+  // CHECK: spirv.func @shaped_bf16_arm_tensor(%arg0: !spirv.arm.tensor<2xbf16>) "None" {
+  spirv.func @shaped_bf16_arm_tensor(%arg0 : !spirv.arm.tensor<2xbf16>) "None" {
+    spirv.Return
+  }
+// -----
+
+  // CHECK: spirv.func @shaped_fp8e4m3fn_arm_tensor(%arg0: !spirv.arm.tensor<2xf8E4M3FN>) "None" {
+  spirv.func @shaped_fp8e4m3fn_arm_tensor(%arg0 : !spirv.arm.tensor<2xf8E4M3FN>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @shaped_fp8e5m2_arm_tensor(%arg0: !spirv.arm.tensor<2xf8E5M2>) "None" {
+  spirv.func @shaped_fp8e5m2_arm_tensor(%arg0 : !spirv.arm.tensor<2xf8E5M2>) "None" {
+    spirv.Return
+  }
+
 }



More information about the Mlir-commits mailing list