[llvm] [SPIRV] Enable `bfloat16` arithmetic (PR #166031)

Alex Voicu via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 1 17:51:26 PDT 2025


https://github.com/AlexVlx created https://github.com/llvm/llvm-project/pull/166031

Enable the `SPV_INTEL_bfloat16_arithmetic` extension, which allows arithmetic, relational and `OpExtInst` instructions to take `bfloat16` arguments. This patch only adds support to arithmetic and relational ops. The extension itself is rather fresh, but `bfloat16` is ubiquitous at this point and not supporting these ops is limiting.

>From 3e0abe10c53e08cc1ebf16662d8d58cd7245f70f Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Thu, 15 May 2025 23:27:42 +0100
Subject: [PATCH 1/2] Add pass which forwards unimplemented math builtins /
 libcalls to the HIPSTDPAR runtime component.

---
 .../llvm/Transforms/HipStdPar/HipStdPar.h     |   7 +
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../lib/Target/AMDGPU/AMDGPUTargetMachine.cpp |   8 +-
 llvm/lib/Transforms/HipStdPar/HipStdPar.cpp   | 117 +++++++++
 llvm/test/Transforms/HipStdPar/math-fixup.ll  | 240 ++++++++++++++++++
 5 files changed, 371 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Transforms/HipStdPar/math-fixup.ll

diff --git a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
index 5ff38bdf04812..27195051ed7eb 100644
--- a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
+++ b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
   static bool isRequired() { return true; }
 };
 
+class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+
+  static bool isRequired() { return true; }
+};
+
 } // namespace llvm
 
 #endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 94dabe290213d..3acdbf4d49fde 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
 MODULE_PASS("globalopt", GlobalOptPass())
 MODULE_PASS("globalsplit", GlobalSplitPass())
 MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
+MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
 MODULE_PASS("hipstdpar-select-accelerator-code",
             HipStdParAcceleratorCodeSelectionPass())
 MODULE_PASS("hotcoldsplit", HotColdSplittingPass())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index ccb251b730f16..c3f8cee1e1783 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
           // When we are not using -fgpu-rdc, we can run accelerator code
           // selection relatively early, but still after linking to prevent
           // eager removal of potentially reachable symbols.
-          if (EnableHipStdPar)
+          if (EnableHipStdPar) {
+            PM.addPass(HipStdParMathFixupPass());
             PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+          }
           PM.addPass(AMDGPUPrintfRuntimeBindingPass());
         }
 
@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // selection after linking to prevent, otherwise we end up removing
         // potentially reachable symbols that were exported as external in other
         // modules.
-        if (EnableHipStdPar)
+        if (EnableHipStdPar) {
+          PM.addPass(HipStdParMathFixupPass());
           PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+        }
         // We want to support the -lto-partitions=N option as "best effort".
         // For that, we need to lower LDS earlier in the pipeline before the
         // module is partitioned for codegen.
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
index 5a87cf8c83d79..815878089c69e 100644
--- a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -37,6 +37,16 @@
 //    memory that ends up in one of the runtime equivalents, since this can
 //    happen if e.g. a library that was compiled without interposition returns
 //    an allocation that can be validly passed to `free`.
+//
+// 3. MathFixup (required): Some accelerators might have an incomplete
+//    implementation for the intrinsics used to implement some of the math
+//    functions in <cmath> / their corresponding libcall lowerings. Since this
+//    can vary quite significantly between accelerators, we replace calls to a
+//    set of intrinsics / lib functions known to be problematic with calls to a
+//    HIPSTDPAR specific forwarding layer, which gives an uniform interface for
+//    accelerators to implement in their own runtime components. This pass
+//    should run before AcceleratorCodeSelection so as to prevent the spurious
+//    removal of the HIPSTDPAR specific forwarding functions.
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
 
   return PreservedAnalyses::none();
 }
+
+static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
+    {"acosh", "__hipstdpar_acosh_f64"},
+    {"acoshf", "__hipstdpar_acosh_f32"},
+    {"asinh", "__hipstdpar_asinh_f64"},
+    {"asinhf", "__hipstdpar_asinh_f32"},
+    {"atanh", "__hipstdpar_atanh_f64"},
+    {"atanhf", "__hipstdpar_atanh_f32"},
+    {"cbrt", "__hipstdpar_cbrt_f64"},
+    {"cbrtf", "__hipstdpar_cbrt_f32"},
+    {"erf", "__hipstdpar_erf_f64"},
+    {"erff", "__hipstdpar_erf_f32"},
+    {"erfc", "__hipstdpar_erfc_f64"},
+    {"erfcf", "__hipstdpar_erfc_f32"},
+    {"fdim", "__hipstdpar_fdim_f64"},
+    {"fdimf", "__hipstdpar_fdim_f32"},
+    {"expm1", "__hipstdpar_expm1_f64"},
+    {"expm1f", "__hipstdpar_expm1_f32"},
+    {"hypot", "__hipstdpar_hypot_f64"},
+    {"hypotf", "__hipstdpar_hypot_f32"},
+    {"ilogb", "__hipstdpar_ilogb_f64"},
+    {"ilogbf", "__hipstdpar_ilogb_f32"},
+    {"lgamma", "__hipstdpar_lgamma_f64"},
+    {"lgammaf", "__hipstdpar_lgamma_f32"},
+    {"log1p", "__hipstdpar_log1p_f64"},
+    {"log1pf", "__hipstdpar_log1p_f32"},
+    {"logb", "__hipstdpar_logb_f64"},
+    {"logbf", "__hipstdpar_logb_f32"},
+    {"nextafter", "__hipstdpar_nextafter_f64"},
+    {"nextafterf", "__hipstdpar_nextafter_f32"},
+    {"nexttoward", "__hipstdpar_nexttoward_f64"},
+    {"nexttowardf", "__hipstdpar_nexttoward_f32"},
+    {"remainder", "__hipstdpar_remainder_f64"},
+    {"remainderf", "__hipstdpar_remainder_f32"},
+    {"remquo", "__hipstdpar_remquo_f64"},
+    {"remquof", "__hipstdpar_remquo_f32"},
+    {"scalbln", "__hipstdpar_scalbln_f64"},
+    {"scalblnf", "__hipstdpar_scalbln_f32"},
+    {"scalbn", "__hipstdpar_scalbn_f64"},
+    {"scalbnf", "__hipstdpar_scalbn_f32"},
+    {"tgamma", "__hipstdpar_tgamma_f64"},
+    {"tgammaf", "__hipstdpar_tgamma_f32"}};
+
+PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (M.empty())
+    return PreservedAnalyses::all();
+
+  SmallVector<std::pair<Function *, std::string>> ToReplace;
+  for (auto &&F : M) {
+    if (!F.hasName())
+      continue;
+
+    auto N = F.getName().str();
+    auto ID = F.getIntrinsicID();
+
+    switch (ID) {
+    case Intrinsic::not_intrinsic: {
+      auto It = find_if(MathLibToHipStdPar,
+                        [&](auto &&M) { return M.first == N; });
+      if (It == std::cend(MathLibToHipStdPar))
+        continue;
+      ToReplace.emplace_back(&F, It->second);
+      break;
+    }
+    case Intrinsic::acos:
+    case Intrinsic::asin:
+    case Intrinsic::atan:
+    case Intrinsic::atan2:
+    case Intrinsic::cosh:
+    case Intrinsic::modf:
+    case Intrinsic::sinh:
+    case Intrinsic::tan:
+    case Intrinsic::tanh:
+      break;
+    default: {
+      if (F.getReturnType()->isDoubleTy()) {
+        switch (ID) {
+        case Intrinsic::cos:
+        case Intrinsic::exp:
+        case Intrinsic::exp2:
+        case Intrinsic::log:
+        case Intrinsic::log10:
+        case Intrinsic::log2:
+        case Intrinsic::pow:
+        case Intrinsic::sin:
+          break;
+        default:
+          continue;
+        }
+        break;
+      }
+      continue;
+    }
+    }
+
+    llvm::replace(N, '.', '_');
+    N.replace(0, sizeof("llvm"), "__hipstdpar_");
+    ToReplace.emplace_back(&F, std::move(N));
+  }
+  for (auto &&F : ToReplace)
+    F.first->replaceAllUsesWith(M.getOrInsertFunction(
+        F.second, F.first->getFunctionType()).getCallee());
+
+  return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/HipStdPar/math-fixup.ll b/llvm/test/Transforms/HipStdPar/math-fixup.ll
new file mode 100644
index 0000000000000..e0e2ca79c0843
--- /dev/null
+++ b/llvm/test/Transforms/HipStdPar/math-fixup.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -passes=hipstdpar-math-fixup %s | FileCheck %s
+
+define void @foo(double noundef %dbl, float noundef %flt, i32 noundef %quo) #0 {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: double noundef [[DBL:%.*]], float noundef [[FLT:%.*]], i32 noundef [[QUO:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[QUO_ADDR:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 [[QUO]], ptr [[QUO_ADDR]], align 4
+entry:
+  %quo.addr = alloca i32, align 4
+  store i32 %quo, ptr %quo.addr, align 4
+  ; CHECK-NEXT:    [[TMP0:%.*]] = tail call contract double @llvm.fabs.f64(double [[DBL]])
+  %0 = tail call contract double @llvm.fabs.f64(double %dbl)
+  ; CHECK-NEXT:    [[TMP1:%.*]] = tail call contract float @llvm.fabs.f32(float [[FLT]])
+  %1 = tail call contract float @llvm.fabs.f32(float %flt)
+  ; CHECK-NEXT:    [[CALL:%.*]] = tail call contract double @__hipstdpar_remainder_f64(double noundef [[TMP0]], double noundef [[TMP0]]) #[[ATTR4:[0-9]+]]
+  %call = tail call contract double @remainder(double noundef %0, double noundef %0) #4
+  ; CHECK-NEXT:    [[CALL1:%.*]] = tail call contract float @__hipstdpar_remainder_f32(float noundef [[TMP1]], float noundef [[TMP1]]) #[[ATTR4]]
+  %call1 = tail call contract float @remainderf(float noundef %1, float noundef %1) #4
+  ; CHECK-NEXT:    [[CALL2:%.*]] = call contract double @__hipstdpar_remquo_f64(double noundef [[CALL]], double noundef [[CALL]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3:[0-9]+]]
+  %call2 = call contract double @remquo(double noundef %call, double noundef %call, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[CALL3:%.*]] = call contract float @__hipstdpar_remquo_f32(float noundef [[CALL1]], float noundef [[CALL1]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3]]
+  %call3 = call contract float @remquof(float noundef %call1, float noundef %call1, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[CALL2]], double [[CALL2]], double [[CALL2]])
+  %2 = call contract double @llvm.fma.f64(double %call2, double %call2, double %call2)
+  ; CHECK-NEXT:    [[TMP3:%.*]] = call contract float @llvm.fma.f32(float [[CALL3]], float [[CALL3]], float [[CALL3]])
+  %3 = call contract float @llvm.fma.f32(float %call3, float %call3, float %call3)
+  ; CHECK-NEXT:    [[CALL4:%.*]] = call contract double @__hipstdpar_fdim_f64(double noundef [[TMP2]], double noundef [[TMP2]]) #[[ATTR4]]
+  %call4 = call contract double @fdim(double noundef %2, double noundef %2) #4
+  ; CHECK-NEXT:    [[CALL5:%.*]] = call contract float @__hipstdpar_fdim_f32(float noundef [[TMP3]], float noundef [[TMP3]]) #[[ATTR4]]
+  %call5 = call contract float @fdimf(float noundef %3, float noundef %3) #4
+  ; CHECK-NEXT:    [[TMP4:%.*]] = call contract double @__hipstdpar_exp_f64(double [[CALL4]])
+  %4 = call contract double @llvm.exp.f64(double %call4)
+  ; CHECK-NEXT:    [[TMP5:%.*]] = call contract float @llvm.exp.f32(float [[CALL5]])
+  %5 = call contract float @llvm.exp.f32(float %call5)
+  ; CHECK-NEXT:    [[TMP6:%.*]] = call contract double @__hipstdpar_exp2_f64(double [[TMP4]])
+  %6 = call contract double @llvm.exp2.f64(double %4)
+  ; CHECK-NEXT:    [[TMP7:%.*]] = call contract float @llvm.exp2.f32(float [[TMP5]])
+  %7 = call contract float @llvm.exp2.f32(float %5)
+  ; CHECK-NEXT:    [[CALL6:%.*]] = call contract double @__hipstdpar_expm1_f64(double noundef [[TMP6]]) #[[ATTR4]]
+  %call6 = call contract double @expm1(double noundef %6) #4
+  ; CHECK-NEXT:    [[CALL7:%.*]] = call contract float @__hipstdpar_expm1_f32(float noundef [[TMP7]]) #[[ATTR4]]
+  %call7 = call contract float @expm1f(float noundef %7) #4
+  ; CHECK-NEXT:    [[TMP8:%.*]] = call contract double @__hipstdpar_log_f64(double [[CALL6]])
+  %8 = call contract double @llvm.log.f64(double %call6)
+  ; CHECK-NEXT:    [[TMP9:%.*]] = call contract float @llvm.log.f32(float [[CALL7]])
+  %9 = call contract float @llvm.log.f32(float %call7)
+  ; CHECK-NEXT:    [[TMP10:%.*]] = call contract double @__hipstdpar_log10_f64(double [[TMP8]])
+  %10 = call contract double @llvm.log10.f64(double %8)
+  ; CHECK-NEXT:    [[TMP11:%.*]] = call contract float @llvm.log10.f32(float [[TMP9]])
+  %11 = call contract float @llvm.log10.f32(float %9)
+  ; CHECK-NEXT:    [[TMP12:%.*]] = call contract double @__hipstdpar_log2_f64(double [[TMP10]])
+  %12 = call contract double @llvm.log2.f64(double %10)
+  ; CHECK-NEXT:    [[TMP13:%.*]] = call contract float @llvm.log2.f32(float [[TMP11]])
+  %13 = call contract float @llvm.log2.f32(float %11)
+  ; CHECK-NEXT:    [[CALL8:%.*]] = call contract double @__hipstdpar_log1p_f64(double noundef [[TMP12]]) #[[ATTR4]]
+  %call8 = call contract double @log1p(double noundef %12) #4
+  ; CHECK-NEXT:    [[CALL9:%.*]] = call contract float @__hipstdpar_log1p_f32(float noundef [[TMP13]]) #[[ATTR4]]
+  %call9 = call contract float @log1pf(float noundef %13) #4
+  ; CHECK-NEXT:    [[TMP14:%.*]] = call contract float @llvm.pow.f32(float [[CALL9]], float [[CALL9]])
+  %14 = call contract float @llvm.pow.f32(float %call9, float %call9)
+  ; CHECK-NEXT:    [[TMP15:%.*]] = call contract double @llvm.sqrt.f64(double [[CALL8]])
+  %15 = call contract double @llvm.sqrt.f64(double %call8)
+  ; CHECK-NEXT:    [[TMP16:%.*]] = call contract float @llvm.sqrt.f32(float [[TMP14]])
+  %16 = call contract float @llvm.sqrt.f32(float %14)
+  ; CHECK-NEXT:    [[CALL10:%.*]] = call contract double @__hipstdpar_cbrt_f64(double noundef [[TMP15]]) #[[ATTR4]]
+  %call10 = call contract double @cbrt(double noundef %15) #4
+  ; CHECK-NEXT:    [[CALL11:%.*]] = call contract float @__hipstdpar_cbrt_f32(float noundef [[TMP16]]) #[[ATTR4]]
+  %call11 = call contract float @cbrtf(float noundef %16) #4
+  ; CHECK-NEXT:    [[CALL12:%.*]] = call contract double @__hipstdpar_hypot_f64(double noundef [[CALL10]], double noundef [[CALL10]]) #[[ATTR4]]
+  %call12 = call contract double @hypot(double noundef %call10, double noundef %call10) #4
+  ; CHECK-NEXT:    [[CALL13:%.*]] = call contract float @__hipstdpar_hypot_f32(float noundef [[CALL11]], float noundef [[CALL11]]) #[[ATTR4]]
+  %call13 = call contract float @hypotf(float noundef %call11, float noundef %call11) #4
+  ; CHECK-NEXT:    [[TMP17:%.*]] = call contract float @llvm.sin.f32(float [[CALL13]])
+  %17 = call contract float @llvm.sin.f32(float %call13)
+  ; CHECK-NEXT:    [[TMP18:%.*]] = call contract float @llvm.cos.f32(float [[TMP17]])
+  %18 = call contract float @llvm.cos.f32(float %17)
+  ; CHECK-NEXT:    [[TMP19:%.*]] = call contract double @__hipstdpar_tan_f64(double [[CALL12]])
+  %19 = call contract double @llvm.tan.f64(double %call12)
+  ; CHECK-NEXT:    [[TMP20:%.*]] = call contract double @__hipstdpar_asin_f64(double [[TMP19]])
+  %20 = call contract double @llvm.asin.f64(double %19)
+  ; CHECK-NEXT:    [[TMP21:%.*]] = call contract double @__hipstdpar_acos_f64(double [[TMP20]])
+  %21 = call contract double @llvm.acos.f64(double %20)
+  ; CHECK-NEXT:    [[TMP22:%.*]] = call contract double @__hipstdpar_atan_f64(double [[TMP21]])
+  %22 = call contract double @llvm.atan.f64(double %21)
+  ; CHECK-NEXT:    [[TMP23:%.*]] = call contract double @__hipstdpar_atan2_f64(double [[TMP22]], double [[TMP22]])
+  %23 = call contract double @llvm.atan2.f64(double %22, double %22)
+  ; CHECK-NEXT:    [[TMP24:%.*]] = call contract double @__hipstdpar_sinh_f64(double [[TMP23]])
+  %24 = call contract double @llvm.sinh.f64(double %23)
+  ; CHECK-NEXT:    [[TMP25:%.*]] = call contract double @__hipstdpar_cosh_f64(double [[TMP24]])
+  %25 = call contract double @llvm.cosh.f64(double %24)
+  ; CHECK-NEXT:    [[TMP26:%.*]] = call contract double @__hipstdpar_tanh_f64(double [[TMP25]])
+  %26 = call contract double @llvm.tanh.f64(double %25)
+  ; CHECK-NEXT:    [[CALL14:%.*]] = call contract double @__hipstdpar_asinh_f64(double noundef [[TMP26]]) #[[ATTR4]]
+  %call14 = call contract double @asinh(double noundef %26) #4
+  ; CHECK-NEXT:    [[CALL15:%.*]] = call contract float @__hipstdpar_asinh_f32(float noundef [[TMP18]]) #[[ATTR4]]
+  %call15 = call contract float @asinhf(float noundef %18) #4
+  ; CHECK-NEXT:    [[CALL16:%.*]] = call contract double @__hipstdpar_acosh_f64(double noundef [[CALL14]]) #[[ATTR4]]
+  %call16 = call contract double @acosh(double noundef %call14) #4
+  ; CHECK-NEXT:    [[CALL17:%.*]] = call contract float @__hipstdpar_acosh_f32(float noundef [[CALL15]]) #[[ATTR4]]
+  %call17 = call contract float @acoshf(float noundef %call15) #4
+  ; CHECK-NEXT:    [[CALL18:%.*]] = call contract double @__hipstdpar_atanh_f64(double noundef [[CALL16]]) #[[ATTR4]]
+  %call18 = call contract double @atanh(double noundef %call16) #4
+  ; CHECK-NEXT:    [[CALL19:%.*]] = call contract float @__hipstdpar_atanh_f32(float noundef [[CALL17]]) #[[ATTR4]]
+  %call19 = call contract float @atanhf(float noundef %call17) #4
+  ; CHECK-NEXT:    [[CALL20:%.*]] = call contract double @__hipstdpar_erf_f64(double noundef [[CALL18]]) #[[ATTR4]]
+  %call20 = call contract double @erf(double noundef %call18) #4
+  ; CHECK-NEXT:    [[CALL21:%.*]] = call contract float @__hipstdpar_erf_f32(float noundef [[CALL19]]) #[[ATTR4]]
+  %call21 = call contract float @erff(float noundef %call19) #4
+  ; CHECK-NEXT:    [[CALL22:%.*]] = call contract double @__hipstdpar_erfc_f64(double noundef [[CALL20]]) #[[ATTR4]]
+  %call22 = call contract double @erfc(double noundef %call20) #4
+  ; CHECK-NEXT:    [[CALL23:%.*]] = call contract float @__hipstdpar_erfc_f32(float noundef [[CALL21]]) #[[ATTR4]]
+  %call23 = call contract float @erfcf(float noundef %call21) #4
+  ; CHECK-NEXT:    [[CALL24:%.*]] = call contract double @__hipstdpar_tgamma_f64(double noundef [[CALL22]]) #[[ATTR4]]
+  %call24 = call contract double @tgamma(double noundef %call22) #4
+  ; CHECK-NEXT:    [[CALL25:%.*]] = call contract float @__hipstdpar_tgamma_f32(float noundef [[CALL23]]) #[[ATTR4]]
+  %call25 = call contract float @tgammaf(float noundef %call23) #4
+  ; CHECK-NEXT:    [[CALL26:%.*]] = call contract double @__hipstdpar_lgamma_f64(double noundef [[CALL24]]) #[[ATTR3]]
+  %call26 = call contract double @lgamma(double noundef %call24) #5
+  ; CHECK-NEXT:    [[CALL27:%.*]] = call contract float @__hipstdpar_lgamma_f32(float noundef [[CALL25]]) #[[ATTR3]]
+  %call27 = call contract float @lgammaf(float noundef %call25) #5
+  ret void
+}
+
+declare double @llvm.fabs.f64(double) #1
+
+declare float @llvm.fabs.f32(float) #1
+
+declare hidden double @remainder(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @remainderf(float noundef, float noundef) local_unnamed_addr #2
+
+declare hidden double @remquo(double noundef, double noundef, ptr noundef) local_unnamed_addr #3
+
+declare hidden float @remquof(float noundef, float noundef, ptr noundef) local_unnamed_addr #3
+
+declare double @llvm.fma.f64(double, double, double) #1
+
+declare float @llvm.fma.f32(float, float, float) #1
+
+declare hidden double @fdim(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @fdimf(float noundef, float noundef) local_unnamed_addr #2
+
+declare double @llvm.exp.f64(double) #1
+
+declare float @llvm.exp.f32(float) #1
+
+declare double @llvm.exp2.f64(double) #1
+
+declare float @llvm.exp2.f32(float) #1
+
+declare hidden double @expm1(double noundef) local_unnamed_addr #2
+
+declare hidden float @expm1f(float noundef) local_unnamed_addr #2
+
+declare double @llvm.log.f64(double) #1
+
+declare float @llvm.log.f32(float) #1
+
+declare double @llvm.log10.f64(double) #1
+
+declare float @llvm.log10.f32(float) #1
+
+declare double @llvm.log2.f64(double) #1
+
+declare float @llvm.log2.f32(float) #1
+
+declare hidden double @log1p(double noundef) local_unnamed_addr #2
+
+declare hidden float @log1pf(float noundef) local_unnamed_addr #2
+
+declare float @llvm.pow.f32(float, float) #1
+
+declare double @llvm.sqrt.f64(double) #1
+
+declare float @llvm.sqrt.f32(float) #1
+
+declare hidden double @cbrt(double noundef) local_unnamed_addr #2
+
+declare hidden float @cbrtf(float noundef) local_unnamed_addr #2
+
+declare hidden double @hypot(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @hypotf(float noundef, float noundef) local_unnamed_addr #2
+
+declare float @llvm.sin.f32(float) #1
+
+declare float @llvm.cos.f32(float) #1
+
+declare double @llvm.tan.f64(double) #1
+
+declare double @llvm.asin.f64(double) #1
+
+declare double @llvm.acos.f64(double) #1
+
+declare double @llvm.atan.f64(double) #1
+
+declare double @llvm.atan2.f64(double, double) #1
+
+declare double @llvm.sinh.f64(double) #1
+
+declare double @llvm.cosh.f64(double) #1
+
+declare double @llvm.tanh.f64(double) #1
+
+declare hidden double @asinh(double noundef) local_unnamed_addr #2
+
+declare hidden float @asinhf(float noundef) local_unnamed_addr #2
+
+declare hidden double @acosh(double noundef) local_unnamed_addr #2
+
+declare hidden float @acoshf(float noundef) local_unnamed_addr #2
+
+declare hidden double @atanh(double noundef) local_unnamed_addr #2
+
+declare hidden float @atanhf(float noundef) local_unnamed_addr #2
+
+declare hidden double @erf(double noundef) local_unnamed_addr #2
+
+declare hidden float @erff(float noundef) local_unnamed_addr #2
+
+declare hidden double @erfc(double noundef) local_unnamed_addr #2
+
+declare hidden float @erfcf(float noundef) local_unnamed_addr #2
+
+declare hidden double @tgamma(double noundef) local_unnamed_addr #2
+
+declare hidden float @tgammaf(float noundef) local_unnamed_addr #2
+
+declare hidden double @lgamma(double noundef) local_unnamed_addr #3
+
+declare hidden float @lgammaf(float noundef) local_unnamed_addr #3
+
+attributes #0 = { convergent mustprogress norecurse nounwind }
+attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+attributes #2 = { convergent mustprogress nofree nounwind willreturn memory(none) }
+attributes #3 = { convergent nounwind }
+attributes #4 = { convergent nounwind willreturn memory(none) }
+attributes #5 = { convergent nounwind }

>From 965e1fa59bfdccda2e2b81fe6b35cb857fade7db Mon Sep 17 00:00:00 2001
From: Alex Voicu <alexandru.voicu at amd.com>
Date: Sun, 2 Nov 2025 02:44:50 +0200
Subject: [PATCH 2/2] Add support for `SPV_INTEL_bfloat16_arithmetic`.

---
 llvm/docs/SPIRVUsage.rst                      |   6 +-
 llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp  |  10 +-
 llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp    |   2 +
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  62 ++-
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |   5 +-
 .../bfloat16-arithmetic.ll                    | 142 +++++++
 .../bfloat16-relational.ll                    | 376 ++++++++++++++++++
 7 files changed, 594 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll

diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 85eeabf10244a..922c9e1a3cfe3 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -173,6 +173,8 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
      - Allows generating arbitrary width integer types.
    * - ``SPV_INTEL_bindless_images``
      - Adds instructions to convert convert unsigned integer handles to images, samplers and sampled images.
+   * - ``SPV_INTEL_bfloat16_arithmetic``
+     - Allows the use of 16-bit bfloat16 values in arithmetic and relational operators.
    * - ``SPV_INTEL_bfloat16_conversion``
      - Adds instructions to convert between single-precision 32-bit floating-point values and 16-bit bfloat16 values.
    * - ``SPV_INTEL_cache_controls``
@@ -226,9 +228,9 @@ Below is a list of supported SPIR-V extensions, sorted alphabetically by their e
    * - ``SPV_INTEL_fp_max_error``
      - Adds the ability to specify the maximum error for floating-point operations.
    * - ``SPV_INTEL_ternary_bitwise_function``
-     - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform. 
+     - Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
    * - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
-     - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix. 
+     - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
    * - ``SPV_INTEL_int4``
      - Adds support for 4-bit integer type, and allow this type to be used in cooperative matrices.
    * - ``SPV_KHR_float_controls2``
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 1fc90d0852aad..847163edcbc4b 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -306,7 +306,7 @@ static bool containsBF16Type(const User &U) {
 
 bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
                                      MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   // Get or create a virtual register for each value.
@@ -328,7 +328,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U,
 
 bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   Register Op0 = getOrCreateVReg(*U.getOperand(0));
@@ -348,7 +348,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) {
 
 bool IRTranslator::translateCompare(const User &U,
                                     MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   auto *CI = cast<CmpInst>(&U);
@@ -1569,7 +1569,7 @@ bool IRTranslator::translateBitCast(const User &U,
 
 bool IRTranslator::translateCast(unsigned Opcode, const User &U,
                                  MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(U))
     return false;
 
   uint32_t Flags = 0;
@@ -2688,7 +2688,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
 
 bool IRTranslator::translateInlineAsm(const CallBase &CB,
                                       MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(CB))
+  if (!MF->getTarget().getTargetTriple().isSPIRV() && containsBF16Type(CB))
     return false;
 
   const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering();
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 96f5dee21bc2a..9643db1f1bf53 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -107,6 +107,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_INTEL_inline_assembly},
         {"SPV_INTEL_bindless_images",
          SPIRV::Extension::Extension::SPV_INTEL_bindless_images},
+        {"SPV_INTEL_bfloat16_arithmetic",
+         SPIRV::Extension::Extension::SPV_INTEL_bfloat16_arithmetic},
         {"SPV_INTEL_bfloat16_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_bfloat16_conversion},
         {"SPV_KHR_subgroup_rotate",
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index db036a55ee6c6..009d2dc1a567a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1435,6 +1435,8 @@ void addInstrRequirements(const MachineInstr &MI,
       addPrintfRequirements(MI, Reqs, ST);
       break;
     }
+    // TODO: handle bfloat16 extended instructions when
+    // SPV_INTEL_bfloat16_arithmetic is enabled.
     break;
   }
   case SPIRV::OpAliasDomainDeclINTEL:
@@ -2060,7 +2062,65 @@ void addInstrRequirements(const MachineInstr &MI,
     Reqs.addCapability(SPIRV::Capability::PredicatedIOINTEL);
     break;
   }
-
+  case SPIRV::OpFAddS:
+  case SPIRV::OpFSubS:
+  case SPIRV::OpFMulS:
+  case SPIRV::OpFDivS:
+  case SPIRV::OpFRemS:
+  case SPIRV::OpFMod:
+  case SPIRV::OpFNegate:
+  case SPIRV::OpFAddV:
+  case SPIRV::OpFSubV:
+  case SPIRV::OpFMulV:
+  case SPIRV::OpFDivV:
+  case SPIRV::OpFRemV:
+  case SPIRV::OpFNegateV: {
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Arithmetic instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
+  case SPIRV::OpOrdered:
+  case SPIRV::OpUnordered:
+  case SPIRV::OpFOrdEqual:
+  case SPIRV::OpFOrdNotEqual:
+  case SPIRV::OpFOrdLessThan:
+  case SPIRV::OpFOrdLessThanEqual:
+  case SPIRV::OpFOrdGreaterThan:
+  case SPIRV::OpFOrdGreaterThanEqual:
+  case SPIRV::OpFUnordEqual:
+  case SPIRV::OpFUnordNotEqual:
+  case SPIRV::OpFUnordLessThan:
+  case SPIRV::OpFUnordLessThanEqual:
+  case SPIRV::OpFUnordGreaterThan:
+  case SPIRV::OpFUnordGreaterThanEqual: {
+    const SPIRVSubtarget &ST = MI.getMF()->getSubtarget<SPIRVSubtarget>();
+    SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
+    SPIRVType *TypeDef = GR->getSPIRVTypeForVReg(MI.getOperand(2).getReg());
+    const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
+    if (TypeDef->getOpcode() == SPIRV::OpTypeVector)
+      TypeDef = MRI.getVRegDef(TypeDef->getOperand(1).getReg());
+    if (isBFloat16Type(TypeDef)) {
+      if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic))
+        report_fatal_error(
+            "Relational instructions with bfloat16 arguments require the "
+            "following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic",
+            false);
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_arithmetic);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ArithmeticINTEL);
+    }
+    break;
+  }
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 7d08b29a51a6e..263b59fbe6959 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -387,6 +387,8 @@ defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
 defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
 defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
+defm SPV_INTEL_bfloat16_arithmetic
+    : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -570,6 +572,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ArithmeticINTEL : CapabilityOperand<6226, 0, 0, [SPV_INTEL_bfloat16_arithmetic], []>;
 defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
 defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
@@ -1919,7 +1922,7 @@ defm GenericCastToPtr :  SpecConstantOpOperandsOperand<122, [], [Kernel]>;
 defm PtrCastToGeneric :  SpecConstantOpOperandsOperand<121, [], [Kernel]>;
 defm Bitcast :  SpecConstantOpOperandsOperand<124, [], []>;
 defm QuantizeToF16 :  SpecConstantOpOperandsOperand<116, [], [Shader]>;
-// Arithmetic 
+// Arithmetic
 defm SNegate :  SpecConstantOpOperandsOperand<126, [], []>;
 defm Not :  SpecConstantOpOperandsOperand<200, [], []>;
 defm IAdd :  SpecConstantOpOperandsOperand<128, [], []>;
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
new file mode 100644
index 0000000000000..4cabddb94df25
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-arithmetic.ll
@@ -0,0 +1,142 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Arithmetic instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[NEG:%.*]] "neg"
+; CHECK-DAG: OpName [[NEGV:%.*]] "negv"
+; CHECK-DAG: OpName [[ADD:%.*]] "add"
+; CHECK-DAG: OpName [[ADDV:%.*]] "addv"
+; CHECK-DAG: OpName [[SUB:%.*]] "sub"
+; CHECK-DAG: OpName [[SUBV:%.*]] "subv"
+; CHECK-DAG: OpName [[MUL:%.*]] "mul"
+; CHECK-DAG: OpName [[MULV:%.*]] "mulv"
+; CHECK-DAG: OpName [[DIV:%.*]] "div"
+; CHECK-DAG: OpName [[DIVV:%.*]] "divv"
+; CHECK-DAG: OpName [[REM:%.*]] "rem"
+; CHECK-DAG: OpName [[REMV:%.*]] "remv"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 4
+
+; CHECK-DAG: [[NEG]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOAT]] [[X]]
+define spir_func bfloat @neg(bfloat %x) {
+entry:
+  %r = fneg bfloat %x
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[NEGV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFNegate [[BFLOATV]] [[X]]
+define spir_func <4 x bfloat> @negv(<4 x bfloat> %x) {
+entry:
+  %r = fneg <4 x bfloat> %x
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[ADD]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @add(bfloat %x, bfloat %y) {
+entry:
+  %r = fadd bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[ADDV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFAdd [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @addv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fadd <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[SUB]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @sub(bfloat %x, bfloat %y) {
+entry:
+  %r = fsub bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[SUBV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFSub [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @subv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fsub <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[MUL]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @mul(bfloat %x, bfloat %y) {
+entry:
+  %r = fmul bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[MULV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFMul [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @mulv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fmul <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[DIV]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @div(bfloat %x, bfloat %y) {
+entry:
+  %r = fdiv bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[DIVV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFDiv [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @divv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = fdiv <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
+
+; CHECK-DAG: [[REM]] = OpFunction [[BFLOAT]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOAT]] [[X]] [[Y]]
+define spir_func bfloat @rem(bfloat %x, bfloat %y) {
+entry:
+  %r = frem bfloat %x, %y
+  ret bfloat %r
+}
+
+; CHECK-DAG: [[REMV]] = OpFunction [[BFLOATV]]
+; CHECK: [[X:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK: [[Y:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-DAG: [[R:%.*]] = OpFRem [[BFLOATV]] [[X]] [[Y]]
+define spir_func <4 x bfloat> @remv(<4 x bfloat> %x, <4 x bfloat> %y) {
+entry:
+  %r = frem <4 x bfloat> %x, %y
+  ret <4 x bfloat> %r
+}
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
new file mode 100644
index 0000000000000..3774791d58f87
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_arithmetic/bfloat16-relational.ll
@@ -0,0 +1,376 @@
+; RUN: not llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - | FileCheck %s
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_bfloat16_arithmetic,+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-ERROR: LLVM ERROR: Relational instructions with bfloat16 arguments require the following SPIR-V extension: SPV_INTEL_bfloat16_arithmetic
+
+; CHECK-DAG: OpCapability BFloat16TypeKHR
+; CHECK-DAG: OpCapability BFloat16ArithmeticINTEL
+; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
+; CHECK-DAG: OpExtension "SPV_INTEL_bfloat16_arithmetic"
+; CHECK-DAG: OpName [[UEQ:%.*]] "test_ueq"
+; CHECK-DAG: OpName [[OEQ:%.*]] "test_oeq"
+; CHECK-DAG: OpName [[UNE:%.*]] "test_une"
+; CHECK-DAG: OpName [[ONE:%.*]] "test_one"
+; CHECK-DAG: OpName [[ULT:%.*]] "test_ult"
+; CHECK-DAG: OpName [[OLT:%.*]] "test_olt"
+; CHECK-DAG: OpName [[ULE:%.*]] "test_ule"
+; CHECK-DAG: OpName [[OLE:%.*]] "test_ole"
+; CHECK-DAG: OpName [[UGT:%.*]] "test_ugt"
+; CHECK-DAG: OpName [[OGT:%.*]] "test_ogt"
+; CHECK-DAG: OpName [[UGE:%.*]] "test_uge"
+; CHECK-DAG: OpName [[OGE:%.*]] "test_oge"
+; CHECK-DAG: OpName [[UNO:%.*]] "test_uno"
+; CHECK-DAG: OpName [[ORD:%.*]] "test_ord"
+; CHECK-DAG: OpName [[v3UEQ:%.*]] "test_v3_ueq"
+; CHECK-DAG: OpName [[v3OEQ:%.*]] "test_v3_oeq"
+; CHECK-DAG: OpName [[v3UNE:%.*]] "test_v3_une"
+; CHECK-DAG: OpName [[v3ONE:%.*]] "test_v3_one"
+; CHECK-DAG: OpName [[v3ULT:%.*]] "test_v3_ult"
+; CHECK-DAG: OpName [[v3OLT:%.*]] "test_v3_olt"
+; CHECK-DAG: OpName [[v3ULE:%.*]] "test_v3_ule"
+; CHECK-DAG: OpName [[v3OLE:%.*]] "test_v3_ole"
+; CHECK-DAG: OpName [[v3UGT:%.*]] "test_v3_ugt"
+; CHECK-DAG: OpName [[v3OGT:%.*]] "test_v3_ogt"
+; CHECK-DAG: OpName [[v3UGE:%.*]] "test_v3_uge"
+; CHECK-DAG: OpName [[v3OGE:%.*]] "test_v3_oge"
+; CHECK-DAG: OpName [[v3UNO:%.*]] "test_v3_uno"
+; CHECK-DAG: OpName [[v3ORD:%.*]] "test_v3_ord"
+; CHECK: [[BFLOAT:%.*]] = OpTypeFloat 16 0
+; CHECK: [[BFLOATV:%.*]] = OpTypeVector [[BFLOAT]] 3
+
+; CHECK:      [[UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ueq(bfloat %a, bfloat %b) {
+  %r = fcmp ueq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oeq(bfloat %a, bfloat %b) {
+  %r = fcmp oeq bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_une(bfloat %a, bfloat %b) {
+  %r = fcmp une bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_one(bfloat %a, bfloat %b) {
+  %r = fcmp one bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ult(bfloat %a, bfloat %b) {
+  %r = fcmp ult bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_olt(bfloat %a, bfloat %b) {
+  %r = fcmp olt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ULE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ule(bfloat %a, bfloat %b) {
+  %r = fcmp ule bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OLE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ole(bfloat %a, bfloat %b) {
+  %r = fcmp ole bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ugt(bfloat %a, bfloat %b) {
+  %r = fcmp ugt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ogt(bfloat %a, bfloat %b) {
+  %r = fcmp ogt bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_uge(bfloat %a, bfloat %b) {
+  %r = fcmp uge bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[OGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_oge(bfloat %a, bfloat %b) {
+  %r = fcmp oge bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[ORD]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpOrdered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_ord(bfloat %a, bfloat %b) {
+  %r = fcmp ord bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[UNO]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOAT]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpUnordered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define i1 @test_uno(bfloat %a, bfloat %b) {
+  %r = fcmp uno bfloat %a, %b
+  ret i1 %r
+}
+
+; CHECK:      [[v3UEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ueq(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ueq <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OEQ]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_oeq(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp oeq <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UNE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_une(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp une <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ONE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdNotEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_one(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp one <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ULT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ult(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ult <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OLT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_olt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp olt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ULE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ule(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ule <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OLE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdLessThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ole(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ole <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ugt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ugt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OGT]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThan {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ogt(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ogt <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFUnordGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_uge(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp uge <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3OGE]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpFOrdGreaterThanEqual {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_oge(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp oge <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3ORD]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpOrdered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_ord(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp ord <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}
+
+; CHECK:      [[v3UNO]] = OpFunction
+; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter [[BFLOATV]]
+; CHECK-NEXT: OpLabel
+; CHECK-NEXT: [[R:%.*]] = OpUnordered {{%.+}} [[A]] [[B]]
+; CHECK-NEXT: OpReturnValue [[R]]
+; CHECK-NEXT: OpFunctionEnd
+define <3 x i1> @test_v3_uno(<3 x bfloat> %a, <3 x bfloat> %b) {
+  %r = fcmp uno <3 x bfloat> %a, %b
+  ret <3 x i1> %r
+}



More information about the llvm-commits mailing list