[Mlir-commits] [mlir] 03e6bf5 - [mlir][spirv] Define `spirv.*Dot` integer dot product ops
Jakub Kuderski
llvmlistbot at llvm.org
Tue Dec 6 17:21:47 PST 2022
Author: Jakub Kuderski
Date: 2022-12-06T20:17:41-05:00
New Revision: 03e6bf5f564c440ffbbac3a7a30015b6ca779afe
URL: https://github.com/llvm/llvm-project/commit/03e6bf5f564c440ffbbac3a7a30015b6ca779afe
DIFF: https://github.com/llvm/llvm-project/commit/03e6bf5f564c440ffbbac3a7a30015b6ca779afe.diff
LOG: [mlir][spirv] Define `spirv.*Dot` integer dot product ops
This covers `SDot`, `SUDot`, and `UDot`. The `*AccSat` version will be
added in a follow-up revision.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D139242
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/availability.mlir
mlir/test/Dialect/SPIRV/IR/target-env.mlir
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/utils/spirv/define_enum.sh
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 1d6c98d017350..f18796b7c96ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -76,7 +76,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
}]>
];
- // These op require a custom verifier.
+ // These ops require a custom verifier.
let hasVerifier = 1;
}
@@ -423,75 +423,6 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
// -----
-def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
- [Pure, Commutative]> {
- let summary = [{
- Result is the full value of the signed integer multiplication of Operand
- 1 and Operand 2.
- }];
-
- let description = [{
- Result Type must be from OpTypeStruct. The struct must have two
- members, and the two members must be the same type. The member type
- must be a scalar or vector of integer type.
-
- Operand 1 and Operand 2 must have the same type as the members of Result
- Type. These are consumed as signed integers.
-
- Results are computed per component.
-
- Member 0 of the result gets the low-order bits of the multiplication.
-
- Member 1 of the result gets the high-order bits of the multiplication.
-
- <!-- End of AutoGen section -->
-
- #### Example:
-
- ```mlir
- %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
- %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
- ```
- }];
-}
-
-// -----
-
-def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
- [Pure, Commutative]> {
- let summary = [{
- Result is the full value of the unsigned integer multiplication of
- Operand 1 and Operand 2.
- }];
-
- let description = [{
- Result Type must be from OpTypeStruct. The struct must have two
- members, and the two members must be the same type. The member type
- must be a scalar or vector of integer type, whose Signedness operand is
- 0.
-
- Operand 1 and Operand 2 must have the same type as the members of Result
- Type. These are consumed as unsigned integers.
-
- Results are computed per component.
-
- Member 0 of the result gets the low-order bits of the multiplication.
-
- Member 1 of the result gets the high-order bits of the multiplication.
-
- <!-- End of AutoGen section -->
-
- #### Example:
-
- ```mlir
- %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
- %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
- ```
- }];
-}
-
-// -----
-
def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
@@ -646,6 +577,40 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
// -----
+def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
+ [Pure, Commutative]> {
+ let summary = [{
+ Result is the full value of the signed integer multiplication of Operand
+ 1 and Operand 2.
+ }];
+
+ let description = [{
+ Result Type must be from OpTypeStruct. The struct must have two
+ members, and the two members must be the same type. The member type
+ must be a scalar or vector of integer type.
+
+ Operand 1 and Operand 2 must have the same type as the members of Result
+ Type. These are consumed as signed integers.
+
+ Results are computed per component.
+
+ Member 0 of the result gets the low-order bits of the multiplication.
+
+ Member 1 of the result gets the high-order bits of the multiplication.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+ %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
@@ -654,7 +619,7 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
let description = [{
Result Type must be a scalar or vector of integer type.
- Operand’s type must be a scalar or vector of integer type. It must
+ Operand's type must be a scalar or vector of integer type. It must
have the same number of components as Result Type. The component width
must equal the component width in Result Type.
@@ -746,6 +711,41 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
// -----
+def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
+ [Pure, Commutative]> {
+ let summary = [{
+ Result is the full value of the unsigned integer multiplication of
+ Operand 1 and Operand 2.
+ }];
+
+ let description = [{
+ Result Type must be from OpTypeStruct. The struct must have two
+ members, and the two members must be the same type. The member type
+ must be a scalar or vector of integer type, whose Signedness operand is
+ 0.
+
+ Operand 1 and Operand 2 must have the same type as the members of Result
+ Type. These are consumed as unsigned integers.
+
+ Results are computed per component.
+
+ Member 0 of the result gets the low-order bits of the multiplication.
+
+ Member 1 of the result gets the high-order bits of the multiplication.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+ %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+ ```
+ }];
+}
+
+// -----
+
def SPIRV_VectorTimesScalarOp : SPIRV_Op<"VectorTimesScalar", [Pure]> {
let summary = "Scale a floating-point vector.";
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3c993cbc77445..4be10e61ddee2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3969,6 +3969,18 @@ def SPIRV_StorageClassAttr :
SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
]>;
+def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
+ list<Availability> availability = [
+ MinVersion<SPIRV_V_1_6>,
+ Extension<[SPV_KHR_integer_dot_product]>
+ ];
+}
+
+def SPIRV_PackedVectorFormatAttr :
+ SPIRV_I32EnumAttr<"PackedVectorFormat", "valid SPIR-V PackedVectorFormat", "packed_vector_format", [
+ SPIRV_PVF_PackedVectorFormat4x8Bit
+ ]>;
+
// End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
// Enums added manually that are not part of SPIR-V spec
@@ -4365,6 +4377,12 @@ def SPIRV_OC_OpGroupNonUniformSMax : I32EnumAttrCase<"OpGroupNonUniformSM
def SPIRV_OC_OpGroupNonUniformUMax : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
def SPIRV_OC_OpGroupNonUniformFMax : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
+def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
+def SPIRV_OC_OpUDot : I32EnumAttrCase<"OpUDot", 4451>;
+def SPIRV_OC_OpSUDot : I32EnumAttrCase<"OpSUDot", 4452>;
+def SPIRV_OC_OpSDotAccSat : I32EnumAttrCase<"OpSDotAccSat", 4453>;
+def SPIRV_OC_OpUDotAccSat : I32EnumAttrCase<"OpUDotAccSat", 4454>;
+def SPIRV_OC_OpSUDotAccSat : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
@@ -4457,7 +4475,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformSMin, SPIRV_OC_OpGroupNonUniformUMin,
SPIRV_OC_OpGroupNonUniformFMin, SPIRV_OC_OpGroupNonUniformSMax,
SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
- SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpTypeCooperativeMatrixNV,
+ SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
+ SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
+ SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV,
SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
@@ -4494,6 +4514,19 @@ class SPIRV_Op<string mnemonic, list<Trait> traits = []> :
Capability<[]>
];
+ // Controls whether to auto-generate this op's availability specification.
+ // If set, generates the following methods:
+ //
+ // ```c++
+ // SmallVector<ArrayRef<Capability>, 1> OpTy::getCapabilities();
+ // SmallVector<ArrayRef<Extension>, 1> OpTy::getExtensions();
+ // Optional<Version> OpTy::getMinVersion();
+ // Optional<Version> OpTy::getMaxVersion();
+ // ```
+ //
+ // When not set, manual implementation of these methods is required.
+ bit autogenAvailability = 1;
+
// For each SPIR-V op, the following static functions need to be defined
// in SPIRVOps.cpp:
//
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
new file mode 100644
index 0000000000000..451aeb27207da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -0,0 +1,190 @@
+//===-- SPIRVIntegerDotProductOps.td - MLIR SPIR-V IDP Ops -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains arithmetic ops for the SPIR-V dialect. It corresponds
+// to instructions defined by the "SPV_KHR_integer_dot_product" SPIR-V
+// extension.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
+#define MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class SPIRV_IntegerDotProductOp<string mnemonic,
+ list<Trait> traits = []> :
+ SPIRV_Op<mnemonic, !listconcat(traits, [Pure])> {
+ let results = (outs
+ SPIRV_Integer:$result
+ );
+
+ let assemblyFormat = [{
+ operands attr-dict `:` `(` type(operands) `)` `->` type($result)
+ }];
+
+ // These ops require dynamic availability specification based on operand and
+ // result types.
+ bit autogenAvailability = 0;
+
+ // These ops require a custom verifier.
+ let hasVerifier = 1;
+}
+
+class SPIRV_IntegerDotProductBinaryOp<string mnemonic,
+ list<Trait> traits = []> :
+ SPIRV_IntegerDotProductOp<mnemonic, traits> {
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
+ OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
+ );
+}
+
+class SPIRV_IntegerDotProductTernaryOp<string mnemonic,
+ list<Trait> traits = []> :
+ SPIRV_IntegerDotProductOp<mnemonic, traits> {
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
+ SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
+ SPIRV_Integer:$accumulator,
+ OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
+ );
+}
+
+// -----
+
+def SPIRV_SDotOp : SPIRV_IntegerDotProductBinaryOp<"SDot",
+ [SignedOp, Commutative]> {
+ let summary = "Signed integer dot product of Vector 1 and Vector 2.";
+
+ let description = [{
+ Result Type must be an integer type whose Width must be greater than or
+ equal to that of the components of Vector 1 and Vector 2.
+
+ Vector 1 and Vector 2 must have the same type.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type
+ (enabled by the DotProductInput4x8Bit or DotProductInputAll capability).
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of the input vectors are sign-extended to the bit width
+ of the result's type. The sign-extended input vectors are then
+ multiplied component-wise and all components of the vector resulting
+ from the component-wise multiplication are added together. The resulting
+ value will equal the low-order N bits of the correct result R, where N
+ is the result width and R is computed with enough precision to avoid
+ overflow and underflow.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_SUDotOp : SPIRV_IntegerDotProductBinaryOp<"SUDot",
+ [SignedOp, UnsignedOp]> {
+ let summary = [{
+ Mixed-signedness integer dot product of Vector 1 and Vector 2.
+ Components of Vector 1 are treated as signed, components of Vector 2 are
+ treated as unsigned.
+ }];
+
+ let description = [{
+ Result Type must be an integer type whose Width must be greater than or
+ equal to that of the components of Vector 1 and Vector 2.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type with
+ the same number of components and same component Width (enabled by the
+ DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1
+ and Vector 2 are vectors, the components of Vector 2 must have a
+ Signedness of 0.
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of Vector 1 are sign-extended to the bit width of the
+ result's type. All components of Vector 2 are zero-extended to the bit
+ width of the result's type. The sign- or zero-extended input vectors are
+ then multiplied component-wise and all components of the vector
+ resulting from the component-wise multiplication are added together. The
+ resulting value will equal the low-order N bits of the correct result R,
+ where N is the result width and R is computed with enough precision to
+ avoid overflow and underflow.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ ```
+ }];
+}
+
+// -----
+
+def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot",
+ [UnsignedOp, Commutative]> {
+ let summary = "Unsigned integer dot product of Vector 1 and Vector 2.";
+
+ let description = [{
+ Result Type must be an integer type with Signedness of 0 whose Width
+ must be greater than or equal to that of the components of Vector 1 and
+ Vector 2.
+
+ Vector 1 and Vector 2 must have the same type.
+
+ Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+ DotProductInput4x8BitPacked capability) or vectors of integer type with
+ Signedness of 0 (enabled by the DotProductInput4x8Bit or
+ DotProductInputAll capability).
+
+ When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+ Format must be specified to select how the integers are to be
+ interpreted as vectors.
+
+ All components of the input vectors are zero-extended to the bit width
+ of the result's type. The zero-extended input vectors are then
+ multiplied component-wise and all components of the vector resulting
+ from the component-wise multiplication are added together. The resulting
+ value will equal the low-order N bits of the correct result R, where N
+ is the result width and R is computed with enough precision to avoid
+ overflow and underflow.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ ```
+ }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 5e8e5e4c7ce92..767e939f04473 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -34,6 +34,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td"
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1a93882e8f5f4..888a756be5201 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,33 +29,35 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/bit.h"
+#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <numeric>
using namespace mlir;
// TODO: generate these strings using ODS.
-constexpr char kMemoryAccessAttrName[] = "memory_access";
-constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
constexpr char kAlignmentAttrName[] = "alignment";
-constexpr char kSourceAlignmentAttrName[] = "source_alignment";
constexpr char kBranchWeightAttrName[] = "branch_weights";
constexpr char kCallee[] = "callee";
constexpr char kClusterSize[] = "cluster_size";
constexpr char kControl[] = "control";
constexpr char kDefaultValueAttrName[] = "default_value";
-constexpr char kExecutionScopeAttrName[] = "execution_scope";
constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+constexpr char kExecutionScopeAttrName[] = "execution_scope";
constexpr char kFnNameAttrName[] = "fn";
constexpr char kGroupOperationAttrName[] = "group_operation";
constexpr char kIndicesAttrName[] = "indices";
constexpr char kInitializerAttrName[] = "initializer";
constexpr char kInterfaceAttrName[] = "interface";
+constexpr char kMemoryAccessAttrName[] = "memory_access";
constexpr char kMemoryScopeAttrName[] = "memory_scope";
+constexpr char kPackedVectorFormatAttrName[] = "format";
constexpr char kSemanticsAttrName[] = "semantics";
+constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
constexpr char kSpecIdAttrName[] = "spec_id";
constexpr char kTypeAttrName[] = "type";
constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
@@ -4791,6 +4793,125 @@ LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyIntegerDotProduct(Operation *op) {
+ assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
+ "Not an integer dot product op?");
+ assert(op->getNumResults() == 1 && "Expected a single result");
+
+ Type factorTy = op->getOperand(0).getType();
+ if (op->getOperand(1).getType() != factorTy)
+ return op->emitOpError("requires the same type for both vector operands");
+
+ if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+ auto packedVectorFormat =
+ op->getAttr(kPackedVectorFormatAttrName)
+ .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
+ if (!packedVectorFormat)
+ return op->emitOpError("requires Packed Vector Format attribute for "
+ "integer vector operands");
+
+ assert(packedVectorFormat.getValue() ==
+ spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
+ "unknown Packed Vector format");
+ if (intTy.getWidth() != 32)
+ return op->emitOpError(
+ llvm::formatv("with specified Packed Vector Format ({0}) requires "
+ "integer vector operands to be 32-bits wide",
+ packedVectorFormat.getValue()));
+ }
+
+ if (op->getAttrs().size() > 1)
+ return op->emitError(
+ "op only supports the 'format' #spirv.packed_vector_format attribute");
+
+ Type resultTy = op->getResultTypes().front();
+ unsigned factorBitWidth = getBitWidth(factorTy);
+ unsigned resultBitWidth = getBitWidth(resultTy);
+ if (factorBitWidth > resultBitWidth)
+ return op->emitOpError(
+ llvm::formatv("result type has insufficient bit-width ({0} bits) "
+ "for the specified vector operand type ({1} bits)",
+ resultBitWidth, factorBitWidth));
+
+ return success();
+}
+
+static Optional<spirv::Version> getIntegerDotProductMinVersion() {
+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static Optional<spirv::Version> getIntegerDotProductMaxVersion() {
+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+static SmallVector<ArrayRef<spirv::Extension>, 1>
+getIntegerDotProductExtensions() {
+ // Requires the SPV_KHR_integer_dot_product extension, specified either
+ // explicitly or implied by target env's SPIR-V version >= 1.6.
+ static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
+ return {extension};
+}
+
+static SmallVector<ArrayRef<spirv::Capability>, 1>
+getIntegerDotProductCapabilities(Operation *op) {
+ // Requires the the DotProduct capability and capabilities that depend on
+ // exact op types.
+ static const auto dotProductCap = spirv::Capability::DotProduct;
+ static const auto dotProductInput4x8BitPackedCap =
+ spirv::Capability::DotProductInput4x8BitPacked;
+ static const auto dotProductInput4x8BitCap =
+ spirv::Capability::DotProductInput4x8Bit;
+ static const auto dotProductInputAllCap =
+ spirv::Capability::DotProductInputAll;
+
+ SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
+
+ Type factorTy = op->getOperand(0).getType();
+ if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+ auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
+ .cast<spirv::PackedVectorFormatAttr>();
+ if (formatAttr.getValue() ==
+ spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
+ capabilities.push_back(dotProductInput4x8BitPackedCap);
+
+ return capabilities;
+ }
+
+ auto vecTy = factorTy.cast<VectorType>();
+ if (vecTy.getElementTypeBitWidth() == 8) {
+ capabilities.push_back(dotProductInput4x8BitCap);
+ return capabilities;
+ }
+
+ capabilities.push_back(dotProductInputAllCap);
+ return capabilities;
+}
+
+#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
+ LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
+ SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
+ return getIntegerDotProductExtensions(); \
+ } \
+ SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
+ return getIntegerDotProductCapabilities(*this); \
+ } \
+ Optional<spirv::Version> OpName::getMinVersion() { \
+ return getIntegerDotProductMinVersion(); \
+ } \
+ Optional<spirv::Version> OpName::getMaxVersion() { \
+ return getIntegerDotProductMaxVersion(); \
+ }
+
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
+
+#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
+
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 290e07de41d53..5cd7253da620b 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -49,3 +49,97 @@ func.func @module_physical_storage_buffer64_vulkan() {
spirv.module PhysicalStorageBuffer64 Vulkan { }
return
}
+
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: sdot_scalar_i32_i32
+func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.SDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: sdot_vector_4xi8_i64
+func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.SDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sdot_vector_4xi16_i64
+func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.SDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sudot_scalar_i32_i32
+func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.SUDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: sudot_vector_4xi8_i64
+func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.SUDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: sudot_vector_4xi16_i64
+func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.SUDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: udot_scalar_i32_i32
+func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+ %r = spirv.UDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r: i32
+}
+
+// CHECK-LABEL: udot_vector_4xi8_i64
+func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+ %r = spirv.UDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+ return %r: i64
+}
+
+// CHECK-LABEL: udot_vector_4xi16_i64
+func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+ // CHECK: min version: v1.0
+ // CHECK: max version: v1.6
+ // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+ // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+ %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+ return %r: i64
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
new file mode 100644
index 0000000000000..c0c5cf39b03fd
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// This test covers the Integer Dot Product ops defined in the
+// SPV_KHR_integer_dot_product extension.
+
+//===----------------------------------------------------------------------===//
+// spirv.SDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sdot_scalar_i32
+func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
+ // CHECK-NEXT: spirv.SDot
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sdot_scalar_i64
+func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
+ // CHECK-NEXT: spirv.SDot
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ return %r : i64
+}
+
+// CHECK: @sdot_vector_4xi8
+func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+ // CHECK-NEXT: spirv.SDot
+ %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ return %r : i32
+}
+
+// CHECK: @sdot_vector_4xi16
+func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
+ // CHECK-NEXT: spirv.SDot
+ %r = spirv.SDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+ return %r : i64
+}
+
+// CHECK: @sdot_vector_8xi8
+func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
+ // CHECK-NEXT: spirv.SDot
+ %r = spirv.SDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+ return %r : i64
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
+ // expected-error @+1 {{op requires the same type for both vector operands}}
+ %r = spirv.SDot %a, %b : (i32, i64) -> i32
+ return %r : i32
+}
+
+// -----
+
+func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 {
+ // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
+ %r = spirv.SDot %a, %b {volatile = #spirv.decoration<Volatile>,
+ format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r : i32
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
+ // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i16
+ return %r : i16
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
+ // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
+ %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64) -> i64
+ return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SUDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sudot_scalar_i32
+func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
+ // CHECK-NEXT: spirv.SUDot
+ %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @sudot_scalar_i64
+func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
+ // CHECK-NEXT: spirv.SUDot
+ %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ return %r : i64
+}
+
+// CHECK: @sudot_vector_4xi8
+func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+ // CHECK-NEXT: spirv.SUDot
+ %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ return %r : i32
+}
+
+// CHECK: @sudot_vector_4xi16
+func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
+ // CHECK-NEXT: spirv.SUDot
+ %r = spirv.SUDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+ return %r : i64
+}
+
+// CHECK: @sudot_vector_8xi8
+func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
+ // CHECK-NEXT: spirv.SUDot
+ %r = spirv.SUDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+ return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @udot_scalar_i32
+func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
+ // CHECK-NEXT: spirv.UDot
+ %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+ return %r : i32
+}
+
+// CHECK: @udot_scalar_i64
+func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
+ // CHECK-NEXT: spirv.UDot
+ %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+ return %r : i64
+}
+
+// CHECK: @udot_vector_4xi8
+func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+ // CHECK-NEXT: spirv.UDot
+ %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+ return %r : i32
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index ecf87678473b7..91ffdf26242fc 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -19,6 +19,9 @@
// spirv.KHR.SubgroupBallot is available under in all SPIR-V versions under
// SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.
+// Integer Dot Product ops (spirv.*Dot*) require the
+// SPV_KHR_integer_dot_product extension and a number of related capabilities.
+
// The GeometryPointSize capability implies the Geometry capability, which
// implies the Shader capability.
@@ -122,6 +125,96 @@ func.func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attr
return %0: i32
}
+// CHECK-LABEL: @sdot_scalar_i32_i32_capabilities
+func.func @sdot_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.SDot
+ %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability1
+func.func @sdot_scalar_i32_i32_missing_capability1(%operand: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_sdot_op
+ %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability2
+func.func @sdot_scalar_i32_i32_missing_capability2(%operand: i32) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_sdot_op
+ %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_capabilities
+func.func @sudot_vector_4xi8_i32_capabilities(%operand: vector<4xi8>) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.SUDot
+ %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability1
+func.func @sudot_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_sudot_op
+ %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability2
+func.func @sudot_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>) -> i32 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_sudot_op
+ %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+ return %0: i32
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_capabilities
+func.func @udot_vector_4xi16_i64_capabilities(%operand: vector<4xi16>) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.UDot
+ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability1
+func.func @udot_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProduct, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_op
+ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability2
+func.func @udot_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>) -> i64 attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+ [DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_op
+ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0: i64
+}
+
//===----------------------------------------------------------------------===//
// Extension
//===----------------------------------------------------------------------===//
@@ -189,3 +282,25 @@ func.func @module_implied_extension() attributes {
"test.convert_to_module_op"() : () -> ()
return
}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_implied_extension
+func.func @udot_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>) -> i64 attributes {
+ // Version 1.6 implies SPV_KHR_integer_to_product.
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+ [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK: spirv.UDot
+ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_extension
+func.func @udot_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>) -> i64 attributes {
+ // Version 1.5 does not imply SPV_KHR_integer_to_product.
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+ [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+ // CHECK: test.convert_to_udot_op
+ %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+ return %0: i64
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index e29d167e1e5d8..13c35ca5b150d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -191,6 +191,19 @@ struct ConvertToSubgroupBallot : RewritePattern {
return success();
}
};
+
+template <const char *TestOpName, typename SPIRVOp>
+struct ConvertToIntegerDotProd : RewritePattern {
+ ConvertToIntegerDotProd(MLIRContext *context)
+ : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
+ op->getOperands(), op->getAttrs());
+ return success();
+ }
+};
} // namespace
void ConvertToTargetEnv::runOnOperation() {
@@ -207,10 +220,17 @@ void ConvertToTargetEnv::runOnOperation() {
auto target = SPIRVConversionTarget::get(targetEnv);
+ static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
+ static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
+ static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
+
RewritePatternSet patterns(context);
patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
ConvertToGroupNonUniformBallot, ConvertToModule,
- ConvertToSubgroupBallot>(context);
+ ConvertToSubgroupBallot,
+ ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
+ ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
+ ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>>(context);
if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 014c94783119b..dad90566ffff8 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -1395,7 +1395,8 @@ static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
for (const auto *def : defs) {
Operator op(def);
- emitAvailabilityImpl(op, os);
+ if (def->getValueAsBit("autogenAvailability"))
+ emitAvailabilityImpl(op, os);
}
return false;
}
diff --git a/mlir/utils/spirv/define_enum.sh b/mlir/utils/spirv/define_enum.sh
index 496f90c320789..ca9d8642cd902 100755
--- a/mlir/utils/spirv/define_enum.sh
+++ b/mlir/utils/spirv/define_enum.sh
@@ -12,7 +12,7 @@
# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
# SPIR-V enum classes.
#
-# If <enum-name> is missing, this script updates existing ones.
+# If <enum-class-name> is missing, this script updates existing ones.
set -e
More information about the Mlir-commits
mailing list