[llvm] 2286118 - [SPIRV] Enable `bfloat16` arithmetic (#166031)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 4 08:10:31 PST 2025


Author: Alex Voicu
Date: 2025-11-04T18:10:26+02:00
New Revision: 2286118e6f2cda56b78d2e6b0193dd6f0ca7b7ea

URL: https://github.com/llvm/llvm-project/commit/2286118e6f2cda56b78d2e6b0193dd6f0ca7b7ea
DIFF: https://github.com/llvm/llvm-project/commit/2286118e6f2cda56b78d2e6b0193dd6f0ca7b7ea.diff

LOG: [SPIRV] Enable `bfloat16` arithmetic (#166031)

Enable the `SPV_INTEL_bfloat16_arithmetic` extension, which allows arithmetic, relational and `OpExtInst` instructions to take `bfloat16` arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, but `bfloat16` is ubiquitous at this point and not supporting these ops is limiting.

Added: 
    llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
    llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll

Modified: 
    llvm/docs/SPIRVUsage.rst
    llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
    llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
    llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
    llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Removed: 
    


################################################################################
diff  --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 99f56a5cbc63a..749961356d23e 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -173,6 +173,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
      - Allows generating arbitrary width integer types.
    * - ``SPV_INTEL_bindless_images``
      - Adds instructions to convert convert unsigned integer handles to images, samplers and sampled images.
+   * - ``SPV_INTEL_bfloat16_arithmetic``
+     - Allows the use of 16-bit bfloat16 values in arithmetic and relational operators.
    * - ``SPV_INTEL_bfloat16_conversion``
      - Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
    * - ``SPV_INTEL_cache_controls``

diff  --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1fc90d0852aad..4fd220481cedb 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -294,6 +294,10 @@ void IRTranslator::addMachineCFGPred(CFGEdge Edge, MachineBasicBlock *NewPred) {
   MachinePreds[Edge].push_back(NewPred);
 }
 
+static bool targetSupportsBF16Type(const MachineFunction *MF) {
+  return MF->getTarget().getTargetTriple().isSPIRV();
+}
+
 static bool containsBF16Type(const User &U) {
   // BF16 cannot currently be represented by LLT, to avoid miscompiles we
   // prevent any instructions using them. FIXME: This can be removed once LLT
@@ -306,7 +310,7 @@ static bool containsBF16Type(const User &U) {
 
 bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
                                      MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
     return false;
 
   // Get or create a virtual register for each value.
@@ -328,7 +332,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
 
 bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
     return false;
 
   Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +352,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
 
 bool IRTranslator::translateCompare(const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
     return false;
 
   auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1573,7 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
     return false;
 
   uint32_t Flags = 0;
@@ -2688,7 +2692,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
 
 bool IRTranslator::translateInlineAsm(const CallBase &CB,
                                       MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(CB))
+  if (containsBF16Type(CB) && !targetSupportsBF16Type(MF))
     return false;
 
   const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
@@ -2779,7 +2783,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
+  if (containsBF16Type(U) && !targetSupportsBF16Type(MF))
     return false;
 
   const CallInst &CI = cast<CallInst>(U);

diff  --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index f0558ebcb6681..43b2869cecdf7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -107,6 +107,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
         {"SPV_INTEL_bindless_images",
          SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
+        {"SPV_INTEL_bfloat16_arithmetic",
+         SPIRV::Extension::Extension::SPV_INTEL_bfloat16_arithmetic},
         {"SPV_INTEL_bfloat16_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
         {"SPV_KHR_subgroup_rotate",

diff  --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index d154a06c6f313..e5ac76c405841 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1435,6 +1435,8 @@ void addInstrRequirements(const MachineInstr &MI,
       addPrintfRequirements(MI, Reqs, ST);
       break;
     }
+    // TODO: handle bfloat16 extended instructions when
+    // SPV_INTEL_bfloat16_arithmetic is enabled.
     break;
   }
   case SPIRV::OpAliasDomainDeclINTEL:
@@ -2060,7 +2062,64 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
     break;
   }
-
+  case SPIRV::OpFAddS:
+  case SPIRV::OpFSubS:
+  case SPIRV::OpFMulS:
+  case SPIRV::OpFDivS:
+  case SPIRV::OpFRemS:
+  case SPIRV::OpFMod:
+  case SPIRV::OpFNegate:
+  case SPIRV::OpFAddV:
+  case SPIRV::OpFSubV:
+  case SPIRV::OpFMulV:
+  case SPIRV::OpFDivV:
+  case SPIRV::OpFRemV:
+  case SPIRV::OpFNegateV: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Arithmetic instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
+  case SPIRV::OpOrdered:
+  case SPIRV::OpUnordered:
+  case SPIRV::OpFOrdEqual:
+  case SPIRV::OpFOrdNotEqual:
+  case SPIRV::OpFOrdLessThan:
+  case SPIRV::OpFOrdLessThanEqual:
+  case SPIRV::OpFOrdGreaterThan:
+  case SPIRV::OpFOrdGreaterThanEqual:
+  case SPIRV::OpFUnordEqual:
+  case SPIRV::OpFUnordNotEqual:
+  case SPIRV::OpFUnordLessThan:
+  case SPIRV::OpFUnordLessThanEqual:
+  case SPIRV::OpFUnordGreaterThan:
+  case SPIRV::OpFUnordGreaterThanEqual: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    MachineInstr *OperandDef = MRI.getVRegDef(MI.getOperand(2).getReg());
+    SPIRVType *TypeDef = MRI.getVRegDef(OperandDef->getOperand(1).getReg());
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Relational instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
   default:
     break;
   }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 267118364c371..1b4b29bbb160a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -387,6 +387,8 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
 defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
 defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
+defm SPV_INTEL_bfloat16_arithmetic
+    : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +572,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ArithmeticINTEL : CapabilityOperand<6226, 0, 0, [SPV_INTEL_bfloat16_arithmetic], []>;
 defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
 defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
new file mode 100644
index 0000000000000..4cabddb94df25
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
@@ -0,0 +1,142 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Arithmetic instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[NEG:%.*]] "neg"
+; CHECK-DAG: OpName [[NEGV:%.*]] "negv"
+; CHECK-DAG: OpName [[ADD:%.*]] "add"
+; CHECK-DAG: OpName [[ADDV:%.*]] "addv"
+; CHECK-DAG: OpName [[SUB:%.*]] "sub"
+; CHECK-DAG: OpName [[SUBV:%.*]] "subv"
+; CHECK-DAG: OpName [[MUL:%.*]] "mul"
+; CHECK-DAG: OpName [[MULV:%.*]] "mulv"
+; CHECK-DAG: OpName [[DIV:%.*]] "div"
+; CHECK-DAG: OpName [[DIVV:%.*]] "divv"
+; CHECK-DAG: OpName [[REM:%.*]] "rem"
+; CHECK-DAG: OpName [[REMV:%.*]] "remv"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 4
+
+; CHECK-DAG: [[NEG]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOAT]] [[X]]
+define spir_func bfloat @neg(bfloat %x) {
+entry:
+  %r = fneg bfloat %x
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[NEGV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOATV]] [[X]]
+define spir_func <4 x bfloat> @negv(<4 x bfloat> %x) {
+entry:
+  %r = fneg <4 x bfloat> %x
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[ADD]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @add(bfloat %x, bfloat %y) {
+entry:
+  %r = fadd bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[ADDV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @addv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fadd <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[SUB]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @sub(bfloat %x, bfloat %y) {
+entry:
+  %r = fsub bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[SUBV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @subv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fsub <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[MUL]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @mul(bfloat %x, bfloat %y) {
+entry:
+  %r = fmul bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[MULV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @mulv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fmul <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[DIV]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @div(bfloat %x, bfloat %y) {
+entry:
+  %r = fdiv bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[DIVV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @divv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fdiv <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[REM]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @rem(bfloat %x, bfloat %y) {
+entry:
+  %r = frem bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[REMV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @remv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = frem <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}

diff  --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
new file mode 100644
index 0000000000000..3774791d58f87
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
@@ -0,0 +1,376 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Relational instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[UEQ:%.*]] "test_ueq"
+; CHECK-DAG: OpName [[OEQ:%.*]] "test_oeq"
+; CHECK-DAG: OpName [[UNE:%.*]] "test_une"
+; CHECK-DAG: OpName [[ONE:%.*]] "test_one"
+; CHECK-DAG: OpName [[ULT:%.*]] "test_ult"
+; CHECK-DAG: OpName [[OLT:%.*]] "test_olt"
+; CHECK-DAG: OpName [[ULE:%.*]] "test_ule"
+; CHECK-DAG: OpName [[OLE:%.*]] "test_ole"
+; CHECK-DAG: OpName [[UGT:%.*]] "test_ugt"
+; CHECK-DAG: OpName [[OGT:%.*]] "test_ogt"
+; CHECK-DAG: OpName [[UGE:%.*]] "test_uge"
+; CHECK-DAG: OpName [[OGE:%.*]] "test_oge"
+; CHECK-DAG: OpName [[UNO:%.*]] "test_uno"
+; CHECK-DAG: OpName [[ORD:%.*]] "test_ord"
+; CHECK-DAG: OpName [[v3UEQ:%.*]] "test_v3_ueq"
+; CHECK-DAG: OpName [[v3OEQ:%.*]] "test_v3_oeq"
+; CHECK-DAG: OpName [[v3UNE:%.*]] "test_v3_une"
+; CHECK-DAG: OpName [[v3ONE:%.*]] "test_v3_one"
+; CHECK-DAG: OpName [[v3ULT:%.*]] "test_v3_ult"
+; CHECK-DAG: OpName [[v3OLT:%.*]] "test_v3_olt"
+; CHECK-DAG: OpName [[v3ULE:%.*]] "test_v3_ule"
+; CHECK-DAG: OpName [[v3OLE:%.*]] "test_v3_ole"
+; CHECK-DAG: OpName [[v3UGT:%.*]] "test_v3_ugt"
+; CHECK-DAG: OpName [[v3OGT:%.*]] "test_v3_ogt"
+; CHECK-DAG: OpName [[v3UGE:%.*]] "test_v3_uge"
+; CHECK-DAG: OpName [[v3OGE:%.*]] "test_v3_oge"
+; CHECK-DAG: OpName [[v3UNO:%.*]] "test_v3_uno"
+; CHECK-DAG: OpName [[v3ORD:%.*]] "test_v3_ord"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 3
+
+; CHECK:      [[UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ueq(bfloat %a, bfloat %b) {
+  %r = fcmp ueq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oeq(bfloat %a, bfloat %b) {
+  %r = fcmp oeq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_une(bfloat %a, bfloat %b) {
+  %r = fcmp une bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_one(bfloat %a, bfloat %b) {
+  %r = fcmp one bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ult(bfloat %a, bfloat %b) {
+  %r = fcmp ult bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_olt(bfloat %a, bfloat %b) {
+  %r = fcmp olt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ULE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ule(bfloat %a, bfloat %b) {
+  %r = fcmp ule bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OLE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ole(bfloat %a, bfloat %b) {
+  %r = fcmp ole bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ugt(bfloat %a, bfloat %b) {
+  %r = fcmp ugt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ogt(bfloat %a, bfloat %b) {
+  %r = fcmp ogt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_uge(bfloat %a, bfloat %b) {
+  %r = fcmp uge bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oge(bfloat %a, bfloat %b) {
+  %r = fcmp oge bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ORD]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpOrdered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ord(bfloat %a, bfloat %b) {
+  %r = fcmp ord bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UNO]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpUnordered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_uno(bfloat %a, bfloat %b) {
+  %r = fcmp uno bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[v3UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ueq(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ueq <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_oeq(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp oeq <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_une(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp une <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_one(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp one <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ult(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ult <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_olt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp olt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ULE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ule(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ule <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OLE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ole(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ole <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ugt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ugt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ogt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ogt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_uge(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp uge <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_oge(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp oge <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ORD]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpOrdered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ord(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ord <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UNO]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpUnordered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_uno(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp uno <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}


        


More information about the llvm-commits mailing list