[Mlir-commits] [mlir] d2a559f - [mlir][spirv] Add OpExtension "SPV_INTEL_bfloat16_conversion"

Lei Zhang llvmlistbot at llvm.org
Fri Mar 31 14:12:28 PDT 2023


Author: Md Abdullah Shahneous Bari
Date: 2023-03-31T14:02:59-07:00
New Revision: d2a559ffc0dc61b9d7426064bd5076b66d2f96d6

URL: https://github.com/llvm/llvm-project/commit/d2a559ffc0dc61b9d7426064bd5076b66d2f96d6
DIFF: https://github.com/llvm/llvm-project/commit/d2a559ffc0dc61b9d7426064bd5076b66d2f96d6.diff

LOG: [mlir][spirv] Add OpExtension "SPV_INTEL_bfloat16_conversion"

Add Intel-specific "SPV_INTEL_bfloat16_conversion" extension and
capability (Bfloat16ConversionINTEL), and
two ops (OpConvertFToBF16INTEL, OpConvertBF16ToFINTEL)
that are introduced by this extension.
These ops allow BF16 to Float conversion and vice-versa.

Reference Specification:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D147087

Added: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
    mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
    mlir/test/Target/SPIRV/intel-ext-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 7ca32d92c583a..43c6c3edecf44 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -399,6 +399,7 @@ def SPV_INTEL_fp_fast_math_mode                  : I32EnumAttrCase<"SPV_INTEL_fp
 def SPV_INTEL_memory_access_aliasing             : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
 def SPV_INTEL_split_barrier                      : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
 def SPV_INTEL_joint_matrix                       : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
+def SPV_INTEL_bfloat16_conversion                : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
 
 def SPV_NV_compute_shader_derivatives    : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
 def SPV_NV_cooperative_matrix            : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -457,7 +458,7 @@ def SPIRV_ExtensionAttr :
       SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
       SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
       SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
-      SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
+      SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
       SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
       SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
       SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
@@ -1413,6 +1414,12 @@ def SPIRV_C_JointMatrixINTEL                         : I32EnumAttrCase<"JointMat
   ];
 }
 
+def SPIRV_C_Bfloat16ConversionINTEL                         : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
+  list<Availability> availability = [
+    Extension<[SPV_INTEL_bfloat16_conversion]>
+  ];
+}
+
 def SPIRV_CapabilityAttr :
     SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
       SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1504,7 +1511,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
       SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
       SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
-      SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL
+      SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL
     ]>;
 
 def SPIRV_AM_Logical                 : I32EnumAttrCase<"Logical", 0>;
@@ -4079,6 +4086,7 @@ def SPIRV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
 def SPIRV_Void : TypeAlias<NoneType, "void">;
 def SPIRV_Bool : TypeAlias<I1, "bool">;
 def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_Int16 : TypeAlias<I16, "Int16">;
 def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
@@ -4407,6 +4415,9 @@ def SPIRV_OC_OpJointMatrixStoreINTEL      : I32EnumAttrCase<"OpJointMatrixStoreI
 def SPIRV_OC_OpJointMatrixMadINTEL        : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
 def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
 
+def SPIRV_OC_OpConvertFToBF16INTEL        : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
+def SPIRV_OC_OpConvertBF16ToFINTEL        : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
+
 def SPIRV_OpcodeAttr :
     SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
       SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued,
@@ -4492,7 +4503,9 @@ def SPIRV_OpcodeAttr :
 
       SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL,
       SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL,
-      SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL
+      SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL,
+
+      SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
new file mode 100644
index 0000000000000..55753b316fe66
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
@@ -0,0 +1,126 @@
+//===- SPIRVIntelExtOps.td - Intel SPIR-V extensions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Intel-specific SPIR-V extensions
+// These extensions are not part of Khronos specification but publicly available
+// at (https://github.com/intel/llvm)
+// Supported extensions
+// * SPV_INTEL_bfloat16_conversion
+//===----------------------------------------------------------------------===//
+
+
+#ifndef MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
+#define MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
+
+// -----
+
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
+  let summary = "See extension SPV_INTEL_bfloat16_conversion";
+
+  let description = [{
+    Convert value numerically from 32-bit floating point to bfloat16,
+    which is represented as a 16-bit unsigned integer.
+
+    Result Type must be a scalar or vector of integer type.
+    The component width must be 16 bits. Bit pattern in the Result represents a bfloat16 value.
+
+    Float Value must be a scalar or vector of floating-point type.
+    It must have the same number of components as Result Type. The component width must be 32 bits.
+
+    Results are computed per component.
+
+    ```
+    convert-f-to-bf16-op ::= ssa-id `=` `spirv.INTEL.ConvertFToBF16` ssa-use
+                          `:` operand-type `to` result-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.ConvertFToBF16 %0 : f32 to i16
+    %3 = spirv.ConvertFToBF16 %2 : vector<3xf32> to vector<3xi16>
+    ```
+
+  }];
+
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_bfloat16_conversion]>,
+    Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
+  );
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+
+  let hasVerifier = 1;
+}
+
+// -----
+
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
+  let summary = "See extension SPV_INTEL_bfloat16_conversion";
+
+  let description = [{
+    Interpret a 16-bit integer as bfloat16 and convert the value numerically to 32-bit floating point type.
+
+    Result Type must be a scalar or vector of floating-point. The component width must be 32 bits.
+
+    Bfloat16 Value must be a scalar or vector of integer type, which is interpreted as a bfloat16 type.
+    The type must have the same number of components as the Result Type. The component width must be 16 bits.
+
+    Results are computed per component.
+
+    ```
+    convert-bf16-to-f-op ::= ssa-id `=` `spirv.INTEL.ConvertBF16ToF` ssa-use
+                          `:` operand-type `to` result-type
+    ```
+
+    #### Example:
+
+    ```mlir
+    %1 = spirv.ConvertBF16ToF %0 : i16 to f32
+    %3 = spirv.ConvertBF16ToF %2 : vector<3xi16> to vector<3xf32>
+    ```
+
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_INTEL_bfloat16_conversion]>,
+    Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
+  ];
+
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
+  );
+
+  let assemblyFormat = [{
+    $operand attr-dict `:` type($operand) `to` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+
+// -----
+
+#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
\ No newline at end of file

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 767e939f04473..13533d1d65b8f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -31,6 +31,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index bb3ad91ce620a..181c9e0a23bb7 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2201,6 +2201,46 @@ LogicalResult spirv::ConvertUToFOp::verify() {
                       /*skipBitWidthCheck=*/true);
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertBF16ToFOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getResult().getType();
+  // ODS checks that vector result type and vector operand type have the same
+  // shape.
+  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+    unsigned operandNumElements = vectorType.getNumElements();
+    unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+    if (operandNumElements != resultNumElements) {
+      return emitOpError(
+          "operand and result must have same number of elements");
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertFToBF16Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELConvertFToBF16Op::verify() {
+  auto operandType = getOperand().getType();
+  auto resultType = getResult().getType();
+  // ODS checks that vector result type and vector operand type have the same
+  // shape.
+  if (auto vectorType = operandType.dyn_cast<VectorType>()) {
+    unsigned operandNumElements = vectorType.getNumElements();
+    unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
+    if (operandNumElements != resultNumElements) {
+      return emitOpError(
+          "operand and result must have same number of elements");
+    }
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.EntryPoint
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
new file mode 100644
index 0000000000000..53a1015de75bc
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.ConvertFToBF16
+//===----------------------------------------------------------------------===//
+
+spirv.func @f32_to_bf16(%arg0 : f32) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16
+  %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16>
+  %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
+  // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
+  spirv.Return
+}
+
+// -----
+
+spirv.func @f32_to_bf16_vec_unsupported(%arg0 : vector<2xf32>) "None" {
+  // expected-error @+1 {{operand and result must have same number of elements}}
+  %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<4xi16>
+  spirv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.ConvertBF16ToF
+//===----------------------------------------------------------------------===//
+
+spirv.func @bf16_to_f32(%arg0 : i16) "None" {
+  // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32
+  %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32
+  spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
+    %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32>
+    spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_unsupported(%arg0 : i16) "None" {
+  // expected-error @+1 {{result #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}}
+  %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f16
+  spirv.Return
+}
+
+// -----
+
+spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
+  // expected-error @+1 {{operand and result must have same number of elements}}
+  %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
+  spirv.Return
+}

diff  --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
new file mode 100644
index 0000000000000..fe86fd2b7be25
--- /dev/null
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL], [SPV_INTEL_bfloat16_conversion]> {
+  // CHECK-LABEL: @f32_to_bf16
+  spirv.func @f32_to_bf16(%arg0 : f32) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16
+    %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @f32_to_bf16_vec
+  spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16>
+    %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16>
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @bf16_to_f32
+  spirv.func @bf16_to_f32(%arg0 : i16) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32
+    %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32
+    spirv.Return
+  }
+
+  // CHECK-LABEL: @bf16_to_f32_vec
+  spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
+    // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
+    %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32>
+    spirv.Return
+  }
+}


        


More information about the Mlir-commits mailing list