[Mlir-commits] [llvm] [mlir] [SPIR-V] Add support for the SPIR-V extension SPV_INTEL_tensor_float32_conversion (PR #150090)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 29 20:07:46 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/150090

>From 3749aa623d613e32f0cc6691faac1cdde6a8ea85 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 22 Jul 2025 12:14:29 -0700
Subject: [PATCH 1/4] draft implementation for supporting
 SPV_INTEL_tensor_float32_conversion

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 17 +++++++++++++--
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 21 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  5 ++++-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  6 ++++++
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  2 ++
 5 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 6ec7544767c52..1c7c1750af1c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -148,6 +148,7 @@ struct ConvertBuiltin {
   bool IsSaturated;
   bool IsRounded;
   bool IsBfloat16;
+  bool IsTF32;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
 
@@ -2677,8 +2678,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
       }
     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
                                           SPIRV::OpTypeFloat)) {
-      // Float -> Float
-      Opcode = SPIRV::OpFConvert;
+      if(Builtin->IsTF32){
+        const auto *ST = static_cast<const SPIRVSubtarget *>(
+          &MIRBuilder.getMF().getSubtarget());
+        if (!ST->canUseExtension(
+                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+          IsRightComponentsNumber =
+            GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+            GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
+        Opcode = SPIRV::OpRoundFToTF32INTEL;
+      } else {
+        Float -> Float
+        Opcode = SPIRV::OpFConvert;
+      }
     }
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index ea78dcd135267..326109c9fdff4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1461,6 +1461,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
   bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
                        !not(!eq(!find(name, "bfloat16"), -1)));
+  bit IsTF32 = !not(!eq(!find(name, "TF32"), -1));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
                                   !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
                                   !not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1472,7 +1473,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
 def ConvertBuiltins : GenericTable {
   let FilterClass = "ConvertBuiltin";
   let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
-                "IsRounded", "IsBfloat16", "RoundingMode"];
+                "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }
@@ -1556,6 +1557,24 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
   def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
 }
 
+// SPV_INTEL_tensor_float32_conversion
+// Multiclass used to define at the same time both a demangled builtin records
+// and a corresponding convert builtin records.
+multiclass DemangledTF32ConvertBuiltin<string name1, string name2> {
+  // Create records for scalar and vector conversions.
+  foreach i = ["", "2", "3", "4", "8", "16"] in {
+    def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+    def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+  }
+}
+
+defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">;
+
+foreach conv = ["FToTF32INTEL"] in {
+  def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
+  def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+}
+
 //===----------------------------------------------------------------------===//
 // Class defining a vector data load/store builtin record used for lowering
 // into OpExtInst instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 049ba0275f223..a04ed6a42c868 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -441,10 +441,13 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
 def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
 def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
 
-// SPV_INTEL_bfloat16_conversion
+// SPV_INTEL_tensor_float32_conversion
 def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
 def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
 
+// SPV_INTEL_bfloat16_conversion
+def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ad976e5288927..c252fc5897518 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1564,6 +1564,12 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
     }
     break;
+  case SPIRV::OpRoundFToTF32INTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
+      Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
+    }
+    break;
   case SPIRV::OpVariableLengthArrayINTEL:
   case SPIRV::OpSaveMemoryINTEL:
   case SPIRV::OpRestoreMemoryINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 548e9b717c161..7b2139a1c84a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
 defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
 defm SPV_INTEL_int4 : ExtensionOperand<123>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
+defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -502,6 +503,7 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
 defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
+defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
 defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;

>From b90fd0c590b67d8a9eb217b82a82ecbf50dacbc3 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 22 Jul 2025 13:34:02 -0700
Subject: [PATCH 2/4] add tests, finalize the implementation and code cleanup

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  9 +--
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 18 +++---
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |  4 +-
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  4 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  2 +-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  2 +-
 .../tf32-conv-negative1.ll                    | 12 ++++
 .../tf32-conv-negative2.ll                    | 12 ++++
 .../tf32-conv.ll                              | 62 +++++++++++++++++++
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        | 17 ++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 ++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         | 21 +++++++
 mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 +++++++++++
 mlir/test/Target/SPIRV/intel-ext-ops.mlir     | 22 +++++++
 14 files changed, 255 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 1c7c1750af1c9..03ca2ad1d8fa5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -231,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
   // - "__spirv_SubgroupImageMediaBlockReadINTEL"
   // - "__spirv_SubgroupImageMediaBlockWriteINTEL"
   // - "__spirv_Convert"
+  // - "__spirv_Round"
   // - "__spirv_UConvert"
   // - "__spirv_SConvert"
   // - "__spirv_FConvert"
@@ -243,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
       "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
-      "Convert|"
+      "Convert|Round"
       "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
   std::smatch Match;
   if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
@@ -2682,14 +2683,14 @@ static bool generateConvertInst(const StringRef DemangledCall,
         const auto *ST = static_cast<const SPIRVSubtarget *>(
           &MIRBuilder.getMF().getSubtarget());
         if (!ST->canUseExtension(
-                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
-          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+                SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
+          NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
           IsRightComponentsNumber =
             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
         Opcode = SPIRV::OpRoundFToTF32INTEL;
       } else {
-        Float -> Float
+        // Float -> Float
         Opcode = SPIRV::OpFConvert;
       }
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 326109c9fdff4..49d11bf7c8dca 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1461,7 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
   bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
                        !not(!eq(!find(name, "bfloat16"), -1)));
-  bit IsTF32 = !not(!eq(!find(name, "TF32"), -1));
+  bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)),
+                       !not(!eq(!find(name, "tensor_float32"), -1)));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
                                   !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
                                   !not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1557,22 +1558,23 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
   def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
 }
 
-// SPV_INTEL_tensor_float32_conversion
+// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion
 // Multiclass used to define at the same time both a demangled builtin records
 // and a corresponding convert builtin records.
-multiclass DemangledTF32ConvertBuiltin<string name1, string name2> {
+multiclass DemangledTF32RoundBuiltin<string name1, string name2> {
   // Create records for scalar and vector conversions.
   foreach i = ["", "2", "3", "4", "8", "16"] in {
-    def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
-    def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+    def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+    def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>;
   }
 }
 
-defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">;
+defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">;
+defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">;
 
 foreach conv = ["FToTF32INTEL"] in {
-  def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
-  def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+  def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>;
+  def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 2726203d253ad..945d3febe0bcf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
         {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
         {"SPV_KHR_float_controls2",
-         SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
+         SPIRV::Extension::Extension::SPV_KHR_float_controls2},
+         {"SPV_INTEL_tensor_float32_conversion",
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a04ed6a42c868..f0b938d681dba 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -441,11 +441,11 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
 def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
 def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
 
-// SPV_INTEL_tensor_float32_conversion
+// SPV_INTEL_bfloat16_conversion
 def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
 def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
 
-// SPV_INTEL_bfloat16_conversion
+// SPV_INTEL_tensor_float32_conversion
 def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
 
 // 3.42.12 Composite Instructions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index c252fc5897518..eac337c3c4246 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1567,7 +1567,7 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpRoundFToTF32INTEL:
     if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
-      Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
+      Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
     }
     break;
   case SPIRV::OpVariableLengthArrayINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7b2139a1c84a8..614e83ae9b286 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -503,7 +503,6 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
 defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
-defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
 defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
@@ -531,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d
 defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
+defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll
new file mode 100644
index 0000000000000..fa708ab022a85
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll
new file mode 100644
index 0000000000000..630b2fdd7696c
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll
new file mode 100644
index 0000000000000..dcad78d17bff7
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll
@@ -0,0 +1,62 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %}
+
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion 
+
+; CHECK: OpCapability TensorFloat32RoundingINTEL
+; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion"
+
+; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
+; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
+; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
+; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
+; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
+; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
+; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
+
+; CHECK: OpFunction %[[VoidTy]]
+; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
+; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
+; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]]
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]]
+; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]]
+
+; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] 
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]]
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]]
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]]
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]]
+; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(float %a, <8 x float> %in) {
+  %res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
+  %res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
+  %res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00)
+  ret void
+}
+
+declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)
+declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
+
+define dso_local spir_kernel void @test_ocl(float %a) {
+entry:
+  %res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00)
+  %res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer)
+  %res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer)
+  %res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer)
+  %res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer)
+  %res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer)
+  ret void
+}
+
+declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float)
+declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>)
+declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>)
+declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>)
+declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>)
+declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 90383265002a3..9c9eefd054fa6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_me
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
 def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 def SPV_INTEL_cache_controls                     : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
+def SPV_INTEL_tensor_float32_conversion          : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
       SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
       SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
       SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
-      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
+      SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
+      SPV_INTEL_tensor_float32_conversion
     ]>;
 
 //===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"B
   ];
 }
 
+def SPIRV_C_TensorFloat32RoundingINTEL                       : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_tensor_float32_conversion]>
+  ];
+}
+
 def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
       SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
       SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
-      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
+      SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
+      SPIRV_C_TensorFloat32RoundingINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL      : I32EnumAttrCase<"OpControlBarrie
 def SPIRV_OC_OpControlBarrierWaitINTEL        : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
 def SPIRV_OC_OpGroupIMulKHR                   : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
 def SPIRV_OC_OpGroupFMulKHR                   : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
+def SPIRV_OC_OpRoundFToTF32INTEL              : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
 
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
       SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
       SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
-      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
+      SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
+      SPIRV_OC_OpRoundFToTF32INTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
index 82d26e365fb24..b692c07122683 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -11,6 +11,7 @@
 // at (https://github.com/intel/llvm)
 // Supported extensions
 // * SPV_INTEL_bfloat16_conversion
+// * SPV_INTEL_tensor_float32_conversion
 //===----------------------------------------------------------------------===//
 
 
@@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
   let hasVerifier = 1;
 }
 
+// -----
+
+def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> {
+  let summary = "See extension SPV_INTEL_tensor_float32_conversion";
+
+  let description = [{
+    Convert value numerically from a 32-bit floating point type to tensor float32,
+    with rounding to the nearest even.
+
+    Result Type must be a scalar or vector of 32-bit floating-point type.
+    The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value.
+
+    Float Value must be a scalar or vector of floating-point type.
+    It must have the same number of components as Result Type. The component width must be 32 bits.
+
+    Results are computed per component.
+  
+
+    ```
+    convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use
+                          `:` operand-type `to` result-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.RoundFToTF32 %0 : f32 to f32
+    %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32>
+    ```
+
+  }];
+
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_tensor_float32_conversion]>,
+    Capability<[SPIRV_C_TensorFloat32RoundingINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+  );
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
 
 // -----
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc274673be..fc3e7308356bf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTELRoundFToTF32Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELRoundFToTF32Op::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getResult().getType();
+  // ODS checks that vector result type and vector operand type have the same
+  // shape.
+  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+    unsigned operandNumElements = vectorType.getNumElements();
+    unsigned resultNumElements =
+        llvm::cast<VectorType>(resultType).getNumElements();
+    if (operandNumElements != resultNumElements) {
+      return emitOpError(
+          "operand and result must have same number of elements");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.FConvertOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index bb15d018a6c44..aa5bee5796cfa 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" {
+  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+  // expected-error @+1 {{operand and result must have same number of elements}}
+  %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32>
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
index 6d2fd324363c6..53cf8bf8fbd62 100644
--- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL]
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.RoundFToTF32
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [TensorFloat32RoundingINTEL], [SPV_INTEL_tensor_float32_conversion]> {
+  // CHECK-LABEL: @f32_to_tf32
+  spirv.func @f32_to_tf32(%arg0 : f32) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @f32_to_tf32_vec
+  spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32>
+    %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32>
+    spirv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.INTEL.SplitBarrier
 //===----------------------------------------------------------------------===//

>From 6135f1414ed426c510c9c626ccc6bc876e2d0e22 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 29 Jul 2025 19:18:46 -0700
Subject: [PATCH 3/4] fix clang format

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 9 +++++----
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    | 2 +-
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 ++-
 3 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 03ca2ad1d8fa5..a3c9cb96d013b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -699,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call,
+                              Register(0));
 
   Register ScopeRegister =
       buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -2679,13 +2680,13 @@ static bool generateConvertInst(const StringRef DemangledCall,
       }
     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
                                           SPIRV::OpTypeFloat)) {
-      if(Builtin->IsTF32){
+      if (Builtin->IsTF32) {
         const auto *ST = static_cast<const SPIRVSubtarget *>(
-          &MIRBuilder.getMF().getSubtarget());
+            &MIRBuilder.getMF().getSubtarget());
         if (!ST->canUseExtension(
                 SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
           NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
-          IsRightComponentsNumber =
+        IsRightComponentsNumber =
             GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
             GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
         Opcode = SPIRV::OpRoundFToTF32INTEL;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 945d3febe0bcf..d9265f498973e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -103,7 +103,7 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
-         {"SPV_INTEL_tensor_float32_conversion",
+        {"SPV_INTEL_tensor_float32_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index eac337c3c4246..0cd9d7882a52a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1565,7 +1565,8 @@ void addInstrRequirements(const MachineInstr &MI,
     }
     break;
   case SPIRV::OpRoundFToTF32INTEL:
-    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
+    if (ST.canUseExtension(
+            SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
       Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
       Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
     }

>From 8528c2f9f755b14af347a1f46274014bf7749388 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 29 Jul 2025 20:07:24 -0700
Subject: [PATCH 4/4] fix the CI test failure

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index a3c9cb96d013b..25cdf72a658a8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -244,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
       "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
-      "Convert|Round"
+      "Convert|Round|"
       "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
   std::smatch Match;
   if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {



More information about the Mlir-commits mailing list