[llvm] [SPIRV] Add FPEncoding operand support for OpTypeFloat (PR #156871)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 12 13:25:16 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/156871

>From f44dccefc0dc98d43b583218def7e1880212e85a Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 4 Sep 2025 05:14:39 -0700
Subject: [PATCH 1/7] add the support for bfloat in SPIRV

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 26 ++++++++++++++++++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 13 ++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..0f258e03b23c8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1122,7 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
-  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
+
+  // Record the FPVariant of the floating-point registers in the
+  // VRegFPVariantMap.
+  MachineFunction *MF = &MIRBuilder.getMF();
+  Register TypeReg = getSPIRVTypeID(SpirvType);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+    } else {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+    }
+  }
+  VRegToTypeMap[MF][TypeReg] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -2088,3 +2100,15 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   }
   return false;
 }
+
+SPIRVGlobalRegistry::FPVariant
+SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+                                         const MachineFunction *MF) {
+  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
+  if (t != VRegFPVariantMap.end()) {
+    auto tt = t->second.find(VReg);
+    if (tt != t->second.end())
+      return tt->second;
+  }
+  return FPVariant::NONE;
+}
\ No newline at end of file
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..1f8c30dc01f7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,6 +29,10 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
+public:
+  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+
+private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -88,6 +92,11 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
+  // Maps floating point Registers to their FPVariant (float type kind), given
+  // the MachineFunction.
+  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+      VRegFPVariantMap;
+
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -422,6 +431,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
 
+  // Return the FPVariant of to the given floating-point regiester.
+  FPVariant getFPVariantForVReg(Register VReg,
+                                const MachineFunction *MF = nullptr);
+
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 

>From a67050e7d47c853c5721e08a7990eef34c4e25f2 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 4 Sep 2025 05:56:56 -0700
Subject: [PATCH 2/7] revert all the change

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 49 +++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 14 +++---
 2 files changed, 35 insertions(+), 28 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 0f258e03b23c8..adf6c64357a7e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1122,19 +1122,20 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
+  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
 
   // Record the FPVariant of the floating-point registers in the
   // VRegFPVariantMap.
-  MachineFunction *MF = &MIRBuilder.getMF();
-  Register TypeReg = getSPIRVTypeID(SpirvType);
-  if (Ty->isFloatingPointTy()) {
-    if (Ty->isBFloatTy()) {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
-    } else {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
-    }
-  }
-  VRegToTypeMap[MF][TypeReg] = SpirvType;
+  // MachineFunction *MF = &MIRBuilder.getMF();
+  // Register TypeReg = getSPIRVTypeID(SpirvType);
+  // if (Ty->isFloatingPointTy()) {
+  //   if (Ty->isBFloatTy()) {
+  //     VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+  //   } else {
+  //     VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+  //   }
+  // }
+  // VRegToTypeMap[MF][TypeReg] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -2101,14 +2102,20 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   return false;
 }
 
-SPIRVGlobalRegistry::FPVariant
-SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
-                                         const MachineFunction *MF) {
-  auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
-  if (t != VRegFPVariantMap.end()) {
-    auto tt = t->second.find(VReg);
-    if (tt != t->second.end())
-      return tt->second;
-  }
-  return FPVariant::NONE;
-}
\ No newline at end of file
+// SPIRVGlobalRegistry::FPVariant
+// SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+//                                          const MachineFunction *MF) {
+//   const MachineFunction *Func = MF ? MF : CurMF;
+//   DenseMap<const MachineFunction *,
+//            DenseMap<Register, FPVariant>>::const_iterator FuncIt =
+//       VRegFPVariantMap.find(Func);
+
+//   if (FuncIt != VRegFPVariantMap.end()) {
+//     const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+//     DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
+
+//     if (VRegIt != VRegMap.end())
+//       return VRegIt->second;
+//   }
+//   return FPVariant::NONE;
+// }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 1f8c30dc01f7f..fa397c5410dd8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,10 +29,10 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
-public:
-  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+// public:
+//   enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
 
-private:
+// private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -94,8 +94,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   // Maps floating point Registers to their FPVariant (float type kind), given
   // the MachineFunction.
-  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
-      VRegFPVariantMap;
+  // DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+  //     VRegFPVariantMap;
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
@@ -432,8 +432,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   void invalidateMachineInstr(MachineInstr *MI);
 
   // Return the FPVariant of to the given floating-point regiester.
-  FPVariant getFPVariantForVReg(Register VReg,
-                                const MachineFunction *MF = nullptr);
+  // FPVariant getFPVariantForVReg(Register VReg,
+  //                               const MachineFunction *MF = nullptr);
 
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);

>From 4e405b569661f34f25f34dd7d78c4737423e05fc Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 5 Sep 2025 13:11:58 -0700
Subject: [PATCH 3/7] add all the changes

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 51 +++++++++----------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 14 ++---
 2 files changed, 32 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index adf6c64357a7e..9ff5408f65d1a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1122,20 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
-  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
 
   // Record the FPVariant of the floating-point registers in the
   // VRegFPVariantMap.
-  // MachineFunction *MF = &MIRBuilder.getMF();
-  // Register TypeReg = getSPIRVTypeID(SpirvType);
-  // if (Ty->isFloatingPointTy()) {
-  //   if (Ty->isBFloatTy()) {
-  //     VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
-  //   } else {
-  //     VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
-  //   }
-  // }
-  // VRegToTypeMap[MF][TypeReg] = SpirvType;
+  MachineFunction *MF = &MIRBuilder.getMF();
+  Register TypeReg = getSPIRVTypeID(SpirvType);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
+    } else {
+      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
+    }
+  }
+  VRegToTypeMap[MF][TypeReg] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -2102,20 +2101,20 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   return false;
 }
 
-// SPIRVGlobalRegistry::FPVariant
-// SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
-//                                          const MachineFunction *MF) {
-//   const MachineFunction *Func = MF ? MF : CurMF;
-//   DenseMap<const MachineFunction *,
-//            DenseMap<Register, FPVariant>>::const_iterator FuncIt =
-//       VRegFPVariantMap.find(Func);
+SPIRVGlobalRegistry::FPVariant
+SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
+                                         const MachineFunction *MF) {
+  const MachineFunction *Func = MF ? MF : CurMF;
+  DenseMap<const MachineFunction *,
+           DenseMap<Register, FPVariant>>::const_iterator FuncIt =
+      VRegFPVariantMap.find(Func);
 
-//   if (FuncIt != VRegFPVariantMap.end()) {
-//     const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
-//     DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
+  if (FuncIt != VRegFPVariantMap.end()) {
+    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
+    DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
 
-//     if (VRegIt != VRegMap.end())
-//       return VRegIt->second;
-//   }
-//   return FPVariant::NONE;
-// }
+    if (VRegIt != VRegMap.end())
+      return VRegIt->second;
+  }
+  return FPVariant::NONE;
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index fa397c5410dd8..1f8c30dc01f7f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,10 +29,10 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
-// public:
-//   enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
+public:
+  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
 
-// private:
+private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -94,8 +94,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   // Maps floating point Registers to their FPVariant (float type kind), given
   // the MachineFunction.
-  // DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
-  //     VRegFPVariantMap;
+  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
+      VRegFPVariantMap;
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
@@ -432,8 +432,8 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   void invalidateMachineInstr(MachineInstr *MI);
 
   // Return the FPVariant of to the given floating-point regiester.
-  // FPVariant getFPVariantForVReg(Register VReg,
-  //                               const MachineFunction *MF = nullptr);
+  FPVariant getFPVariantForVReg(Register VReg,
+                                const MachineFunction *MF = nullptr);
 
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);

>From 01448f339341e9844d84cfb2359cd406b13db6b0 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Mon, 8 Sep 2025 05:38:28 -0700
Subject: [PATCH 4/7] change based on comment

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 9ff5408f65d1a..349c826846fd1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -2105,14 +2105,10 @@ SPIRVGlobalRegistry::FPVariant
 SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
                                          const MachineFunction *MF) {
   const MachineFunction *Func = MF ? MF : CurMF;
-  DenseMap<const MachineFunction *,
-           DenseMap<Register, FPVariant>>::const_iterator FuncIt =
-      VRegFPVariantMap.find(Func);
-
+  auto FuncIt = VRegFPVariantMap.find(Func);
   if (FuncIt != VRegFPVariantMap.end()) {
     const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
-    DenseMap<Register, FPVariant>::const_iterator VRegIt = VRegMap.find(VReg);
-
+    auto VRegIt = VRegMap.find(VReg);
     if (VRegIt != VRegMap.end())
       return VRegIt->second;
   }

>From 3b26b71e3b8e2bc9f161330b6772c41bc568d519 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 9 Sep 2025 06:53:23 -0700
Subject: [PATCH 5/7] change the appraoch for supporting bfloat, use fpencode
 enum class instead

---
 .../Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h |  5 ++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 50 ++++++++-----------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 16 ++----
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  2 +-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 26 ++++++++++
 5 files changed, 56 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index c2c08f8831307..d76180ce97e9e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
 #include "SPIRVGenTables.inc"
 } // namespace SpecConstantOpOperands
 
+namespace FPEncoding {
+#define GET_FPEncoding_DECL
+#include "SPIRVGenTables.inc"
+} // namespace FPEncoding
+
 struct ExtendedBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 349c826846fd1..115766ce886c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
   });
 }
 
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+                                    MachineIRBuilder &MIRBuilder,
+                                    SPIRV::FPEncoding::FPEncoding FPEncode) {
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addImm(Width)
+        .addImm(FPEncode);
+  });
+}
+
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
@@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
+                            SPIRV::FPEncoding::BFloat16KHR);
+    } else {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+    }
+  }
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
@@ -1122,19 +1140,7 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
                                          ExplicitLayoutRequired, EmitIR);
   TypesInProcessing.erase(Ty);
-
-  // Record the FPVariant of the floating-point registers in the
-  // VRegFPVariantMap.
-  MachineFunction *MF = &MIRBuilder.getMF();
-  Register TypeReg = getSPIRVTypeID(SpirvType);
-  if (Ty->isFloatingPointTy()) {
-    if (Ty->isBFloatTy()) {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
-    } else {
-      VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
-    }
-  }
-  VRegToTypeMap[MF][TypeReg] = SpirvType;
+  VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
 
   // TODO: We could end up with two SPIR-V types pointing to the same llvm type.
   // Is that a problem?
@@ -2100,17 +2106,3 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
   }
   return false;
 }
-
-SPIRVGlobalRegistry::FPVariant
-SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
-                                         const MachineFunction *MF) {
-  const MachineFunction *Func = MF ? MF : CurMF;
-  auto FuncIt = VRegFPVariantMap.find(Func);
-  if (FuncIt != VRegFPVariantMap.end()) {
-    const DenseMap<Register, FPVariant> &VRegMap = FuncIt->second;
-    auto VRegIt = VRegMap.find(VReg);
-    if (VRegIt != VRegMap.end())
-      return VRegIt->second;
-  }
-  return FPVariant::NONE;
-}
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 1f8c30dc01f7f..a648defa0a888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -29,10 +29,6 @@ using SPIRVType = const MachineInstr;
 using StructOffsetDecorator = std::function<void(Register)>;
 
 class SPIRVGlobalRegistry : public SPIRVIRMapping {
-public:
-  enum class FPVariant { NONE, IEEE_FLOAT, BRAIN_FLOAT };
-
-private:
   // Registers holding values which have types associated with them.
   // Initialized upon VReg definition in IRTranslator.
   // Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
@@ -92,11 +88,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // map of aliasing decorations to aliasing metadata
   std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;
 
-  // Maps floating point Registers to their FPVariant (float type kind), given
-  // the MachineFunction.
-  DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
-      VRegFPVariantMap;
-
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                              SPIRV::AccessQualifier::AccessQualifier AQ,
@@ -431,10 +422,6 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
   // structures referring this instruction.
   void invalidateMachineInstr(MachineInstr *MI);
 
-  // Return the FPVariant of to the given floating-point regiester.
-  FPVariant getFPVariantForVReg(Register VReg,
-                                const MachineFunction *MF = nullptr);
-
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
@@ -451,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+                            SPIRV::FPEncoding::FPEncoding FPEncode);
+
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 8d10cd0ffb3dd..496dcba17c10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
                   "$type = OpTypeFloat $width">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..be03f33104872 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
 def CooperativeMatrixOperandsOperand : OperandCategory;
 def SpecConstantOpOperandsOperand : OperandCategory;
 def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
+def FPEncodingOperand : OperandCategory;
 
 //===----------------------------------------------------------------------===//
 // Definition of the Environments
@@ -1996,3 +1997,28 @@ defm MatrixAPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x400,
 defm MatrixBPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixAPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixBPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define FPEncoding enum values and at the
+// same time SymbolicOperand entries with extensions.
+//===----------------------------------------------------------------------===//
+def FPEncoding : GenericEnum, Operand<i32> {
+  let FilterClass = "FPEncoding";
+  let NameField = "Name";
+  let ValueField = "Value";
+  let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class FPEncoding<string name, bits<32> value> {
+  string Name = name;
+  bits<32> Value = value;
+}
+
+multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
+  def NAME : FPEncoding<NAME, value>;
+  defm : SymbolicOperandWithRequirements<
+             FPEncodingOperand, value, NAME, 0, 0,
+             reqExtensions, [], []>;
+}
+
+defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;

>From 6c8352cb96cb6ac046304caeb16ddc879d5d9251 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Tue, 9 Sep 2025 14:20:05 -0700
Subject: [PATCH 6/7] solve the CI failure

---
 llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index be03f33104872..ed933f872d136 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -2021,4 +2021,4 @@ multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
              reqExtensions, [], []>;
 }
 
-defm BFloat16KHR : FPEncodingOperand<0, [SPV_KHR_bfloat16]>;
+defm BFloat16KHR : FPEncodingOperand<0, []>;

>From ac5fc5d9b9052781ad67e5da07420314e497e390 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Fri, 12 Sep 2025 13:24:15 -0700
Subject: [PATCH 7/7] add the test case for bfloat

---
 llvm/test/CodeGen/SPIRV/basic_float_types.ll | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index dfee1ace2205d..27bf0851f65ae 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -9,6 +9,7 @@ entry:
 ; CHECK-DAG: OpCapability Float64
 
 ; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16
+; CHECK-DAG:   %[[#bfloat:]] = OpTypeFloat 16 0
 ; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:   %[[#double:]] = OpTypeFloat 64
 
@@ -25,11 +26,13 @@ entry:
 ; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
 
 ; CHECK-DAG:     %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
+; CHECK-DAG:    %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
 ; CHECK-DAG:    %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
 ; CHECK-DAG:   %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
 ; CHECK-DAG:   %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
 ; CHECK-DAG:   %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
 ; CHECK-DAG:   %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
+; CHECK-DAG:  %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
 ; CHECK-DAG:  %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
 ; CHECK-DAG:  %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
 ; CHECK-DAG:  %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +43,9 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
   %half_Val = alloca half, align 2
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
+  %bfloat_Val = alloca bfloat, align 2
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
   %float_Val = alloca float, align 4
 
@@ -55,6 +61,9 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
   %half4_Val = alloca <4 x half>, align 8
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
+  %bfloat2_Val = alloca <2 x bfloat>, align 4
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
   %float2_Val = alloca <2 x float>, align 8
 



More information about the llvm-commits mailing list