[llvm] [SPIRV] Add FPEncoding operand support for OpTypeFloat (PR #156871)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 15 09:42:07 PDT 2025


https://github.com/YixingZhang007 updated https://github.com/llvm/llvm-project/pull/156871

>From 5b8b92db9e27191adfd3f4a78b35d632b26055c2 Mon Sep 17 00:00:00 2001
From: "Zhang, Yixing" <yixing.zhang at intel.com>
Date: Thu, 4 Sep 2025 05:14:39 -0700
Subject: [PATCH] add the support for bfloat in SPIRV

---
 .../Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h |  5 ++++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 22 ++++++++++++++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  3 +++
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  2 +-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 26 +++++++++++++++++++
 llvm/test/CodeGen/SPIRV/basic_float_types.ll  | 25 ++++++++++++++++--
 6 files changed, 78 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
index c2c08f8831307..d76180ce97e9e 100644
--- a/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
+++ b/llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h
@@ -232,6 +232,11 @@ namespace SpecConstantOpOperands {
 #include "SPIRVGenTables.inc"
 } // namespace SpecConstantOpOperands
 
+namespace FPEncoding {
+#define GET_FPEncoding_DECL
+#include "SPIRVGenTables.inc"
+} // namespace FPEncoding
+
 struct ExtendedBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index cfe24c84941a9..115766ce886c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -203,6 +203,18 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
   });
 }
 
+SPIRVType *
+SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
+                                    MachineIRBuilder &MIRBuilder,
+                                    SPIRV::FPEncoding::FPEncoding FPEncode) {
+  return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
+    return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
+        .addDef(createTypeVReg(MIRBuilder))
+        .addImm(Width)
+        .addImm(FPEncode);
+  });
+}
+
 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
@@ -1041,8 +1053,14 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     return Width == 1 ? getOpTypeBool(MIRBuilder)
                       : getOpTypeInt(Width, MIRBuilder, false);
   }
-  if (Ty->isFloatingPointTy())
-    return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+  if (Ty->isFloatingPointTy()) {
+    if (Ty->isBFloatTy()) {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder,
+                            SPIRV::FPEncoding::BFloat16KHR);
+    } else {
+      return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
+    }
+  }
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 7ef812828b7cc..a648defa0a888 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -438,6 +438,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
 
   SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
 
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+                            SPIRV::FPEncoding::FPEncoding FPEncode);
+
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 8d10cd0ffb3dd..496dcba17c10d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -167,7 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
 def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
 def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
                   "$type = OpTypeInt $width $signedness">;
-def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
+def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, variable_ops),
                   "$type = OpTypeFloat $width">;
 def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
                   "$type = OpTypeVector $compType $compCount">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..ed933f872d136 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -210,6 +210,7 @@ def CooperativeMatrixLayoutOperand : OperandCategory;
 def CooperativeMatrixOperandsOperand : OperandCategory;
 def SpecConstantOpOperandsOperand : OperandCategory;
 def MatrixMultiplyAccumulateOperandsOperand : OperandCategory;
+def FPEncodingOperand : OperandCategory;
 
 //===----------------------------------------------------------------------===//
 // Definition of the Environments
@@ -1996,3 +1997,28 @@ defm MatrixAPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x400,
 defm MatrixBPackedFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x800, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixAPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x1000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
 defm MatrixBPackedBFloat16INTEL :  MatrixMultiplyAccumulateOperandsOperand<0x2000, [SPV_INTEL_subgroup_matrix_multiply_accumulate]>;
+
+//===----------------------------------------------------------------------===//
+// Multiclass used to define FPEncoding enum values and at the
+// same time SymbolicOperand entries with extensions.
+//===----------------------------------------------------------------------===//
+def FPEncoding : GenericEnum, Operand<i32> {
+  let FilterClass = "FPEncoding";
+  let NameField = "Name";
+  let ValueField = "Value";
+  let PrintMethod = !strconcat("printSymbolicOperand<OperandCategory::", FilterClass, "Operand>");
+}
+
+class FPEncoding<string name, bits<32> value> {
+  string Name = name;
+  bits<32> Value = value;
+}
+
+multiclass FPEncodingOperand<bits<32> value, list<Extension> reqExtensions>{
+  def NAME : FPEncoding<NAME, value>;
+  defm : SymbolicOperandWithRequirements<
+             FPEncodingOperand, value, NAME, 0, 0,
+             reqExtensions, [], []>;
+}
+
+defm BFloat16KHR : FPEncodingOperand<0, []>;
diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index dfee1ace2205d..486f6358ce5de 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -1,6 +1,6 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUNx: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 define void @main() {
 entry:
@@ -8,7 +8,8 @@ entry:
 ; CHECK-DAG: OpCapability Float16
 ; CHECK-DAG: OpCapability Float64
 
-; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16
+; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16{{$}}
+; CHECK-DAG:   %[[#bfloat:]] = OpTypeFloat 16 0{{$}}
 ; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:   %[[#double:]] = OpTypeFloat 64
 
@@ -16,6 +17,10 @@ entry:
 ; CHECK-DAG:   %[[#v3half:]] = OpTypeVector %[[#half]] 3
 ; CHECK-DAG:   %[[#v4half:]] = OpTypeVector %[[#half]] 4
 
+; CHECK-DAG:  %[[#v2bfloat:]] = OpTypeVector %[[#bfloat]] 2
+; CHECK-DAG:  %[[#v3bfloat:]] = OpTypeVector %[[#bfloat]] 3
+; CHECK-DAG:  %[[#v4bfloat:]] = OpTypeVector %[[#bfloat]] 4
+
 ; CHECK-DAG:  %[[#v2float:]] = OpTypeVector %[[#float]] 2
 ; CHECK-DAG:  %[[#v3float:]] = OpTypeVector %[[#float]] 3
 ; CHECK-DAG:  %[[#v4float:]] = OpTypeVector %[[#float]] 4
@@ -25,11 +30,15 @@ entry:
 ; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
 
 ; CHECK-DAG:     %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
+; CHECK-DAG:    %[[#ptr_Function_bfloat:]] = OpTypePointer Function %[[#bfloat]]
 ; CHECK-DAG:    %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
 ; CHECK-DAG:   %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
 ; CHECK-DAG:   %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
 ; CHECK-DAG:   %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
 ; CHECK-DAG:   %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
+; CHECK-DAG:  %[[#ptr_Function_v2bfloat:]] = OpTypePointer Function %[[#v2bfloat]]
+; CHECK-DAG:  %[[#ptr_Function_v3bfloat:]] = OpTypePointer Function %[[#v3bfloat]]
+; CHECK-DAG:  %[[#ptr_Function_v4bfloat:]] = OpTypePointer Function %[[#v4bfloat]]
 ; CHECK-DAG:  %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
 ; CHECK-DAG:  %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
 ; CHECK-DAG:  %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -40,6 +49,9 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
   %half_Val = alloca half, align 2
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_bfloat]] Function
+  %bfloat_Val = alloca bfloat, align 2
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
   %float_Val = alloca float, align 4
 
@@ -55,6 +67,15 @@ entry:
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
   %half4_Val = alloca <4 x half>, align 8
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2bfloat]] Function
+  %bfloat2_Val = alloca <2 x bfloat>, align 4
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3bfloat]] Function
+  %bfloat3_Val = alloca <3 x bfloat>, align 8
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4bfloat]] Function
+  %bfloat4_Val = alloca <4 x bfloat>, align 8
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
   %float2_Val = alloca <2 x float>, align 8
 



More information about the llvm-commits mailing list