[Mlir-commits] [mlir] d2a559f - [mlir][spirv] Add OpExtension "SPV_INTEL_bfloat16_conversion"
Lei Zhang
llvmlistbot at llvm.org
Fri Mar 31 14:12:28 PDT 2023
Author: Md Abdullah Shahneous Bari
Date: 2023-03-31T14:02:59-07:00
New Revision: d2a559ffc0dc61b9d7426064bd5076b66d2f96d6
URL: https://github.com/llvm/llvm-project/commit/d2a559ffc0dc61b9d7426064bd5076b66d2f96d6
DIFF: https://github.com/llvm/llvm-project/commit/d2a559ffc0dc61b9d7426064bd5076b66d2f96d6.diff
LOG: [mlir][spirv] Add OpExtension "SPV_INTEL_bfloat16_conversion"
Add Intel-specific "SPV_INTEL_bfloat16_conversion" extension and
capability (Bfloat16ConversionINTEL), and
two ops (OpConvertFToBF16INTEL, OpConvertBF16ToFINTEL)
that are introduced by this extension.
These ops allow BF16 to Float conversion and vice-versa.
Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D147087
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
mlir/test/Target/SPIRV/intel-ext-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7ca32d92c583a..43c6c3edecf44 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -399,6 +399,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
+def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -457,7 +458,7 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
- SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
+ SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
@@ -1413,6 +1414,12 @@ def SPIRV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMat
];
}
+def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_bfloat16_conversion]>
+ ];
+}
+
def SPIRV_CapabilityAttr :
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1504,7 +1511,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
- SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL
+ SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4079,6 +4086,7 @@ def SPIRV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
def SPIRV_Void : TypeAlias<NoneType, "void">;
def SPIRV_Bool : TypeAlias<I1, "bool">;
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int16 : TypeAlias<I16, "Int16">;
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4407,6 +4415,9 @@ def SPIRV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreI
def SPIRV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
+def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
+def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
+
def SPIRV_OpcodeAttr :
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued,
@@ -4492,7 +4503,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL,
SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL,
- SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL
+ SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL,
+
+ SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL
]>;
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
new file mode 100644
index 0000000000000..55753b316fe66
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -0,0 +1,126 @@
+//===- SPIRVIntelExtOps.td - Intel SPIR-V extensions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Intel-specific SPIR-V extensions
+// These extensions are not part of Khronos specification but publicly available
+// at (https://github.com/intel/llvm)
+// Supported extensions
+// * SPV_INTEL_bfloat16_conversion
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
+#define MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
+
+// -----
+
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
+ let summary = "See extension SPV_INTEL_bfloat16_conversion";
+
+ let description = [{
+ Convert value numerically from 32-bit floating point to bfloat16,
+ which is represented as a 16-bit unsigned integer.
+
+ Result Type must be a scalar or vector of integer type.
+ The component width must be 16 bits. Bit pattern in the Result represents a bfloat16 value.
+
+ Float Value must be a scalar or vector of floating-point type.
+ It must have the same number of components as Result Type. The component width must be 32 bits.
+
+ Results are computed per component.
+
+ ```
+ convert-f-to-bf16-op ::= ssa-id `=` `spirv.INTEL.ConvertFToBF16` ssa-use
+ `:` operand-type `to` result-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %1 = spirv.ConvertFToBF16 %0 : f32 to i16
+ %3 = spirv.ConvertFToBF16 %2 : vector<3xf32> to vector<3xi16>
+ ```
+
+ }];
+
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_bfloat16_conversion]>,
+ Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
+ );
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
+
+ let hasVerifier = 1;
+}
+
+// -----
+
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
+ let summary = "See extension SPV_INTEL_bfloat16_conversion";
+
+ let description = [{
+ Interpret a 16-bit integer as bfloat16 and convert the value numerically to 32-bit floating point type.
+
+ Result Type must be a scalar or vector of floating-point. The component width must be 32 bits.
+
+ Bfloat16 Value must be a scalar or vector of integer type, which is interpreted as a bfloat16 type.
+ The type must have the same number of components as the Result Type. The component width must be 16 bits.
+
+ Results are computed per component.
+
+ ```
+ convert-bf16-to-f-op ::= ssa-id `=` `spirv.INTEL.ConvertBF16ToF` ssa-use
+ `:` operand-type `to` result-type
+ ```
+
+ #### Example:
+
+ ```mlir
+ %1 = spirv.ConvertBF16ToF %0 : i16 to f32
+ %3 = spirv.ConvertBF16ToF %2 : vector<3xi16> to vector<3xf32>
+ ```
+
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_INTEL_bfloat16_conversion]>,
+ Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
+ ];
+
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+ );
+
+ let assemblyFormat = [{
+ $operand attr-dict `:` type($operand) `to` type($result)
+ }];
+ let hasVerifier = 1;
+}
+
+
+// -----
+
+#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 767e939f04473..13533d1d65b8f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -31,6 +31,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bb3ad91ce620a..181c9e0a23bb7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2201,6 +2201,46 @@ LogicalResult spirv::ConvertUToFOp::verify() {
/*skipBitWidthCheck=*/true);
}
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertBF16ToFOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertFToBF16Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELConvertFToBF16Op::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
new file mode 100644
index 0000000000000..53a1015de75bc
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.ConvertFToBF16
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_bf16(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16>
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
+ // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
+ spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+ // expected-error @+1 {{operand and result must have same number of elements}}
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
+ spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.ConvertBF16ToF
+//===----------------------------------------------------------------------===//
+
+spirv.func @bf16_to_f32(%arg0 : i16) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32
+ spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32>
+ spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
+ // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
+ spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
+ // expected-error @+1 {{operand and result must have same number of elements}}
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
+ spirv.Return
+}
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
new file mode 100644
index 0000000000000..fe86fd2b7be25
--- /dev/null
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL], [SPV_INTEL_bfloat16_conversion]> {
+ // CHECK-LABEL: @f32_to_bf16
+ spirv.func @f32_to_bf16(%arg0 : f32) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @f32_to_bf16_vec
+ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16>
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16>
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @bf16_to_f32
+ spirv.func @bf16_to_f32(%arg0 : i16) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32
+ spirv.Return
+ }
+
+ // CHECK-LABEL: @bf16_to_f32_vec
+ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32>
+ spirv.Return
+ }
+}
More information about the Mlir-commits
mailing list