[Mlir-commits] [mlir] 03e6bf5 - [mlir][spirv] Define `spirv.*Dot` integer dot product ops

Jakub Kuderski llvmlistbot at llvm.org
Tue Dec 6 17:21:47 PST 2022


Author: Jakub Kuderski
Date: 2022-12-06T20:17:41-05:00
New Revision: 03e6bf5f564c440ffbbac3a7a30015b6ca779afe

URL: https://github.com/llvm/llvm-project/commit/03e6bf5f564c440ffbbac3a7a30015b6ca779afe
DIFF: https://github.com/llvm/llvm-project/commit/03e6bf5f564c440ffbbac3a7a30015b6ca779afe.diff

LOG: [mlir][spirv] Define `spirv.*Dot` integer dot product ops

This covers `SDot`, `SUDot`, and `UDot`. The `*AccSat` version will be
added in a follow-up revision.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D139242

Added: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
    mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/Dialect/SPIRV/IR/availability.mlir
    mlir/test/Dialect/SPIRV/IR/target-env.mlir
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
    mlir/utils/spirv/define_enum.sh

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 1d6c98d017350..f18796b7c96ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -76,7 +76,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
     }]>
   ];
 
-  // These op require a custom verifier.
+  // These ops require a custom verifier.
   let hasVerifier = 1;
 }
 
@@ -423,75 +423,6 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
 
 // -----
 
-def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
-                                                            [Pure, Commutative]> {
-  let summary = [{
-    Result is the full value of the signed integer multiplication of Operand
-    1 and Operand 2.
-  }];
-
-  let description = [{
-    Result Type must be from OpTypeStruct.  The struct must have two
-    members, and the two members must be the same type.  The member type
-    must be a scalar or vector of integer type.
-
-    Operand 1 and Operand 2 must have the same type as the members of Result
-    Type. These are consumed as signed integers.
-
-    Results are computed per component.
-
-    Member 0 of the result gets the low-order bits of the multiplication.
-
-    Member 1 of the result gets the high-order bits of the multiplication.
-
-    <!-- End of AutoGen section -->
-
-    #### Example:
-
-    ```mlir
-    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
-    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
-    ```
-  }];
-}
-
-// -----
-
-def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
-                                                            [Pure, Commutative]> {
-  let summary = [{
-    Result is the full value of the unsigned integer multiplication of
-    Operand 1 and Operand 2.
-  }];
-
-  let description = [{
-    Result Type must be from OpTypeStruct.  The struct must have two
-    members, and the two members must be the same type.  The member type
-    must be a scalar or vector of integer type, whose Signedness operand is
-    0.
-
-    Operand 1 and Operand 2 must have the same type as the members of Result
-    Type. These are consumed as unsigned integers.
-
-    Results are computed per component.
-
-    Member 0 of the result gets the low-order bits of the multiplication.
-
-    Member 1 of the result gets the high-order bits of the multiplication.
-
-    <!-- End of AutoGen section -->
-
-    #### Example:
-
-    ```mlir
-    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
-    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
-    ```
-  }];
-}
-
-// -----
-
 def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
                                         SPIRV_Integer,
                                         [UsableInSpecConstantOp]> {
@@ -646,6 +577,40 @@ def SPIRV_SModOp : SPIRV_ArithmeticBinaryOp<"SMod",
 
 // -----
 
+def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
+                                                            [Pure, Commutative]> {
+  let summary = [{
+    Result is the full value of the signed integer multiplication of Operand
+    1 and Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be from OpTypeStruct.  The struct must have two
+    members, and the two members must be the same type.  The member type
+    must be a scalar or vector of integer type.
+
+    Operand 1 and Operand 2 must have the same type as the members of Result
+    Type. These are consumed as signed integers.
+
+    Results are computed per component.
+
+    Member 0 of the result gets the low-order bits of the multiplication.
+
+    Member 1 of the result gets the high-order bits of the multiplication.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+    %2 = spirv.SMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
                                           SPIRV_Integer,
                                           [UsableInSpecConstantOp]> {
@@ -654,7 +619,7 @@ def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
   let description = [{
     Result Type must be a scalar or vector of integer type.
 
-    Operand’s type  must be a scalar or vector of integer type.  It must
+    Operand's type  must be a scalar or vector of integer type.  It must
     have the same number of components as Result Type.  The component width
     must equal the component width in Result Type.
 
@@ -746,6 +711,41 @@ def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
 
 // -----
 
+def SPIRV_UMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"UMulExtended",
+                                                            [Pure, Commutative]> {
+  let summary = [{
+    Result is the full value of the unsigned integer multiplication of
+    Operand 1 and Operand 2.
+  }];
+
+  let description = [{
+    Result Type must be from OpTypeStruct.  The struct must have two
+    members, and the two members must be the same type.  The member type
+    must be a scalar or vector of integer type, whose Signedness operand is
+    0.
+
+    Operand 1 and Operand 2 must have the same type as the members of Result
+    Type. These are consumed as unsigned integers.
+
+    Results are computed per component.
+
+    Member 0 of the result gets the low-order bits of the multiplication.
+
+    Member 1 of the result gets the high-order bits of the multiplication.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(i32, i32)>
+    %2 = spirv.UMulExtended %0, %1 : !spirv.struct<(vector<2xi32>, vector<2xi32>)>
+    ```
+  }];
+}
+
+// -----
+
 def SPIRV_VectorTimesScalarOp : SPIRV_Op<"VectorTimesScalar", [Pure]> {
   let summary = "Scale a floating-point vector.";
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3c993cbc77445..4be10e61ddee2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3969,6 +3969,18 @@ def SPIRV_StorageClassAttr :
       SPIRV_SC_CodeSectionINTEL, SPIRV_SC_DeviceOnlyINTEL, SPIRV_SC_HostOnlyINTEL
     ]>;
 
+def SPIRV_PVF_PackedVectorFormat4x8Bit : I32EnumAttrCase<"PackedVectorFormat4x8Bit", 0> {
+  list<Availability> availability = [
+    MinVersion<SPIRV_V_1_6>,
+    Extension<[SPV_KHR_integer_dot_product]>
+  ];
+}
+
+def SPIRV_PackedVectorFormatAttr :
+    SPIRV_I32EnumAttr<"PackedVectorFormat", "valid SPIR-V PackedVectorFormat", "packed_vector_format", [
+      SPIRV_PVF_PackedVectorFormat4x8Bit
+    ]>;
+
 // End enum section. Generated from SPIR-V spec; DO NOT MODIFY!
 
 // Enums added manually that are not part of SPIR-V spec
@@ -4365,6 +4377,12 @@ def SPIRV_OC_OpGroupNonUniformSMax        : I32EnumAttrCase<"OpGroupNonUniformSM
 def SPIRV_OC_OpGroupNonUniformUMax        : I32EnumAttrCase<"OpGroupNonUniformUMax", 357>;
 def SPIRV_OC_OpGroupNonUniformFMax        : I32EnumAttrCase<"OpGroupNonUniformFMax", 358>;
 def SPIRV_OC_OpSubgroupBallotKHR          : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
+def SPIRV_OC_OpSDot                       : I32EnumAttrCase<"OpSDot", 4450>;
+def SPIRV_OC_OpUDot                       : I32EnumAttrCase<"OpUDot", 4451>;
+def SPIRV_OC_OpSUDot                      : I32EnumAttrCase<"OpSUDot", 4452>;
+def SPIRV_OC_OpSDotAccSat                 : I32EnumAttrCase<"OpSDotAccSat", 4453>;
+def SPIRV_OC_OpUDotAccSat                 : I32EnumAttrCase<"OpUDotAccSat", 4454>;
+def SPIRV_OC_OpSUDotAccSat                : I32EnumAttrCase<"OpSUDotAccSat", 4455>;
 def SPIRV_OC_OpTypeCooperativeMatrixNV    : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
 def SPIRV_OC_OpCooperativeMatrixLoadNV    : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
 def SPIRV_OC_OpCooperativeMatrixStoreNV   : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
@@ -4457,7 +4475,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformSMin, SPIRV_OC_OpGroupNonUniformUMin,
       SPIRV_OC_OpGroupNonUniformFMin, SPIRV_OC_OpGroupNonUniformSMax,
       SPIRV_OC_OpGroupNonUniformUMax, SPIRV_OC_OpGroupNonUniformFMax,
-      SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpTypeCooperativeMatrixNV,
+      SPIRV_OC_OpSubgroupBallotKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
+      SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
+      SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixNV,
       SPIRV_OC_OpCooperativeMatrixLoadNV, SPIRV_OC_OpCooperativeMatrixStoreNV,
       SPIRV_OC_OpCooperativeMatrixMulAddNV, SPIRV_OC_OpCooperativeMatrixLengthNV,
       SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
@@ -4494,6 +4514,19 @@ class SPIRV_Op<string mnemonic, list<Trait> traits = []> :
     Capability<[]>
   ];
 
+  // Controls whether to auto-generate this op's availability specification.
+  // If set, generates the following methods:
+  //
+  // ```c++
+  // SmallVector<ArrayRef<Capability>, 1> OpTy::getCapabilities();
+  // SmallVector<ArrayRef<Extension>, 1>  OpTy::getExtensions();
+  // Optional<Version>                    OpTy::getMinVersion();
+  // Optional<Version>                    OpTy::getMaxVersion();
+  // ```
+  //
+  // When not set, manual implementation of these methods is required.
+  bit autogenAvailability = 1;
+
   // For each SPIR-V op, the following static functions need to be defined
   // in SPIRVOps.cpp:
   //

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
new file mode 100644
index 0000000000000..451aeb27207da
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td
@@ -0,0 +1,190 @@
+//===-- SPIRVIntegerDotProductOps.td - MLIR SPIR-V IDP Ops -*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains arithmetic ops for the SPIR-V dialect. It corresponds
+// to instructions defined by the "SPV_KHR_integer_dot_product" SPIR-V
+// extension.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
+#define MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+
+class SPIRV_IntegerDotProductOp<string mnemonic,
+                                list<Trait> traits = []> :
+      SPIRV_Op<mnemonic, !listconcat(traits, [Pure])> {
+  let results = (outs
+    SPIRV_Integer:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` `(` type(operands) `)` `->` type($result)
+  }];
+
+  // These ops require dynamic availability specification based on operand and
+  // result types.
+  bit autogenAvailability = 0;
+
+  // These ops require a custom verifier.
+  let hasVerifier = 1;
+}
+
+class SPIRV_IntegerDotProductBinaryOp<string mnemonic,
+                                      list<Trait> traits = []> :
+      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
+    OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
+  );
+}
+
+class SPIRV_IntegerDotProductTernaryOp<string mnemonic,
+                                       list<Trait> traits = []> :
+      SPIRV_IntegerDotProductOp<mnemonic, traits> {
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector1,
+    SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$vector2,
+    SPIRV_Integer:$accumulator,
+    OptionalAttr<SPIRV_PackedVectorFormatAttr>:$format
+  );
+}
+
+// -----
+
+def SPIRV_SDotOp : SPIRV_IntegerDotProductBinaryOp<"SDot",
+                                                   [SignedOp, Commutative]> {
+  let summary = "Signed integer dot product of Vector 1 and Vector 2.";
+
+  let description = [{
+    Result Type must be an integer type whose Width must be greater than or
+    equal to that of the components of Vector 1 and Vector 2.
+
+    Vector 1 and Vector 2 must have the same type.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type
+    (enabled by the DotProductInput4x8Bit or DotProductInputAll capability).
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of the input vectors are sign-extended to the bit width
+    of the result's type. The sign-extended input vectors are then
+    multiplied component-wise and all components of the vector resulting
+    from the component-wise multiplication are added together. The resulting
+    value will equal the low-order N bits of the correct result R, where N
+    is the result width and R is computed with enough precision to avoid
+    overflow and underflow.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+    %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+    %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_SUDotOp : SPIRV_IntegerDotProductBinaryOp<"SUDot",
+                                                    [SignedOp, UnsignedOp]> {
+  let summary = [{
+    Mixed-signedness integer dot product of Vector 1 and Vector 2.
+    Components of Vector 1 are treated as signed, components of Vector 2 are
+    treated as unsigned.
+  }];
+
+  let description = [{
+    Result Type must be an integer type whose Width must be greater than or
+    equal to that of the components of Vector 1 and Vector 2.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type with
+    the same number of components and same component Width (enabled by the
+    DotProductInput4x8Bit or DotProductInputAll capability). When Vector 1
+    and Vector 2 are vectors, the components of Vector 2 must have a
+    Signedness of 0.
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of Vector 1 are sign-extended to the bit width of the
+    result's type. All components of Vector 2 are zero-extended to the bit
+    width of the result's type. The sign- or zero-extended input vectors are
+    then multiplied component-wise and all components of the vector
+    resulting from the component-wise multiplication are added together. The
+    resulting value will equal the low-order N bits of the correct result R,
+    where N is the result width and R is computed with enough precision to
+    avoid overflow and underflow.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+    %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+    %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+    ```
+  }];
+}
+
+// -----
+
+def SPIRV_UDotOp : SPIRV_IntegerDotProductBinaryOp<"UDot",
+                                                   [UnsignedOp, Commutative]> {
+  let summary = "Unsigned integer dot product of Vector 1 and Vector 2.";
+
+  let description = [{
+    Result Type must be an integer type with Signedness of 0 whose Width
+    must be greater than or equal to that of the components of Vector 1 and
+    Vector 2.
+
+    Vector 1 and Vector 2 must have the same type.
+
+    Vector 1 and Vector 2 must be either 32-bit integers (enabled by the
+    DotProductInput4x8BitPacked capability) or vectors of integer type with
+    Signedness of 0 (enabled by the DotProductInput4x8Bit or
+    DotProductInputAll capability).
+
+    When Vector 1 and Vector 2 are scalar integer types, Packed Vector
+    Format must be specified to select how the integers are to be
+    interpreted as vectors.
+
+    All components of the input vectors are zero-extended to the bit width
+    of the result's type. The zero-extended input vectors are then
+    multiplied component-wise and all components of the vector resulting
+    from the component-wise multiplication are added together. The resulting
+    value will equal the low-order N bits of the correct result R, where N
+    is the result width and R is computed with enough precision to avoid
+    overflow and underflow.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+    %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+    %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+    ```
+  }];
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_INTEGER_DOT_PRODUCT_OPS

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 5e8e5e4c7ce92..767e939f04473 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -34,6 +34,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td"

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1a93882e8f5f4..888a756be5201 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,33 +29,35 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/bit.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <numeric>
 
 using namespace mlir;
 
 // TODO: generate these strings using ODS.
-constexpr char kMemoryAccessAttrName[] = "memory_access";
-constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
 constexpr char kAlignmentAttrName[] = "alignment";
-constexpr char kSourceAlignmentAttrName[] = "source_alignment";
 constexpr char kBranchWeightAttrName[] = "branch_weights";
 constexpr char kCallee[] = "callee";
 constexpr char kClusterSize[] = "cluster_size";
 constexpr char kControl[] = "control";
 constexpr char kDefaultValueAttrName[] = "default_value";
-constexpr char kExecutionScopeAttrName[] = "execution_scope";
 constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+constexpr char kExecutionScopeAttrName[] = "execution_scope";
 constexpr char kFnNameAttrName[] = "fn";
 constexpr char kGroupOperationAttrName[] = "group_operation";
 constexpr char kIndicesAttrName[] = "indices";
 constexpr char kInitializerAttrName[] = "initializer";
 constexpr char kInterfaceAttrName[] = "interface";
+constexpr char kMemoryAccessAttrName[] = "memory_access";
 constexpr char kMemoryScopeAttrName[] = "memory_scope";
+constexpr char kPackedVectorFormatAttrName[] = "format";
 constexpr char kSemanticsAttrName[] = "semantics";
+constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
 constexpr char kSpecIdAttrName[] = "spec_id";
 constexpr char kTypeAttrName[] = "type";
 constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
@@ -4791,6 +4793,125 @@ LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
 
 LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
 
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyIntegerDotProduct(Operation *op) {
+  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
+         "Not an integer dot product op?");
+  assert(op->getNumResults() == 1 && "Expected a single result");
+
+  Type factorTy = op->getOperand(0).getType();
+  if (op->getOperand(1).getType() != factorTy)
+    return op->emitOpError("requires the same type for both vector operands");
+
+  if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+    auto packedVectorFormat =
+        op->getAttr(kPackedVectorFormatAttrName)
+            .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
+    if (!packedVectorFormat)
+      return op->emitOpError("requires Packed Vector Format attribute for "
+                             "integer vector operands");
+
+    assert(packedVectorFormat.getValue() ==
+               spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
+           "unknown Packed Vector format");
+    if (intTy.getWidth() != 32)
+      return op->emitOpError(
+          llvm::formatv("with specified Packed Vector Format ({0}) requires "
+                        "integer vector operands to be 32-bits wide",
+                        packedVectorFormat.getValue()));
+  }
+
+  if (op->getAttrs().size() > 1)
+    return op->emitError(
+        "op only supports the 'format' #spirv.packed_vector_format attribute");
+
+  Type resultTy = op->getResultTypes().front();
+  unsigned factorBitWidth = getBitWidth(factorTy);
+  unsigned resultBitWidth = getBitWidth(resultTy);
+  if (factorBitWidth > resultBitWidth)
+    return op->emitOpError(
+        llvm::formatv("result type has insufficient bit-width ({0} bits) "
+                      "for the specified vector operand type ({1} bits)",
+                      resultBitWidth, factorBitWidth));
+
+  return success();
+}
+
+static Optional<spirv::Version> getIntegerDotProductMinVersion() {
+  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static Optional<spirv::Version> getIntegerDotProductMaxVersion() {
+  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+static SmallVector<ArrayRef<spirv::Extension>, 1>
+getIntegerDotProductExtensions() {
+  // Requires the SPV_KHR_integer_dot_product extension, specified either
+  // explicitly or implied by target env's SPIR-V version >= 1.6.
+  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
+  return {extension};
+}
+
+static SmallVector<ArrayRef<spirv::Capability>, 1>
+getIntegerDotProductCapabilities(Operation *op) {
+  // Requires the the DotProduct capability and capabilities that depend on
+  // exact op types.
+  static const auto dotProductCap = spirv::Capability::DotProduct;
+  static const auto dotProductInput4x8BitPackedCap =
+      spirv::Capability::DotProductInput4x8BitPacked;
+  static const auto dotProductInput4x8BitCap =
+      spirv::Capability::DotProductInput4x8Bit;
+  static const auto dotProductInputAllCap =
+      spirv::Capability::DotProductInputAll;
+
+  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
+
+  Type factorTy = op->getOperand(0).getType();
+  if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
+    auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
+                          .cast<spirv::PackedVectorFormatAttr>();
+    if (formatAttr.getValue() ==
+        spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
+      capabilities.push_back(dotProductInput4x8BitPackedCap);
+
+    return capabilities;
+  }
+
+  auto vecTy = factorTy.cast<VectorType>();
+  if (vecTy.getElementTypeBitWidth() == 8) {
+    capabilities.push_back(dotProductInput4x8BitCap);
+    return capabilities;
+  }
+
+  capabilities.push_back(dotProductInputAllCap);
+  return capabilities;
+}
+
+#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
+  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
+  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
+    return getIntegerDotProductExtensions();                                   \
+  }                                                                            \
+  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
+    return getIntegerDotProductCapabilities(*this);                            \
+  }                                                                            \
+  Optional<spirv::Version> OpName::getMinVersion() {                           \
+    return getIntegerDotProductMinVersion();                                   \
+  }                                                                            \
+  Optional<spirv::Version> OpName::getMaxVersion() {                           \
+    return getIntegerDotProductMaxVersion();                                   \
+  }
+
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
+
+#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
+
 // TableGen'erated operation interfaces for querying versions, extensions, and
 // capabilities.
 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"

diff  --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 290e07de41d53..5cd7253da620b 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -49,3 +49,97 @@ func.func @module_physical_storage_buffer64_vulkan() {
   spirv.module PhysicalStorageBuffer64 Vulkan { }
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: sdot_scalar_i32_i32
+func.func @sdot_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.SDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: sdot_vector_4xi8_i64
+func.func @sdot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.SDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sdot_vector_4xi16_i64
+func.func @sdot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.SDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sudot_scalar_i32_i32
+func.func @sudot_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.SUDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: sudot_vector_4xi8_i64
+func.func @sudot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.SUDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: sudot_vector_4xi16_i64
+func.func @sudot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.SUDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: udot_scalar_i32_i32
+func.func @udot_scalar_i32_i32(%a: i32) -> i32 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8BitPacked] ]
+  %r = spirv.UDot %a, %a {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r: i32
+}
+
+// CHECK-LABEL: udot_vector_4xi8_i64
+func.func @udot_vector_4xi8_i64(%a: vector<4xi8>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInput4x8Bit] ]
+  %r = spirv.UDot %a, %a: (vector<4xi8>, vector<4xi8>) -> i64
+  return %r: i64
+}
+
+// CHECK-LABEL: udot_vector_4xi16_i64
+func.func @udot_vector_4xi16_i64(%a: vector<4xi16>) -> i64 {
+  // CHECK: min version: v1.0
+  // CHECK: max version: v1.6
+  // CHECK: extensions: [ [SPV_KHR_integer_dot_product] ]
+  // CHECK: capabilities: [ [DotProduct] [DotProductInputAll] ]
+  %r = spirv.UDot %a, %a: (vector<4xi16>, vector<4xi16>) -> i64
+  return %r: i64
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
new file mode 100644
index 0000000000000..c0c5cf39b03fd
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/integer-dot-product-ops.mlir
@@ -0,0 +1,144 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// This test covers the Integer Dot Product ops defined in the
+// SPV_KHR_integer_dot_product extension.
+
+//===----------------------------------------------------------------------===//
+// spirv.SDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sdot_scalar_i32
+func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
+  // CHECK-NEXT: spirv.SDot
+  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sdot_scalar_i64
+func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
+  // CHECK-NEXT: spirv.SDot
+  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  return %r : i64
+}
+
+// CHECK: @sdot_vector_4xi8
+func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+  // CHECK-NEXT: spirv.SDot
+  %r = spirv.SDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  return %r : i32
+}
+
+// CHECK: @sdot_vector_4xi16
+func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
+  // CHECK-NEXT: spirv.SDot
+  %r = spirv.SDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+  return %r : i64
+}
+
+// CHECK: @sdot_vector_8xi8
+func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
+  // CHECK-NEXT: spirv.SDot
+  %r = spirv.SDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+  return %r : i64
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
+  // expected-error @+1 {{op requires the same type for both vector operands}}
+  %r = spirv.SDot %a, %b : (i32, i64) -> i32
+  return %r : i32
+}
+
+// -----
+
+func.func @sdot_scalar_i32_bad_attr(%a: i32, %b: i32) -> i32 {
+  // expected-error @+1 {{op only supports the 'format' #spirv.packed_vector_format attribute}}
+  %r = spirv.SDot %a, %b {volatile = #spirv.decoration<Volatile>,
+                          format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r : i32
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
+  // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
+  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i16
+  return %r : i16
+}
+
+// -----
+
+func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
+  // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
+  %r = spirv.SDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i64, i64) -> i64
+  return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.SUDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @sudot_scalar_i32
+func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
+  // CHECK-NEXT: spirv.SUDot
+  %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @sudot_scalar_i64
+func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
+  // CHECK-NEXT: spirv.SUDot
+  %r = spirv.SUDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  return %r : i64
+}
+
+// CHECK: @sudot_vector_4xi8
+func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+  // CHECK-NEXT: spirv.SUDot
+  %r = spirv.SUDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  return %r : i32
+}
+
+// CHECK: @sudot_vector_4xi16
+func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
+  // CHECK-NEXT: spirv.SUDot
+  %r = spirv.SUDot %a, %b : (vector<4xi16>, vector<4xi16>) -> i64
+  return %r : i64
+}
+
+// CHECK: @sudot_vector_8xi8
+func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
+  // CHECK-NEXT: spirv.SUDot
+  %r = spirv.SUDot %a, %b : (vector<8xi8>, vector<8xi8>) -> i64
+  return %r : i64
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.UDot
+//===----------------------------------------------------------------------===//
+
+// CHECK: @udot_scalar_i32
+func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
+  // CHECK-NEXT: spirv.UDot
+  %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i32
+  return %r : i32
+}
+
+// CHECK: @udot_scalar_i64
+func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
+  // CHECK-NEXT: spirv.UDot
+  %r = spirv.UDot %a, %b {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> i64
+  return %r : i64
+}
+
+// CHECK: @udot_vector_4xi8
+func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
+  // CHECK-NEXT: spirv.UDot
+  %r = spirv.UDot %a, %b : (vector<4xi8>, vector<4xi8>) -> i32
+  return %r : i32
+}

diff  --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index ecf87678473b7..91ffdf26242fc 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -19,6 +19,9 @@
 // spirv.KHR.SubgroupBallot is available under in all SPIR-V versions under
 // SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension.
 
+// Integer Dot Product ops (spirv.*Dot*) require the
+// SPV_KHR_integer_dot_product extension and a number of related capabilities.
+
 // The GeometryPointSize capability implies the Geometry capability, which
 // implies the Shader capability.
 
@@ -122,6 +125,96 @@ func.func @bit_reverse_recursively_implied_capability(%operand: i32) -> i32 attr
   return %0: i32
 }
 
+// CHECK-LABEL: @sdot_scalar_i32_i32_capabilities
+func.func @sdot_scalar_i32_i32_capabilities(%operand: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.SDot
+  %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability1
+func.func @sdot_scalar_i32_i32_missing_capability1(%operand: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_sdot_op
+  %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sdot_scalar_i32_i32_missing_capability2
+func.func @sdot_scalar_i32_i32_missing_capability2(%operand: i32) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProductInput4x8BitPacked], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_sdot_op
+  %0 = "test.convert_to_sdot_op"(%operand, %operand) {format = #spirv.packed_vector_format<PackedVectorFormat4x8Bit>}: (i32, i32) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_capabilities
+func.func @sudot_vector_4xi8_i32_capabilities(%operand: vector<4xi8>) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.SUDot
+  %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability1
+func.func @sudot_vector_4xi8_i32_missing_capability1(%operand: vector<4xi8>) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_sudot_op
+  %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @sudot_vector_4xi8_i32_missing_capability2
+func.func @sudot_vector_4xi8_i32_missing_capability2(%operand: vector<4xi8>) -> i32 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProductInput4x8Bit], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_sudot_op
+  %0 = "test.convert_to_sudot_op"(%operand, %operand): (vector<4xi8>, vector<4xi8>) -> (i32)
+  return %0: i32
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_capabilities
+func.func @udot_vector_4xi16_i64_capabilities(%operand: vector<4xi16>) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.UDot
+  %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability1
+func.func @udot_vector_4xi16_i64_missing_capability1(%operand: vector<4xi16>) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProduct, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_op
+  %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_capability2
+func.func @udot_vector_4xi16_i64_missing_capability2(%operand: vector<4xi16>) -> i64 attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0,
+    [DotProductInputAll, Int16, Int64], [SPV_KHR_integer_dot_product]>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_op
+  %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0: i64
+}
+
 //===----------------------------------------------------------------------===//
 // Extension
 //===----------------------------------------------------------------------===//
@@ -189,3 +282,25 @@ func.func @module_implied_extension() attributes {
   "test.convert_to_module_op"() : () -> ()
   return
 }
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_implied_extension
+func.func @udot_vector_4xi16_i64_implied_extension(%operand: vector<4xi16>) -> i64 attributes {
+  // Version 1.6 implies SPV_KHR_integer_to_product.
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
+    [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK: spirv.UDot
+  %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0: i64
+}
+
+// CHECK-LABEL: @udot_vector_4xi16_i64_missing_extension
+func.func @udot_vector_4xi16_i64_missing_extension(%operand: vector<4xi16>) -> i64 attributes {
+  // Version 1.5 does not imply SPV_KHR_integer_to_product.
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.5,
+    [DotProduct, DotProductInputAll, Int16, Int64], []>, #spirv.resource_limits<>>
+} {
+  // CHECK: test.convert_to_udot_op
+  %0 = "test.convert_to_udot_op"(%operand, %operand): (vector<4xi16>, vector<4xi16>) -> (i64)
+  return %0: i64
+}

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index e29d167e1e5d8..13c35ca5b150d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -191,6 +191,19 @@ struct ConvertToSubgroupBallot : RewritePattern {
     return success();
   }
 };
+
+template <const char *TestOpName, typename SPIRVOp>
+struct ConvertToIntegerDotProd : RewritePattern {
+  ConvertToIntegerDotProd(MLIRContext *context)
+      : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
+                                         op->getOperands(), op->getAttrs());
+    return success();
+  }
+};
 } // namespace
 
 void ConvertToTargetEnv::runOnOperation() {
@@ -207,10 +220,17 @@ void ConvertToTargetEnv::runOnOperation() {
 
   auto target = SPIRVConversionTarget::get(targetEnv);
 
+  static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
+  static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
+  static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
+
   RewritePatternSet patterns(context);
   patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
                ConvertToGroupNonUniformBallot, ConvertToModule,
-               ConvertToSubgroupBallot>(context);
+               ConvertToSubgroupBallot,
+               ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
+               ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
+               ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>>(context);
 
   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
     return signalPassFailure();

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 014c94783119b..dad90566ffff8 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -1395,7 +1395,8 @@ static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
   auto defs = recordKeeper.getAllDerivedDefinitions("SPIRV_Op");
   for (const auto *def : defs) {
     Operator op(def);
-    emitAvailabilityImpl(op, os);
+    if (def->getValueAsBit("autogenAvailability"))
+      emitAvailabilityImpl(op, os);
   }
   return false;
 }

diff  --git a/mlir/utils/spirv/define_enum.sh b/mlir/utils/spirv/define_enum.sh
index 496f90c320789..ca9d8642cd902 100755
--- a/mlir/utils/spirv/define_enum.sh
+++ b/mlir/utils/spirv/define_enum.sh
@@ -12,7 +12,7 @@
 # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
 # SPIR-V enum classes.
 #
-# If <enum-name> is missing, this script updates existing ones.
+# If <enum-class-name> is missing, this script updates existing ones.
 
 set -e
 


        


More information about the Mlir-commits mailing list