[llvm] [SPIR-V] Add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (PR #83443)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 08:54:58 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/83443

>From f92a8b4b52ac5f59d96db74e25333f216981cb8d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 29 Feb 2024 08:11:56 -0800
Subject: [PATCH 1/2] add support for SPV_INTEL_bfloat16_conversion

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 44 +++++++--
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 24 ++++-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  9 ++
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  4 +
 llvm/lib/Target/SPIRV/SPIRVInstrInfo.td       |  4 +
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp |  7 ++
 llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp      |  4 +
 .../lib/Target/SPIRV/SPIRVSymbolicOperands.td |  2 +
 .../bfloat16-conv.ll                          | 96 +++++++++++++++++++
 9 files changed, 185 insertions(+), 9 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..296782bb0d2689 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -134,6 +134,7 @@ struct ConvertBuiltin {
   bool IsDestinationSigned;
   bool IsSaturated;
   bool IsRounded;
+  bool IsBfloat16;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
 
@@ -1986,6 +1987,7 @@ static bool generateConvertInst(const StringRef DemangledCall,
                     SPIRV::Decoration::FPRoundingMode,
                     {(unsigned)Builtin->RoundingMode});
 
+  std::string NeedExtMsg; // no errors if empty
   unsigned Opcode = SPIRV::OpNop;
   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
     // Int -> ...
@@ -2000,23 +2002,49 @@ static bool generateConvertInst(const StringRef DemangledCall,
     } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
                                           SPIRV::OpTypeFloat)) {
       // Int -> Float
-      bool IsSourceSigned =
-          DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
-      Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+      if (Builtin->IsBfloat16) {
+        const auto *ST = static_cast<const SPIRVSubtarget *>(
+            &MIRBuilder.getMF().getSubtarget());
+        if (!ST->canUseExtension(
+                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        Opcode = SPIRV::OpConvertBF16ToFINTEL;
+      } else {
+        bool IsSourceSigned =
+            DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
+        Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+      }
     }
   } else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
                                         SPIRV::OpTypeFloat)) {
     // Float -> ...
-    if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt))
+    if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
       // Float -> Int
-      Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
-                                            : SPIRV::OpConvertFToU;
-    else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
-                                        SPIRV::OpTypeFloat))
+      if (Builtin->IsBfloat16) {
+        const auto *ST = static_cast<const SPIRVSubtarget *>(
+            &MIRBuilder.getMF().getSubtarget());
+        if (!ST->canUseExtension(
+                SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+          NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        Opcode = SPIRV::OpConvertFToBF16INTEL;
+      } else {
+        Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
+                                              : SPIRV::OpConvertFToU;
+      }
+    } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
+                                          SPIRV::OpTypeFloat)) {
       // Float -> Float
       Opcode = SPIRV::OpFConvert;
+    }
   }
 
+  if (!NeedExtMsg.empty()) {
+    std::string DiagMsg = std::string(Builtin->Name) +
+                          ": the builtin requires the following SPIR-V "
+                          "extension: " +
+                          NeedExtMsg;
+    report_fatal_error(DiagMsg.c_str(), false);
+  }
   assert(Opcode != SPIRV::OpNop &&
          "Conversion between the types not implemented!");
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 28a63b93b43b6e..eb26f70b1861f2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1177,6 +1177,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
   bit IsDestinationSigned = !eq(!find(name, "convert_u"), -1);
   bit IsSaturated = !not(!eq(!find(name, "_sat"), -1));
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
+  bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
+                       !not(!eq(!find(name, "bfloat16"), -1)));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
                                   !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
                                   !not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1187,7 +1189,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
 // Table gathering all the convert builtins.
 def ConvertBuiltins : GenericTable {
   let FilterClass = "ConvertBuiltin";
-  let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", "IsRounded", "RoundingMode"];
+  let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
+                "IsRounded", "IsBfloat16", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }
@@ -1229,6 +1232,25 @@ defm : DemangledConvertBuiltin<"convert_long", OpenCL_std>;
 defm : DemangledConvertBuiltin<"convert_ulong", OpenCL_std>;
 defm : DemangledConvertBuiltin<"convert_float", OpenCL_std>;
 
+// cl_intel_bfloat16_conversions / SPV_INTEL_bfloat16_conversion
+// Multiclass used to define at the same time both a demangled builtin records
+// and a corresponding convert builtin records.
+multiclass DemangledBF16ConvertBuiltin<string name1, string name2> {
+  // Create records for scalar and vector conversions.
+  foreach i = ["", "2", "3", "4", "8", "16"] in {
+    def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+    def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+  }
+}
+
+defm : DemangledBF16ConvertBuiltin<"bfloat16", "_as_ushort">;
+defm : DemangledBF16ConvertBuiltin<"as_bfloat16", "_float">;
+
+foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
+  def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
+  def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+}
+
 //===----------------------------------------------------------------------===//
 // Class defining a vector data load/store builtin record used for lowering
 // into OpExtInst instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..21cba98ca8b6b7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -819,6 +819,15 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
   return false;
 }
 
+unsigned
+SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
+  if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
+    return Type->getOpcode() == SPIRV::OpTypeVector
+               ? static_cast<unsigned>(Type->getOperand(2).getImm())
+               : 1;
+  return 0;
+}
+
 unsigned
 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
   assert(Type && "Invalid Type pointer");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..965d5b848fcb87 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,6 +197,10 @@ class SPIRVGlobalRegistry {
   // opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
   bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
 
+  // Return number of elements in a vector if the given VReg is associated with
+  // a vector type. Return 1 for a scalar type, and 0 for a missing type.
+  unsigned getScalarOrVectorComponentCount(Register VReg) const;
+
   // For vectors or scalars of ints/floats, return the scalar type's bitwidth.
   unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index fe8c909236cde3..99c57dac4141d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -443,6 +443,10 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
 def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
 def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
 
+// SPV_INTEL_bfloat16_conversion
+def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
+def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
+
 // 3.42.12 Composite Instructions
 
 def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ac3d6b362d350b..b7be7ffd3f0c61 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1110,6 +1110,13 @@ void addInstrRequirements(const MachineInstr &MI,
   case SPIRV::OpAtomicFMaxEXT:
     AddAtomicFloatRequirements(MI, Reqs, ST);
     break;
+  case SPIRV::OpConvertBF16ToFINTEL:
+  case SPIRV::OpConvertFToBF16INTEL:
+    if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
+      Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
+    }
+    break;
   case SPIRV::OpVariableLengthArrayINTEL:
   case SPIRV::OpSaveMemoryINTEL:
   case SPIRV::OpRestoreMemoryINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index 0e8952dc6a9c9f..b866def589853f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -81,6 +81,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
             "Allows to use the LinkOnceODR linkage type that is to let "
             "a function or global variable to be merged with other functions "
             "or global variables of the same name when linkage occurs."),
+        clEnumValN(SPIRV::Extension::SPV_INTEL_bfloat16_conversion,
+                   "SPV_INTEL_bfloat16_conversion",
+                   "Adds instructions to convert between single-precision "
+                   "32-bit floating-point values and 16-bit bfloat16 values."),
         clEnumValN(SPIRV::Extension::SPV_KHR_subgroup_rotate,
                    "SPV_KHR_subgroup_rotate",
                    "Adds a new instruction that enables rotating values across "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 211c22340eb82c..8dbbd9049844c8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -297,6 +297,7 @@ defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
 defm SPV_INTEL_optnone : ExtensionOperand<103>;
 defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
 defm SPV_INTEL_variable_length_array : ExtensionOperand<105>;
+defm SPV_INTEL_bfloat16_conversion : ExtensionOperand<106>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -466,6 +467,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
 defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
 defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
 defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
new file mode 100644
index 00000000000000..2bd59b22322ffd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
@@ -0,0 +1,96 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - -filetype=obj | spirv-val %}
+
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_bfloat16_conversion
+
+; CHECK: OpCapability BFloat16ConversionINTEL
+; CHECK: OpExtension "SPV_INTEL_bfloat16_conversion"
+
+; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
+; CHECK-DAG: %[[Int16Ty:.*]] = OpTypeInt 16 0
+; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
+; CHECK-DAG: %[[VecInt162:.*]] = OpTypeVector %[[Int16Ty]] 2
+; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
+; CHECK-DAG: %[[VecInt163:.*]] = OpTypeVector %[[Int16Ty]] 3
+; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
+; CHECK-DAG: %[[VecInt164:.*]] = OpTypeVector %[[Int16Ty]] 4
+; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
+; CHECK-DAG: %[[VecInt168:.*]] = OpTypeVector %[[Int16Ty]] 8
+; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
+; CHECK-DAG: %[[VecInt1616:.*]] = OpTypeVector %[[Int16Ty]] 16
+; CHECK-DAG: %[[IntConstId:.*]] = OpConstant %[[Int16Ty]] 67
+; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
+
+; CHECK: OpFunction %[[VoidTy]]
+; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
+; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
+
+; CHECK: %[[Int16ValId:.*]] = OpConvertFToBF16INTEL %[[Int16Ty]] %[[FP32ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[Int16ValId]]
+; CHECK: %[[Int16v8ValId:.*]] = OpConvertFToBF16INTEL %[[VecInt168]] %[[FP32v8ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]] %[[Int16v8ValId]]
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]] %[[FloatConstId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[IntConstId]]
+
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt162]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt163]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt164]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt168]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt1616]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat2]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat3]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat4]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat16]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(float %a, <8 x float> %in) {
+  %res1 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
+  %res2 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 zeroext %res1)
+  %res3 = tail call spir_func <8 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  %res4 = tail call spir_func <8 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %res3)
+  %res5 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 1.500000e+00)
+  %res6 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 67)
+  ret void
+}
+
+declare spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float)
+declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 zeroext)
+declare spir_func <8 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
+declare spir_func <8 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
+
+define dso_local spir_kernel void @test_ocl() {
+entry:
+  %res = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(float 0.000000e+00)
+  %res1 = call spir_func <2 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x float> zeroinitializer)
+  %res2 = call spir_func <3 x i16> @_Z34intel_convert_bfloat163_as_ushort3Dv3_f(<3 x float> zeroinitializer)
+  %res3 = call spir_func <4 x i16> @_Z34intel_convert_bfloat164_as_ushort4Dv4_f(<4 x float> zeroinitializer)
+  %res4 = call spir_func <8 x i16> @_Z34intel_convert_bfloat168_as_ushort8Dv8_f(<8 x float> zeroinitializer)
+  %res5 = call spir_func <16 x i16> @_Z36intel_convert_bfloat1616_as_ushort16Dv16_f(<16 x float> zeroinitializer)
+  %res6 = call spir_func float @_Z31intel_convert_as_bfloat16_floatt(i16 zeroext 0)
+  %res7 = call spir_func <2 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i16> zeroinitializer)
+  %res8 = call spir_func <3 x float> @_Z33intel_convert_as_bfloat163_float3Dv3_t(<3 x i16> zeroinitializer)
+  %res9 = call spir_func <4 x float> @_Z33intel_convert_as_bfloat164_float4Dv4_t(<4 x i16> zeroinitializer)
+  %res10 = call spir_func <8 x float> @_Z33intel_convert_as_bfloat168_float8Dv8_t(<8 x i16> zeroinitializer)
+  %res11 = call spir_func <16 x float> @_Z35intel_convert_as_bfloat1616_float16Dv16_t(<16 x i16> zeroinitializer)
+  ret void
+}
+
+declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(float)
+declare spir_func <2 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x float>)
+declare spir_func <3 x i16> @_Z34intel_convert_bfloat163_as_ushort3Dv3_f(<3 x float>)
+declare spir_func <4 x i16> @_Z34intel_convert_bfloat164_as_ushort4Dv4_f(<4 x float>)
+declare spir_func <8 x i16> @_Z34intel_convert_bfloat168_as_ushort8Dv8_f(<8 x float>)
+declare spir_func <16 x i16> @_Z36intel_convert_bfloat1616_as_ushort16Dv16_f(<16 x float>)
+declare spir_func float @_Z31intel_convert_as_bfloat16_floatt(i16 zeroext)
+declare spir_func <2 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i16>)
+declare spir_func <3 x float> @_Z33intel_convert_as_bfloat163_float3Dv3_t(<3 x i16>)
+declare spir_func <4 x float> @_Z33intel_convert_as_bfloat164_float4Dv4_t(<4 x i16>)
+declare spir_func <8 x float> @_Z33intel_convert_as_bfloat168_float8Dv8_t(<8 x i16>)
+declare spir_func <16 x float> @_Z35intel_convert_as_bfloat1616_float16Dv16_t(<16 x i16>)

>From cdbb39968fc1beacb0003dacf25fd15f0bb6dd31 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 29 Feb 2024 08:54:47 -0800
Subject: [PATCH 2/2] add validation and negative test cases

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp           | 15 ++++++++++++++-
 .../bfloat16-conv-negative1.ll                    | 12 ++++++++++++
 .../bfloat16-conv-negative2.ll                    | 12 ++++++++++++
 .../bfloat16-conv-negative3.ll                    | 12 ++++++++++++
 .../bfloat16-conv-negative4.ll                    | 13 +++++++++++++
 5 files changed, 63 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 296782bb0d2689..5652ab5bcd9462 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1987,7 +1987,8 @@ static bool generateConvertInst(const StringRef DemangledCall,
                     SPIRV::Decoration::FPRoundingMode,
                     {(unsigned)Builtin->RoundingMode});
 
-  std::string NeedExtMsg; // no errors if empty
+  std::string NeedExtMsg;              // no errors if empty
+  bool IsRightComponentsNumber = true; // check if input/output accepts vectors
   unsigned Opcode = SPIRV::OpNop;
   if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
     // Int -> ...
@@ -2008,6 +2009,9 @@ static bool generateConvertInst(const StringRef DemangledCall,
         if (!ST->canUseExtension(
                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        IsRightComponentsNumber =
+            GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+            GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
         Opcode = SPIRV::OpConvertBF16ToFINTEL;
       } else {
         bool IsSourceSigned =
@@ -2026,6 +2030,9 @@ static bool generateConvertInst(const StringRef DemangledCall,
         if (!ST->canUseExtension(
                 SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
           NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+        IsRightComponentsNumber =
+            GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+            GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
         Opcode = SPIRV::OpConvertFToBF16INTEL;
       } else {
         Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
@@ -2045,6 +2052,12 @@ static bool generateConvertInst(const StringRef DemangledCall,
                           NeedExtMsg;
     report_fatal_error(DiagMsg.c_str(), false);
   }
+  if (!IsRightComponentsNumber) {
+    std::string DiagMsg =
+        std::string(Builtin->Name) +
+        ": result and argument must have the same number of components";
+    report_fatal_error(DiagMsg.c_str(), false);
+  }
   assert(Opcode != SPIRV::OpNop &&
          "Conversion between the types not implemented!");
 
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
new file mode 100644
index 00000000000000..2f3c859db346df
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
new file mode 100644
index 00000000000000..c02d50cfab21d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+  %res = tail call spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+  ret void
+}
+
+declare spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
new file mode 100644
index 00000000000000..20a8042ad9c297
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+  %res = tail call spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+  ret void
+}
+
+declare spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
new file mode 100644
index 00000000000000..87d26472a4eeb6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
@@ -0,0 +1,13 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+  %res = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+  ret void
+}
+
+declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
+



More information about the llvm-commits mailing list