[llvm] [SPIR-V] Add support for the SPIR-V extension SPV_INTEL_bfloat16_conversion (PR #83443)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 29 08:54:58 PST 2024
https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/83443
>From f92a8b4b52ac5f59d96db74e25333f216981cb8d Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 29 Feb 2024 08:11:56 -0800
Subject: [PATCH 1/2] add support for SPV_INTEL_bfloat16_conversion
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 44 +++++++--
llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 24 ++++-
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 9 ++
llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 4 +
llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 4 +
llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 7 ++
llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp | 4 +
.../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 +
.../bfloat16-conv.ll | 96 +++++++++++++++++++
9 files changed, 185 insertions(+), 9 deletions(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c1bb27322443ff..296782bb0d2689 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -134,6 +134,7 @@ struct ConvertBuiltin {
bool IsDestinationSigned;
bool IsSaturated;
bool IsRounded;
+ bool IsBfloat16;
FPRoundingMode::FPRoundingMode RoundingMode;
};
@@ -1986,6 +1987,7 @@ static bool generateConvertInst(const StringRef DemangledCall,
SPIRV::Decoration::FPRoundingMode,
{(unsigned)Builtin->RoundingMode});
+ std::string NeedExtMsg; // no errors if empty
unsigned Opcode = SPIRV::OpNop;
if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
// Int -> ...
@@ -2000,23 +2002,49 @@ static bool generateConvertInst(const StringRef DemangledCall,
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
SPIRV::OpTypeFloat)) {
// Int -> Float
- bool IsSourceSigned =
- DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
- Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+ if (Builtin->IsBfloat16) {
+ const auto *ST = static_cast<const SPIRVSubtarget *>(
+ &MIRBuilder.getMF().getSubtarget());
+ if (!ST->canUseExtension(
+ SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+ NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+ Opcode = SPIRV::OpConvertBF16ToFINTEL;
+ } else {
+ bool IsSourceSigned =
+ DemangledCall[DemangledCall.find_first_of('(') + 1] != 'u';
+ Opcode = IsSourceSigned ? SPIRV::OpConvertSToF : SPIRV::OpConvertUToF;
+ }
}
} else if (GR->isScalarOrVectorOfType(Call->Arguments[0],
SPIRV::OpTypeFloat)) {
// Float -> ...
- if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt))
+ if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeInt)) {
// Float -> Int
- Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
- : SPIRV::OpConvertFToU;
- else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
- SPIRV::OpTypeFloat))
+ if (Builtin->IsBfloat16) {
+ const auto *ST = static_cast<const SPIRVSubtarget *>(
+ &MIRBuilder.getMF().getSubtarget());
+ if (!ST->canUseExtension(
+ SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
+ NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+ Opcode = SPIRV::OpConvertFToBF16INTEL;
+ } else {
+ Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
+ : SPIRV::OpConvertFToU;
+ }
+ } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
+ SPIRV::OpTypeFloat)) {
// Float -> Float
Opcode = SPIRV::OpFConvert;
+ }
}
+ if (!NeedExtMsg.empty()) {
+ std::string DiagMsg = std::string(Builtin->Name) +
+ ": the builtin requires the following SPIR-V "
+ "extension: " +
+ NeedExtMsg;
+ report_fatal_error(DiagMsg.c_str(), false);
+ }
assert(Opcode != SPIRV::OpNop &&
"Conversion between the types not implemented!");
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 28a63b93b43b6e..eb26f70b1861f2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1177,6 +1177,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
bit IsDestinationSigned = !eq(!find(name, "convert_u"), -1);
bit IsSaturated = !not(!eq(!find(name, "_sat"), -1));
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
+ bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
+ !not(!eq(!find(name, "bfloat16"), -1)));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1187,7 +1189,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
// Table gathering all the convert builtins.
def ConvertBuiltins : GenericTable {
let FilterClass = "ConvertBuiltin";
- let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", "IsRounded", "RoundingMode"];
+ let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
+ "IsRounded", "IsBfloat16", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
@@ -1229,6 +1232,25 @@ defm : DemangledConvertBuiltin<"convert_long", OpenCL_std>;
defm : DemangledConvertBuiltin<"convert_ulong", OpenCL_std>;
defm : DemangledConvertBuiltin<"convert_float", OpenCL_std>;
+// cl_intel_bfloat16_conversions / SPV_INTEL_bfloat16_conversion
+// Multiclass used to define at the same time both a demangled builtin records
+// and a corresponding convert builtin records.
+multiclass DemangledBF16ConvertBuiltin<string name1, string name2> {
+ // Create records for scalar and vector conversions.
+ foreach i = ["", "2", "3", "4", "8", "16"] in {
+ def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
+ def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
+ }
+}
+
+defm : DemangledBF16ConvertBuiltin<"bfloat16", "_as_ushort">;
+defm : DemangledBF16ConvertBuiltin<"as_bfloat16", "_float">;
+
+foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
+ def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
+ def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
+}
+
//===----------------------------------------------------------------------===//
// Class defining a vector data load/store builtin record used for lowering
// into OpExtInst instruction.
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index a1cb630f1aa477..21cba98ca8b6b7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -819,6 +819,15 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
return false;
}
+unsigned
+SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
+ if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
+ return Type->getOpcode() == SPIRV::OpTypeVector
+ ? static_cast<unsigned>(Type->getOperand(2).getImm())
+ : 1;
+ return 0;
+}
+
unsigned
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 792a00786f0aaf..965d5b848fcb87 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -197,6 +197,10 @@ class SPIRVGlobalRegistry {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
+ // Return number of elements in a vector if the given VReg is associated with
+ // a vector type. Return 1 for a scalar type, and 0 for a missing type.
+ unsigned getScalarOrVectorComponentCount(Register VReg) const;
+
// For vectors or scalars of ints/floats, return the scalar type's bitwidth.
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index fe8c909236cde3..99c57dac4141d8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -443,6 +443,10 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
+// SPV_INTEL_bfloat16_conversion
+def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
+def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
+
// 3.42.12 Composite Instructions
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index ac3d6b362d350b..b7be7ffd3f0c61 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1110,6 +1110,13 @@ void addInstrRequirements(const MachineInstr &MI,
case SPIRV::OpAtomicFMaxEXT:
AddAtomicFloatRequirements(MI, Reqs, ST);
break;
+ case SPIRV::OpConvertBF16ToFINTEL:
+ case SPIRV::OpConvertFToBF16INTEL:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_bfloat16_conversion);
+ Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
+ }
+ break;
case SPIRV::OpVariableLengthArrayINTEL:
case SPIRV::OpSaveMemoryINTEL:
case SPIRV::OpRestoreMemoryINTEL:
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index 0e8952dc6a9c9f..b866def589853f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -81,6 +81,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
"Allows to use the LinkOnceODR linkage type that is to let "
"a function or global variable to be merged with other functions "
"or global variables of the same name when linkage occurs."),
+ clEnumValN(SPIRV::Extension::SPV_INTEL_bfloat16_conversion,
+ "SPV_INTEL_bfloat16_conversion",
+ "Adds instructions to convert between single-precision "
+ "32-bit floating-point values and 16-bit bfloat16 values."),
clEnumValN(SPIRV::Extension::SPV_KHR_subgroup_rotate,
"SPV_KHR_subgroup_rotate",
"Adds a new instruction that enables rotating values across "
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 211c22340eb82c..8dbbd9049844c8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -297,6 +297,7 @@ defm SPV_INTEL_fpga_argument_interfaces : ExtensionOperand<102>;
defm SPV_INTEL_optnone : ExtensionOperand<103>;
defm SPV_INTEL_function_pointers : ExtensionOperand<104>;
defm SPV_INTEL_variable_length_array : ExtensionOperand<105>;
+defm SPV_INTEL_bfloat16_conversion : ExtensionOperand<106>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -466,6 +467,7 @@ defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atom
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
+defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
new file mode 100644
index 00000000000000..2bd59b22322ffd
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv.ll
@@ -0,0 +1,96 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o - -filetype=obj | spirv-val %}
+
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_bfloat16_conversion
+
+; CHECK: OpCapability BFloat16ConversionINTEL
+; CHECK: OpExtension "SPV_INTEL_bfloat16_conversion"
+
+; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
+; CHECK-DAG: %[[Int16Ty:.*]] = OpTypeInt 16 0
+; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
+; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
+; CHECK-DAG: %[[VecInt162:.*]] = OpTypeVector %[[Int16Ty]] 2
+; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
+; CHECK-DAG: %[[VecInt163:.*]] = OpTypeVector %[[Int16Ty]] 3
+; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
+; CHECK-DAG: %[[VecInt164:.*]] = OpTypeVector %[[Int16Ty]] 4
+; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
+; CHECK-DAG: %[[VecInt168:.*]] = OpTypeVector %[[Int16Ty]] 8
+; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
+; CHECK-DAG: %[[VecInt1616:.*]] = OpTypeVector %[[Int16Ty]] 16
+; CHECK-DAG: %[[IntConstId:.*]] = OpConstant %[[Int16Ty]] 67
+; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
+
+; CHECK: OpFunction %[[VoidTy]]
+; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
+; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
+
+; CHECK: %[[Int16ValId:.*]] = OpConvertFToBF16INTEL %[[Int16Ty]] %[[FP32ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[Int16ValId]]
+; CHECK: %[[Int16v8ValId:.*]] = OpConvertFToBF16INTEL %[[VecInt168]] %[[FP32v8ValId]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]] %[[Int16v8ValId]]
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]] %[[FloatConstId]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]] %[[IntConstId]]
+
+; CHECK: OpConvertFToBF16INTEL %[[Int16Ty]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt162]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt163]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt164]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt168]]
+; CHECK: OpConvertFToBF16INTEL %[[VecInt1616]]
+; CHECK: OpConvertBF16ToFINTEL %[[FP32Ty]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat2]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat3]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat4]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat8]]
+; CHECK: OpConvertBF16ToFINTEL %[[VecFloat16]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(float %a, <8 x float> %in) {
+ %res1 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float %a)
+ %res2 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 zeroext %res1)
+ %res3 = tail call spir_func <8 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+ %res4 = tail call spir_func <8 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %res3)
+ %res5 = tail call spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float 1.500000e+00)
+ %res6 = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 67)
+ ret void
+}
+
+declare spir_func zeroext i16 @_Z27__spirv_ConvertFToBF16INTELf(float)
+declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELs(i16 zeroext)
+declare spir_func <8 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
+declare spir_func <8 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
+
+define dso_local spir_kernel void @test_ocl() {
+entry:
+ %res = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(float 0.000000e+00)
+ %res1 = call spir_func <2 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x float> zeroinitializer)
+ %res2 = call spir_func <3 x i16> @_Z34intel_convert_bfloat163_as_ushort3Dv3_f(<3 x float> zeroinitializer)
+ %res3 = call spir_func <4 x i16> @_Z34intel_convert_bfloat164_as_ushort4Dv4_f(<4 x float> zeroinitializer)
+ %res4 = call spir_func <8 x i16> @_Z34intel_convert_bfloat168_as_ushort8Dv8_f(<8 x float> zeroinitializer)
+ %res5 = call spir_func <16 x i16> @_Z36intel_convert_bfloat1616_as_ushort16Dv16_f(<16 x float> zeroinitializer)
+ %res6 = call spir_func float @_Z31intel_convert_as_bfloat16_floatt(i16 zeroext 0)
+ %res7 = call spir_func <2 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i16> zeroinitializer)
+ %res8 = call spir_func <3 x float> @_Z33intel_convert_as_bfloat163_float3Dv3_t(<3 x i16> zeroinitializer)
+ %res9 = call spir_func <4 x float> @_Z33intel_convert_as_bfloat164_float4Dv4_t(<4 x i16> zeroinitializer)
+ %res10 = call spir_func <8 x float> @_Z33intel_convert_as_bfloat168_float8Dv8_t(<8 x i16> zeroinitializer)
+ %res11 = call spir_func <16 x float> @_Z35intel_convert_as_bfloat1616_float16Dv16_t(<16 x i16> zeroinitializer)
+ ret void
+}
+
+declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(float)
+declare spir_func <2 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x float>)
+declare spir_func <3 x i16> @_Z34intel_convert_bfloat163_as_ushort3Dv3_f(<3 x float>)
+declare spir_func <4 x i16> @_Z34intel_convert_bfloat164_as_ushort4Dv4_f(<4 x float>)
+declare spir_func <8 x i16> @_Z34intel_convert_bfloat168_as_ushort8Dv8_f(<8 x float>)
+declare spir_func <16 x i16> @_Z36intel_convert_bfloat1616_as_ushort16Dv16_f(<16 x float>)
+declare spir_func float @_Z31intel_convert_as_bfloat16_floatt(i16 zeroext)
+declare spir_func <2 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i16>)
+declare spir_func <3 x float> @_Z33intel_convert_as_bfloat163_float3Dv3_t(<3 x i16>)
+declare spir_func <4 x float> @_Z33intel_convert_as_bfloat164_float4Dv4_t(<4 x i16>)
+declare spir_func <8 x float> @_Z33intel_convert_as_bfloat168_float8Dv8_t(<8 x i16>)
+declare spir_func <16 x float> @_Z35intel_convert_as_bfloat1616_float16Dv16_t(<16 x i16>)
>From cdbb39968fc1beacb0003dacf25fd15f0bb6dd31 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 29 Feb 2024 08:54:47 -0800
Subject: [PATCH 2/2] add validation and negative test cases
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 15 ++++++++++++++-
.../bfloat16-conv-negative1.ll | 12 ++++++++++++
.../bfloat16-conv-negative2.ll | 12 ++++++++++++
.../bfloat16-conv-negative3.ll | 12 ++++++++++++
.../bfloat16-conv-negative4.ll | 13 +++++++++++++
5 files changed, 63 insertions(+), 1 deletion(-)
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 296782bb0d2689..5652ab5bcd9462 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1987,7 +1987,8 @@ static bool generateConvertInst(const StringRef DemangledCall,
SPIRV::Decoration::FPRoundingMode,
{(unsigned)Builtin->RoundingMode});
- std::string NeedExtMsg; // no errors if empty
+ std::string NeedExtMsg; // no errors if empty
+ bool IsRightComponentsNumber = true; // check if input/output accepts vectors
unsigned Opcode = SPIRV::OpNop;
if (GR->isScalarOrVectorOfType(Call->Arguments[0], SPIRV::OpTypeInt)) {
// Int -> ...
@@ -2008,6 +2009,9 @@ static bool generateConvertInst(const StringRef DemangledCall,
if (!ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+ IsRightComponentsNumber =
+ GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+ GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
Opcode = SPIRV::OpConvertBF16ToFINTEL;
} else {
bool IsSourceSigned =
@@ -2026,6 +2030,9 @@ static bool generateConvertInst(const StringRef DemangledCall,
if (!ST->canUseExtension(
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
+ IsRightComponentsNumber =
+ GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
+ GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
Opcode = SPIRV::OpConvertFToBF16INTEL;
} else {
Opcode = Builtin->IsDestinationSigned ? SPIRV::OpConvertFToS
@@ -2045,6 +2052,12 @@ static bool generateConvertInst(const StringRef DemangledCall,
NeedExtMsg;
report_fatal_error(DiagMsg.c_str(), false);
}
+ if (!IsRightComponentsNumber) {
+ std::string DiagMsg =
+ std::string(Builtin->Name) +
+ ": result and argument must have the same number of components";
+ report_fatal_error(DiagMsg.c_str(), false);
+ }
assert(Opcode != SPIRV::OpNop &&
"Conversion between the types not implemented!");
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
new file mode 100644
index 00000000000000..2f3c859db346df
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative1.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+ %res = tail call spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+ ret void
+}
+
+declare spir_func i16 @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
new file mode 100644
index 00000000000000..c02d50cfab21d8
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative2.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x float> %in) {
+ %res = tail call spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float> %in)
+ ret void
+}
+
+declare spir_func <4 x i16> @_Z27__spirv_ConvertFToBF16INTELDv8_f(<8 x float>)
\ No newline at end of file
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
new file mode 100644
index 00000000000000..20a8042ad9c297
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative3.ll
@@ -0,0 +1,12 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+ %res = tail call spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+ ret void
+}
+
+declare spir_func <4 x float> @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
new file mode 100644
index 00000000000000..87d26472a4eeb6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_bfloat16_conversion/bfloat16-conv-negative4.ll
@@ -0,0 +1,13 @@
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-extensions=SPV_INTEL_bfloat16_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: result and argument must have the same number of components
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+define spir_func void @test(<8 x i16> %in) {
+ %res = tail call spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16> %in)
+ ret void
+}
+
+declare spir_func float @_Z27__spirv_ConvertBF16ToFINTELDv8_s(<8 x i16>)
+
More information about the llvm-commits
mailing list