[Mlir-commits] [mlir] 9415241 - [mlir][spirv] Split op implementation file into subfiles. NFC.

Jakub Kuderski llvmlistbot at llvm.org
Wed Jul 19 13:50:59 PDT 2023


Author: Jakub Kuderski
Date: 2023-07-19T16:48:47-04:00
New Revision: 9415241c5ba700379d67006391e50204df4f32f4

URL: https://github.com/llvm/llvm-project/commit/9415241c5ba700379d67006391e50204df4f32f4
DIFF: https://github.com/llvm/llvm-project/commit/9415241c5ba700379d67006391e50204df4f32f4.diff

LOG: [mlir][spirv] Split op implementation file into subfiles. NFC.

The main op implementation file for SPIR-V grew past 5k LOC. This makes it
take a long time to compile and index with LSPs like clangd.

Pull out the first few SPIR-V extension ops into their own `.cpp` files,
just like we do with `.td` op definitions. This includes the
KHR/NV/Intel coop matrix and the integer dot prod extensions.

I plan to further split this in future revisions.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D155747

Added: 
    mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
    mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
    mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
    mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h

Modified: 
    mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
    mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Removed: 
    mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h
deleted file mode 100644
index 073bdcef2293fc..00000000000000
--- a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h
+++ /dev/null
@@ -1,45 +0,0 @@
-//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines utilities used for parsing types and ops for SPIR-V
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_
-#define MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_
-
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-
-namespace mlir {
-
-/// Parses the next keyword in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass, typename ParserType>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
-                     StringRef attrName = spirv::attributeName<EnumClass>()) {
-  StringRef keyword;
-  SmallVector<NamedAttribute, 1> attr;
-  auto loc = parser.getCurrentLocation();
-  if (parser.parseKeyword(&keyword))
-    return failure();
-  if (std::optional<EnumClass> attr =
-          spirv::symbolizeEnum<EnumClass>(keyword)) {
-    value = *attr;
-    return success();
-  }
-  return parser.emitError(loc, "invalid ")
-         << attrName << " attribute specification: " << keyword;
-}
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b59cd07f87f92f..70e2eb786e397c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,12 +3,16 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
 add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
 
 add_mlir_dialect_library(MLIRSPIRVDialect
+  CooperativeMatrixOps.cpp
+  IntegerDotProductOps.cpp
+  JointMatrixOps.cpp
   SPIRVAttributes.cpp
   SPIRVCanonicalization.cpp
   SPIRVGLCanonicalization.cpp
   SPIRVDialect.cpp
   SPIRVEnums.cpp
   SPIRVOps.cpp
+  SPIRVParsingUtils.cpp
   SPIRVTypes.cpp
   TargetAndABI.cpp
 

diff  --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
new file mode 100644
index 00000000000000..bdd87677866501
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -0,0 +1,306 @@
+//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops  -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Cooperative Matrix operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult KHRCooperativeMatrixLengthOp::verify() {
+  if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
+    return emitOpError(
+               "type attribute must be a '!spirv.coopmatrix' type, found ")
+           << getCooperativeMatrixType() << " instead";
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
+  if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
+    return failure();
+  if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
+    return failure();
+
+  CooperativeMatrixLayoutKHR layout;
+  if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
+          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+    return failure();
+  }
+
+  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+    return failure();
+
+  Type ptrType;
+  Type elementType;
+  if (parser.parseColon() || parser.parseType(ptrType) ||
+      parser.parseKeywordType("as", elementType)) {
+    return failure();
+  }
+  result.addTypes(elementType);
+
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  if (parser.resolveOperands(operandInfo, {ptrType, strideType},
+                             parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getStride() << ", "
+          << getMatrixLayout();
+  // Print optional memory operand attribute.
+  if (auto memOperand = getMemoryOperand())
+    printer << " [\"" << memOperand << "\"]";
+  printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
+                                                    Type coopMatrix) {
+  auto pointerType = cast<PointerType>(pointer);
+  Type pointeeType = pointerType.getPointeeType();
+  if (!isa<ScalarType, VectorType>(pointeeType)) {
+    return op->emitError(
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  }
+
+  // TODO: Verify the memory object behind the pointer:
+  // > If the Shader capability was declared, Pointer must point into an array
+  // > and any ArrayStride decoration on Pointer is ignored.
+
+  return success();
+}
+
+LogicalResult KHRCooperativeMatrixLoadOp::verify() {
+  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+                                        getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
+  for (auto &op : operandInfo) {
+    if (parser.parseOperand(op) || parser.parseComma())
+      return failure();
+  }
+
+  CooperativeMatrixLayoutKHR layout;
+  if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
+          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+    return failure();
+  }
+
+  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+    return failure();
+
+  Type ptrType;
+  Type objectType;
+  if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
+      parser.parseType(objectType)) {
+    return failure();
+  }
+
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
+                             parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+          << ", " << getMatrixLayout();
+
+  // Print optional memory operand attribute.
+  if (auto memOperand = getMemoryOperand())
+    printer << " [\"" << *memOperand << "\"]";
+  printer << " : " << getPointer().getType() << ", " << getObject().getType();
+}
+
+LogicalResult KHRCooperativeMatrixStoreOp::verify() {
+  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+                                        getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult NVCooperativeMatrixLengthOp::verify() {
+  if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
+    return emitOpError(
+               "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
+           << getCooperativeMatrixType() << " instead";
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+                                             OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
+  Type elementType;
+  if (parser.parseOperandList(operandInfo, 3) ||
+      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
+    return failure();
+  }
+  if (parser.resolveOperands(operandInfo,
+                             {ptrType, strideType, columnMajorType},
+                             parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  result.addTypes(elementType);
+  return success();
+}
+
+void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getStride() << ", "
+          << getColumnmajor();
+  // Print optional memory access attribute.
+  if (auto memAccess = getMemoryAccess())
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+  printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult
+verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
+  Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
+  if (!llvm::isa<ScalarType>(pointeeType) &&
+      !llvm::isa<VectorType>(pointeeType))
+    return op->emitError(
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
+  if (storage != StorageClass::Workgroup &&
+      storage != StorageClass::StorageBuffer &&
+      storage != StorageClass::PhysicalStorageBuffer)
+    return op->emitError(
+               "Pointer storage class must be Workgroup, StorageBuffer or "
+               "PhysicalStorageBufferEXT but provided ")
+           << stringifyStorageClass(storage);
+  return success();
+}
+
+LogicalResult NVCooperativeMatrixLoadOp::verify() {
+  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+                                          getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+                                              OperationState &result) {
+  SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
+  Type strideType = parser.getBuilder().getIntegerType(32);
+  Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
+  Type elementType;
+  if (parser.parseOperandList(operandInfo, 4) ||
+      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+      parser.parseType(ptrType) || parser.parseComma() ||
+      parser.parseType(elementType)) {
+    return failure();
+  }
+  if (parser.resolveOperands(
+          operandInfo, {ptrType, elementType, strideType, columnMajorType},
+          parser.getNameLoc(), result.operands)) {
+    return failure();
+  }
+
+  return success();
+}
+
+void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+          << ", " << getColumnmajor();
+  // Print optional memory access attribute.
+  if (auto memAccess = getMemoryAccess())
+    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+  printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
+}
+
+LogicalResult NVCooperativeMatrixStoreOp::verify() {
+  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+                                          getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixMulAdd
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) {
+  if (op.getC().getType() != op.getResult().getType())
+    return op.emitOpError("result and third operand must have the same type");
+  auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
+  auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
+  auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
+  auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
+  if (typeA.getRows() != typeR.getRows() ||
+      typeA.getColumns() != typeB.getRows() ||
+      typeB.getColumns() != typeR.getColumns())
+    return op.emitOpError("matrix size must match");
+  if (typeR.getScope() != typeA.getScope() ||
+      typeR.getScope() != typeB.getScope() ||
+      typeR.getScope() != typeC.getScope())
+    return op.emitOpError("matrix scope must match");
+  auto elementTypeA = typeA.getElementType();
+  auto elementTypeB = typeB.getElementType();
+  if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
+    if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
+        llvm::cast<IntegerType>(elementTypeB).getWidth())
+      return op.emitOpError(
+          "matrix A and B integer element types must be the same bit width");
+  } else if (elementTypeA != elementTypeB) {
+    return op.emitOpError(
+        "matrix A and B non-integer element types must match");
+  }
+  if (typeR.getElementType() != typeC.getElementType())
+    return op.emitOpError("matrix accumulator element type must match");
+  return success();
+}
+
+LogicalResult NVCooperativeMatrixMulAddOp::verify() {
+  return verifyCoopMatrixMulAddNV(*this);
+}
+
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
new file mode 100644
index 00000000000000..28efe4f046fcde
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -0,0 +1,158 @@
+//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops  ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Integer Dot Product operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyIntegerDotProduct(Operation *op) {
+  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
+         "Not an integer dot product op?");
+  assert(op->getNumResults() == 1 && "Expected a single result");
+
+  Type factorTy = op->getOperand(0).getType();
+  if (op->getOperand(1).getType() != factorTy)
+    return op->emitOpError("requires the same type for both vector operands");
+
+  unsigned expectedNumAttrs = 0;
+  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
+    ++expectedNumAttrs;
+    auto packedVectorFormat =
+        llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
+            op->getAttr(kPackedVectorFormatAttrName));
+    if (!packedVectorFormat)
+      return op->emitOpError("requires Packed Vector Format attribute for "
+                             "integer vector operands");
+
+    assert(packedVectorFormat.getValue() ==
+               spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
+           "Unknown Packed Vector Format");
+    if (intTy.getWidth() != 32)
+      return op->emitOpError(
+          llvm::formatv("with specified Packed Vector Format ({0}) requires "
+                        "integer vector operands to be 32-bits wide",
+                        packedVectorFormat.getValue()));
+  } else {
+    if (op->hasAttr(kPackedVectorFormatAttrName))
+      return op->emitOpError(llvm::formatv(
+          "with invalid format attribute for vector operands of type '{0}'",
+          factorTy));
+  }
+
+  if (op->getAttrs().size() > expectedNumAttrs)
+    return op->emitError(
+        "op only supports the 'format' #spirv.packed_vector_format attribute");
+
+  Type resultTy = op->getResultTypes().front();
+  bool hasAccumulator = op->getNumOperands() == 3;
+  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
+    return op->emitOpError(
+        "requires the same accumulator operand and result types");
+
+  unsigned factorBitWidth = getBitWidth(factorTy);
+  unsigned resultBitWidth = getBitWidth(resultTy);
+  if (factorBitWidth > resultBitWidth)
+    return op->emitOpError(
+        llvm::formatv("result type has insufficient bit-width ({0} bits) "
+                      "for the specified vector operand type ({1} bits)",
+                      resultBitWidth, factorBitWidth));
+
+  return success();
+}
+
+static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
+  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
+  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+static SmallVector<ArrayRef<spirv::Extension>, 1>
+getIntegerDotProductExtensions() {
+  // Requires the SPV_KHR_integer_dot_product extension, specified either
+  // explicitly or implied by target env's SPIR-V version >= 1.6.
+  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
+  return {extension};
+}
+
+static SmallVector<ArrayRef<spirv::Capability>, 1>
+getIntegerDotProductCapabilities(Operation *op) {
+  // Requires the the DotProduct capability and capabilities that depend on
+  // exact op types.
+  static const auto dotProductCap = spirv::Capability::DotProduct;
+  static const auto dotProductInput4x8BitPackedCap =
+      spirv::Capability::DotProductInput4x8BitPacked;
+  static const auto dotProductInput4x8BitCap =
+      spirv::Capability::DotProductInput4x8Bit;
+  static const auto dotProductInputAllCap =
+      spirv::Capability::DotProductInputAll;
+
+  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
+
+  Type factorTy = op->getOperand(0).getType();
+  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
+    auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
+        op->getAttr(kPackedVectorFormatAttrName));
+    if (formatAttr.getValue() ==
+        spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
+      capabilities.push_back(dotProductInput4x8BitPackedCap);
+
+    return capabilities;
+  }
+
+  auto vecTy = llvm::cast<VectorType>(factorTy);
+  if (vecTy.getElementTypeBitWidth() == 8) {
+    capabilities.push_back(dotProductInput4x8BitCap);
+    return capabilities;
+  }
+
+  capabilities.push_back(dotProductInputAllCap);
+  return capabilities;
+}
+
+#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
+  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
+  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
+    return getIntegerDotProductExtensions();                                   \
+  }                                                                            \
+  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
+    return getIntegerDotProductCapabilities(*this);                            \
+  }                                                                            \
+  std::optional<spirv::Version> OpName::getMinVersion() {                      \
+    return getIntegerDotProductMinVersion();                                   \
+  }                                                                            \
+  std::optional<spirv::Version> OpName::getMaxVersion() {                      \
+    return getIntegerDotProductMaxVersion();                                   \
+  }
+
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)
+
+#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
+
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
new file mode 100644
index 00000000000000..63305ecdd0c4e9
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
@@ -0,0 +1,84 @@
+//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops  -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Intel Joint Matrix operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixLoad
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
+  Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
+  if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
+      !llvm::isa<VectorType>(pointeeType))
+    return op->emitError(
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  spirv::StorageClass storage =
+      llvm::cast<spirv::PointerType>(pointer).getStorageClass();
+  if (storage != spirv::StorageClass::Workgroup &&
+      storage != spirv::StorageClass::CrossWorkgroup &&
+      storage != spirv::StorageClass::UniformConstant &&
+      storage != spirv::StorageClass::Generic)
+    return op->emitError("Pointer storage class must be Workgroup or "
+                         "CrossWorkgroup but provided ")
+           << stringifyStorageClass(storage);
+  return success();
+}
+
+LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
+  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+                                         getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixStore
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
+  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+                                         getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixMad
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
+  if (op.getC().getType() != op.getResult().getType())
+    return op.emitOpError("result and third operand must have the same type");
+  auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
+  auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
+  auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
+  auto typeR =
+      llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
+  if (typeA.getRows() != typeR.getRows() ||
+      typeA.getColumns() != typeB.getRows() ||
+      typeB.getColumns() != typeR.getColumns())
+    return op.emitOpError("matrix size must match");
+  if (typeR.getScope() != typeA.getScope() ||
+      typeR.getScope() != typeB.getScope() ||
+      typeR.getScope() != typeC.getScope())
+    return op.emitOpError("matrix scope must match");
+  if (typeA.getElementType() != typeB.getElementType() ||
+      typeR.getElementType() != typeC.getElementType())
+    return op.emitOpError("matrix element type must match");
+  return success();
+}
+
+LogicalResult spirv::INTELJointMatrixMadOp::verify() {
+  return verifyJointMatrixMad(*this);
+}
+
+} // namespace mlir

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 2e1c7923e24126..124d4ed6e8e6ed 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -11,7 +11,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
+
+#include "SPIRVParsingUtils.h"
+
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -341,11 +343,13 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
     return {};
 
   Scope scope;
-  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+  if (parser.parseComma() ||
+      spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
     return {};
 
   CooperativeMatrixUseKHR use;
-  if (parser.parseComma() || parseEnumKeywordAttr(use, parser, "use <id>"))
+  if (parser.parseComma() ||
+      spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
     return {};
 
   if (parser.parseGreater())
@@ -376,7 +380,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
     return Type();
 
   Scope scope;
-  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+  if (parser.parseComma() ||
+      spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
     return Type();
 
   if (parser.parseGreater())
@@ -407,10 +412,11 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
     return Type();
   MatrixLayout matrixLayout;
   if (parser.parseComma() ||
-      parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
+      spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
     return Type();
   Scope scope;
-  if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+  if (parser.parseComma() ||
+      spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
     return Type();
   if (parser.parseGreater())
     return Type();

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
new file mode 100644
index 00000000000000..efe596cd725c5e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
@@ -0,0 +1,32 @@
+//===- SPIRVOpUtils.h - MLIR SPIR-V Dialect Op Definition Utilities -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+namespace mlir::spirv {
+
+/// Returns the bit width of the `type`.
+inline unsigned getBitWidth(Type type) {
+  if (isa<spirv::PointerType>(type)) {
+    // Just return 64 bits for pointer types for now.
+    // TODO: Make sure not caller relies on the actual pointer width value.
+    return 64;
+  }
+
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
+    assert(vectorType.getElementType().isIntOrFloat());
+    return vectorType.getNumElements() *
+           vectorType.getElementType().getIntOrFloatBitWidth();
+  }
+  llvm_unreachable("unhandled bit width computation for type");
+}
+
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2516c29fbc58a8..2184cec953fb05 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -12,7 +12,9 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 
-#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
@@ -33,41 +35,12 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/FormatVariadic.h"
 #include <cassert>
 #include <numeric>
 #include <type_traits>
 
 using namespace mlir;
-
-// TODO: generate these strings using ODS.
-constexpr char kAlignmentAttrName[] = "alignment";
-constexpr char kBranchWeightAttrName[] = "branch_weights";
-constexpr char kCallee[] = "callee";
-constexpr char kClusterSize[] = "cluster_size";
-constexpr char kControl[] = "control";
-constexpr char kDefaultValueAttrName[] = "default_value";
-constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-constexpr char kExecutionScopeAttrName[] = "execution_scope";
-constexpr char kFnNameAttrName[] = "fn";
-constexpr char kGroupOperationAttrName[] = "group_operation";
-constexpr char kIndicesAttrName[] = "indices";
-constexpr char kInitializerAttrName[] = "initializer";
-constexpr char kInterfaceAttrName[] = "interface";
-constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-constexpr char kMemoryAccessAttrName[] = "memory_access";
-constexpr char kMemoryOperandAttrName[] = "memory_operand";
-constexpr char kMemoryScopeAttrName[] = "memory_scope";
-constexpr char kPackedVectorFormatAttrName[] = "format";
-constexpr char kSemanticsAttrName[] = "semantics";
-constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-constexpr char kSpecIdAttrName[] = "spec_id";
-constexpr char kTypeAttrName[] = "type";
-constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-constexpr char kValueAttrName[] = "value";
-constexpr char kValuesAttrName[] = "values";
-constexpr char kCompositeSpecConstituentsName[] = "constituents";
+using namespace mlir::spirv::AttrNames;
 
 //===----------------------------------------------------------------------===//
 // Common utility functions
@@ -158,79 +131,6 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
   return success();
 }
 
-template <typename Ty>
-static ArrayAttr
-getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
-                           function_ref<StringRef(Ty)> stringifyFn) {
-  if (enumValues.empty()) {
-    return nullptr;
-  }
-  SmallVector<StringRef, 1> enumValStrs;
-  enumValStrs.reserve(enumValues.size());
-  for (auto val : enumValues) {
-    enumValStrs.emplace_back(stringifyFn(val));
-  }
-  return builder.getStrArrayAttr(enumValStrs);
-}
-
-/// Parses the next string attribute in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass>
-static ParseResult
-parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
-                 StringRef attrName = spirv::attributeName<EnumClass>()) {
-  static_assert(std::is_enum_v<EnumClass>);
-  Attribute attrVal;
-  NamedAttrList attr;
-  auto loc = parser.getCurrentLocation();
-  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
-                            attrName, attr))
-    return failure();
-  if (!llvm::isa<StringAttr>(attrVal))
-    return parser.emitError(loc, "expected ")
-           << attrName << " attribute specified as string";
-  auto attrOptional = spirv::symbolizeEnum<EnumClass>(
-      llvm::cast<StringAttr>(attrVal).getValue());
-  if (!attrOptional)
-    return parser.emitError(loc, "invalid ")
-           << attrName << " attribute specification: " << attrVal;
-  value = *attrOptional;
-  return success();
-}
-
-/// Parses the next string attribute in `parser` as an enumerant of the given
-/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
-/// attribute with the enum class's name as attribute name.
-template <typename EnumAttrClass,
-          typename EnumClass = typename EnumAttrClass::ValueType>
-static ParseResult
-parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
-                 StringRef attrName = spirv::attributeName<EnumClass>()) {
-  static_assert(std::is_enum_v<EnumClass>);
-  if (parseEnumStrAttr(value, parser))
-    return failure();
-  state.addAttribute(attrName,
-                     parser.getBuilder().getAttr<EnumAttrClass>(value));
-  return success();
-}
-
-/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
-/// and inserts the enumerant into `state` as an 32-bit integer attribute with
-/// the enum class's name as attribute name.
-template <typename EnumAttrClass,
-          typename EnumClass = typename EnumAttrClass::ValueType>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
-                     OperationState &state,
-                     StringRef attrName = spirv::attributeName<EnumClass>()) {
-  static_assert(std::is_enum_v<EnumClass>);
-  if (parseEnumKeywordAttr(value, parser))
-    return failure();
-  state.addAttribute(attrName,
-                     parser.getBuilder().getAttr<EnumAttrClass>(value));
-  return success();
-}
-
 /// Parses Function, Selection and Loop control attributes. If no control is
 /// specified, "None" is used as a default.
 template <typename EnumAttrClass, typename EnumClass>
@@ -240,7 +140,7 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
   if (succeeded(parser.parseOptionalKeyword(kControl))) {
     EnumClass control;
     if (parser.parseLParen() ||
-        parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
+        spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
         parser.parseRParen())
       return failure();
     return success();
@@ -252,40 +152,6 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
   return success();
 }
 
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-///     (`[` memory-access `]`)?
-/// where:
-///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-///         integer-literal | `"NonTemporal"`
-static ParseResult
-parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state,
-                            StringRef attrName = kMemoryAccessAttrName) {
-  // Parse an optional list of attributes staring with '['
-  if (parser.parseOptionalLSquare()) {
-    // Nothing to do
-    return success();
-  }
-
-  spirv::MemoryAccess memoryAccessAttr;
-  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
-                                                attrName))
-    return failure();
-
-  if (spirv::bitEnumContainsAll(memoryAccessAttr,
-                                spirv::MemoryAccess::Aligned)) {
-    // Parse integer attribute for alignment.
-    Attribute alignmentAttr;
-    Type i32Type = parser.getBuilder().getIntegerType(32);
-    if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
-                              state.attributes)) {
-      return failure();
-    }
-  }
-  return parser.parseRSquare();
-}
-
 // TODO Make sure to merge this and the previous function into one template
 // parameterized by memory access attribute name and alignment. Doing so now
 // results in VS2017 in producing an internal error (at the call site) that's
@@ -299,8 +165,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
-                                                kSourceMemoryAccessAttrName))
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -683,25 +549,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
 }
 
-// Get bit width of types.
-static unsigned getBitWidth(Type type) {
-  if (llvm::isa<spirv::PointerType>(type)) {
-    // Just return 64 bits for pointer types for now.
-    // TODO: Make sure not caller relies on the actual pointer width value.
-    return 64;
-  }
-
-  if (type.isIntOrFloat())
-    return type.getIntOrFloatBitWidth();
-
-  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
-    assert(vectorType.getElementType().isIntOrFloat());
-    return vectorType.getNumElements() *
-           vectorType.getElementType().getIntOrFloatBitWidth();
-  }
-  llvm_unreachable("unhandled bit width computation for type");
-}
-
 /// Walks the given type hierarchy with the given indices, potentially down
 /// to component granularity, to select an element type. Returns null type and
 /// emits errors with the given loc on failure.
@@ -839,10 +686,10 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
   OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
   Type type;
   SMLoc loc;
-  if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
-                                         kMemoryScopeAttrName) ||
-      parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
-                                                   kSemanticsAttrName) ||
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
+                                                kMemoryScopeAttrName) ||
+      spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+          memoryScope, parser, state, kSemanticsAttrName) ||
       parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
       parser.getCurrentLocation(&loc) || parser.parseColonType(type))
     return failure();
@@ -916,10 +763,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   spirv::Scope executionScope;
   spirv::GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
-  if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                         kExecutionScopeAttrName) ||
-      parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
-                                                  kGroupOperationAttrName) ||
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
+                                                kExecutionScopeAttrName) ||
+      spirv::parseEnumStrAttr<spirv::GroupOperationAttr>(
+          groupOperation, parser, state, kGroupOperationAttrName) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -1199,11 +1046,11 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
   spirv::MemorySemantics equalSemantics, unequalSemantics;
   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
   Type type;
-  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
-                                         kMemoryScopeAttrName) ||
-      parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
+                                                kMemoryScopeAttrName) ||
+      spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
           equalSemantics, parser, state, kEqualSemanticsAttrName) ||
-      parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+      spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
           unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
       parser.parseOperandList(operandInfo, 3))
     return failure();
@@ -3478,10 +3325,10 @@ ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
   // Parse attributes
   spirv::AddressingModel addrModel;
   spirv::MemoryModel memoryModel;
-  if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
-                                                         result) ||
-      ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
-                                                     result))
+  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
+                                                              result) ||
+      spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
+                                                          result))
     return failure();
 
   if (succeeded(parser.parseOptionalKeyword("requires"))) {
@@ -4028,364 +3875,6 @@ LogicalResult spirv::VectorShuffleOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
-  if (!isa<spirv::CooperativeMatrixType>(getCooperativeMatrixType())) {
-    return emitOpError(
-               "type attribute must be a '!spirv.coopmatrix' type, found ")
-           << getCooperativeMatrixType() << " instead";
-  }
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
-                                                     OperationState &result) {
-  std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
-  if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
-    return failure();
-  if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
-    return failure();
-
-  spirv::CooperativeMatrixLayoutKHR layout;
-  if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
-          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
-    return failure();
-  }
-
-  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
-    return failure();
-
-  Type ptrType;
-  Type elementType;
-  if (parser.parseColon() || parser.parseType(ptrType) ||
-      parser.parseKeywordType("as", elementType)) {
-    return failure();
-  }
-  result.addTypes(elementType);
-
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  if (parser.resolveOperands(operandInfo, {ptrType, strideType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getStride() << ", "
-          << getMatrixLayout();
-  // Print optional memory operand attribute.
-  if (auto memOperand = getMemoryOperand())
-    printer << " [\"" << memOperand << "\"]";
-  printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
-  auto pointerType = cast<spirv::PointerType>(pointer);
-  Type pointeeType = pointerType.getPointeeType();
-  if (!isa<spirv::ScalarType, VectorType>(pointeeType)) {
-    return op->emitError(
-               "Pointer must point to a scalar or vector type but provided ")
-           << pointeeType;
-  }
-
-  // TODO: Verify the memory object behind the pointer:
-  // > If the Shader capability was declared, Pointer must point into an array
-  // > and any ArrayStride decoration on Pointer is ignored.
-
-  return success();
-}
-
-LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
-                                                      OperationState &result) {
-  std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
-  for (auto &op : operandInfo) {
-    if (parser.parseOperand(op) || parser.parseComma())
-      return failure();
-  }
-
-  spirv::CooperativeMatrixLayoutKHR layout;
-  if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
-          layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
-    return failure();
-  }
-
-  if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
-    return failure();
-
-  Type ptrType;
-  Type objectType;
-  if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
-      parser.parseType(objectType)) {
-    return failure();
-  }
-
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
-          << ", " << getMatrixLayout();
-
-  // Print optional memory operand attribute.
-  if (auto memOperand = getMemoryOperand())
-    printer << " [\"" << *memOperand << "\"]";
-  printer << " : " << getPointer().getType() << ", " << getObject().getType();
-}
-
-LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() {
-  if (!isa<spirv::CooperativeMatrixNVType>(getCooperativeMatrixType())) {
-    return emitOpError(
-               "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
-           << getCooperativeMatrixType() << " instead";
-  }
-
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
-                                                    OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  Type columnMajorType = parser.getBuilder().getIntegerType(1);
-  Type ptrType;
-  Type elementType;
-  if (parser.parseOperandList(operandInfo, 3) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
-    return failure();
-  }
-  if (parser.resolveOperands(operandInfo,
-                             {ptrType, strideType, columnMajorType},
-                             parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  result.addTypes(elementType);
-  return success();
-}
-
-void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getStride() << ", "
-          << getColumnmajor();
-  // Print optional memory access attribute.
-  if (auto memAccess = getMemoryAccess())
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult
-verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
-  Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
-  if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
-      !llvm::isa<VectorType>(pointeeType))
-    return op->emitError(
-               "Pointer must point to a scalar or vector type but provided ")
-           << pointeeType;
-  spirv::StorageClass storage =
-      llvm::cast<spirv::PointerType>(pointer).getStorageClass();
-  if (storage != spirv::StorageClass::Workgroup &&
-      storage != spirv::StorageClass::StorageBuffer &&
-      storage != spirv::StorageClass::PhysicalStorageBuffer)
-    return op->emitError(
-               "Pointer storage class must be Workgroup, StorageBuffer or "
-               "PhysicalStorageBufferEXT but provided ")
-           << stringifyStorageClass(storage);
-  return success();
-}
-
-LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
-                                          getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
-                                                     OperationState &result) {
-  SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
-  Type strideType = parser.getBuilder().getIntegerType(32);
-  Type columnMajorType = parser.getBuilder().getIntegerType(1);
-  Type ptrType;
-  Type elementType;
-  if (parser.parseOperandList(operandInfo, 4) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(ptrType) || parser.parseComma() ||
-      parser.parseType(elementType)) {
-    return failure();
-  }
-  if (parser.resolveOperands(
-          operandInfo, {ptrType, elementType, strideType, columnMajorType},
-          parser.getNameLoc(), result.operands)) {
-    return failure();
-  }
-
-  return success();
-}
-
-void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
-  printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
-          << ", " << getColumnmajor();
-  // Print optional memory access attribute.
-  if (auto memAccess = getMemoryAccess())
-    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
-}
-
-LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
-                                          getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixMulAdd
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) {
-  if (op.getC().getType() != op.getResult().getType())
-    return op.emitOpError("result and third operand must have the same type");
-  auto typeA = llvm::cast<spirv::CooperativeMatrixNVType>(op.getA().getType());
-  auto typeB = llvm::cast<spirv::CooperativeMatrixNVType>(op.getB().getType());
-  auto typeC = llvm::cast<spirv::CooperativeMatrixNVType>(op.getC().getType());
-  auto typeR =
-      llvm::cast<spirv::CooperativeMatrixNVType>(op.getResult().getType());
-  if (typeA.getRows() != typeR.getRows() ||
-      typeA.getColumns() != typeB.getRows() ||
-      typeB.getColumns() != typeR.getColumns())
-    return op.emitOpError("matrix size must match");
-  if (typeR.getScope() != typeA.getScope() ||
-      typeR.getScope() != typeB.getScope() ||
-      typeR.getScope() != typeC.getScope())
-    return op.emitOpError("matrix scope must match");
-  auto elementTypeA = typeA.getElementType();
-  auto elementTypeB = typeB.getElementType();
-  if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
-    if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
-        llvm::cast<IntegerType>(elementTypeB).getWidth())
-      return op.emitOpError(
-          "matrix A and B integer element types must be the same bit width");
-  } else if (elementTypeA != elementTypeB) {
-    return op.emitOpError(
-        "matrix A and B non-integer element types must match");
-  }
-  if (typeR.getElementType() != typeC.getElementType())
-    return op.emitOpError("matrix accumulator element type must match");
-  return success();
-}
-
-LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
-  return verifyCoopMatrixMulAddNV(*this);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
-  Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
-  if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
-      !llvm::isa<VectorType>(pointeeType))
-    return op->emitError(
-               "Pointer must point to a scalar or vector type but provided ")
-           << pointeeType;
-  spirv::StorageClass storage =
-      llvm::cast<spirv::PointerType>(pointer).getStorageClass();
-  if (storage != spirv::StorageClass::Workgroup &&
-      storage != spirv::StorageClass::CrossWorkgroup &&
-      storage != spirv::StorageClass::UniformConstant &&
-      storage != spirv::StorageClass::Generic)
-    return op->emitError("Pointer storage class must be Workgroup or "
-                         "CrossWorkgroup but provided ")
-           << stringifyStorageClass(storage);
-  return success();
-}
-
-LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
-  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
-                                         getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixStore
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
-  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
-                                         getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixMad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
-  if (op.getC().getType() != op.getResult().getType())
-    return op.emitOpError("result and third operand must have the same type");
-  auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
-  auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
-  auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
-  auto typeR =
-      llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
-  if (typeA.getRows() != typeR.getRows() ||
-      typeA.getColumns() != typeB.getRows() ||
-      typeB.getColumns() != typeR.getColumns())
-    return op.emitOpError("matrix size must match");
-  if (typeR.getScope() != typeA.getScope() ||
-      typeR.getScope() != typeB.getScope() ||
-      typeR.getScope() != typeC.getScope())
-    return op.emitOpError("matrix scope must match");
-  if (typeA.getElementType() != typeB.getElementType() ||
-      typeR.getElementType() != typeC.getElementType())
-    return op.emitOpError("matrix element type must match");
-  return success();
-}
-
-LogicalResult spirv::INTELJointMatrixMadOp::verify() {
-  return verifyJointMatrixMad(*this);
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.MatrixTimesScalar
 //===----------------------------------------------------------------------===//
@@ -5058,140 +4547,6 @@ LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
 
 LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
 
-//===----------------------------------------------------------------------===//
-// Integer Dot Product ops
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyIntegerDotProduct(Operation *op) {
-  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
-         "Not an integer dot product op?");
-  assert(op->getNumResults() == 1 && "Expected a single result");
-
-  Type factorTy = op->getOperand(0).getType();
-  if (op->getOperand(1).getType() != factorTy)
-    return op->emitOpError("requires the same type for both vector operands");
-
-  unsigned expectedNumAttrs = 0;
-  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    ++expectedNumAttrs;
-    auto packedVectorFormat =
-        llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(kPackedVectorFormatAttrName));
-    if (!packedVectorFormat)
-      return op->emitOpError("requires Packed Vector Format attribute for "
-                             "integer vector operands");
-
-    assert(packedVectorFormat.getValue() ==
-               spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
-           "Unknown Packed Vector Format");
-    if (intTy.getWidth() != 32)
-      return op->emitOpError(
-          llvm::formatv("with specified Packed Vector Format ({0}) requires "
-                        "integer vector operands to be 32-bits wide",
-                        packedVectorFormat.getValue()));
-  } else {
-    if (op->hasAttr(kPackedVectorFormatAttrName))
-      return op->emitOpError(llvm::formatv(
-          "with invalid format attribute for vector operands of type '{0}'",
-          factorTy));
-  }
-
-  if (op->getAttrs().size() > expectedNumAttrs)
-    return op->emitError(
-        "op only supports the 'format' #spirv.packed_vector_format attribute");
-
-  Type resultTy = op->getResultTypes().front();
-  bool hasAccumulator = op->getNumOperands() == 3;
-  if (hasAccumulator && op->getOperand(2).getType() != resultTy)
-    return op->emitOpError(
-        "requires the same accumulator operand and result types");
-
-  unsigned factorBitWidth = getBitWidth(factorTy);
-  unsigned resultBitWidth = getBitWidth(resultTy);
-  if (factorBitWidth > resultBitWidth)
-    return op->emitOpError(
-        llvm::formatv("result type has insufficient bit-width ({0} bits) "
-                      "for the specified vector operand type ({1} bits)",
-                      resultBitWidth, factorBitWidth));
-
-  return success();
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
-  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
-  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
-static SmallVector<ArrayRef<spirv::Extension>, 1>
-getIntegerDotProductExtensions() {
-  // Requires the SPV_KHR_integer_dot_product extension, specified either
-  // explicitly or implied by target env's SPIR-V version >= 1.6.
-  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
-  return {extension};
-}
-
-static SmallVector<ArrayRef<spirv::Capability>, 1>
-getIntegerDotProductCapabilities(Operation *op) {
-  // Requires the the DotProduct capability and capabilities that depend on
-  // exact op types.
-  static const auto dotProductCap = spirv::Capability::DotProduct;
-  static const auto dotProductInput4x8BitPackedCap =
-      spirv::Capability::DotProductInput4x8BitPacked;
-  static const auto dotProductInput4x8BitCap =
-      spirv::Capability::DotProductInput4x8Bit;
-  static const auto dotProductInputAllCap =
-      spirv::Capability::DotProductInputAll;
-
-  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
-
-  Type factorTy = op->getOperand(0).getType();
-  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
-    auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
-        op->getAttr(kPackedVectorFormatAttrName));
-    if (formatAttr.getValue() ==
-        spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
-      capabilities.push_back(dotProductInput4x8BitPackedCap);
-
-    return capabilities;
-  }
-
-  auto vecTy = llvm::cast<VectorType>(factorTy);
-  if (vecTy.getElementTypeBitWidth() == 8) {
-    capabilities.push_back(dotProductInput4x8BitCap);
-    return capabilities;
-  }
-
-  capabilities.push_back(dotProductInputAllCap);
-  return capabilities;
-}
-
-#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
-  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
-  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
-    return getIntegerDotProductExtensions();                                   \
-  }                                                                            \
-  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
-    return getIntegerDotProductCapabilities(*this);                            \
-  }                                                                            \
-  std::optional<spirv::Version> OpName::getMinVersion() {                      \
-    return getIntegerDotProductMinVersion();                                   \
-  }                                                                            \
-  std::optional<spirv::Version> OpName::getMaxVersion() {                      \
-    return getIntegerDotProductMaxVersion();                                   \
-  }
-
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
-
-#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
-
 // TableGen'erated operation interfaces for querying versions, extensions, and
 // capabilities.
 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
new file mode 100644
index 00000000000000..43c0beaccc0fd3
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -0,0 +1,48 @@
+//===- SPIRVParsingUtilities.cpp - MLIR SPIR-V Dialect Parsing Utils-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements common SPIR-V dialect parsing functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVParsingUtils.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+                                        OperationState &state,
+                                        StringRef attrName) {
+  // Parse an optional list of attributes staring with '['
+  if (parser.parseOptionalLSquare()) {
+    // Nothing to do
+    return success();
+  }
+
+  spirv::MemoryAccess memoryAccessAttr;
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
+                                                       state, attrName))
+    return failure();
+
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
+    // Parse integer attribute for alignment.
+    Attribute alignmentAttr;
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.parseComma() ||
+        parser.parseAttribute(alignmentAttr, i32Type,
+                              AttrNames::kAlignmentAttrName,
+                              state.attributes)) {
+      return failure();
+    }
+  }
+  return parser.parseRSquare();
+}
+
+} // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
new file mode 100644
index 00000000000000..fd2faf4b7b333f
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -0,0 +1,156 @@
+//===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <type_traits>
+
+namespace mlir::spirv {
+namespace AttrNames {
+// TODO: generate these strings using ODS.
+inline constexpr char kAlignmentAttrName[] = "alignment";
+inline constexpr char kBranchWeightAttrName[] = "branch_weights";
+inline constexpr char kCallee[] = "callee";
+inline constexpr char kClusterSize[] = "cluster_size";
+inline constexpr char kControl[] = "control";
+inline constexpr char kDefaultValueAttrName[] = "default_value";
+inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
+inline constexpr char kFnNameAttrName[] = "fn";
+inline constexpr char kGroupOperationAttrName[] = "group_operation";
+inline constexpr char kIndicesAttrName[] = "indices";
+inline constexpr char kInitializerAttrName[] = "initializer";
+inline constexpr char kInterfaceAttrName[] = "interface";
+inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
+inline constexpr char kMemoryAccessAttrName[] = "memory_access";
+inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
+inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
+inline constexpr char kPackedVectorFormatAttrName[] = "format";
+inline constexpr char kSemanticsAttrName[] = "semantics";
+inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
+inline constexpr char kSpecIdAttrName[] = "spec_id";
+inline constexpr char kTypeAttrName[] = "type";
+inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
+inline constexpr char kValueAttrName[] = "value";
+inline constexpr char kValuesAttrName[] = "values";
+inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+} // namespace AttrNames
+
+template <typename Ty>
+ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
+                                     function_ref<StringRef(Ty)> stringifyFn) {
+  if (enumValues.empty()) {
+    return nullptr;
+  }
+  SmallVector<StringRef, 1> enumValStrs;
+  enumValStrs.reserve(enumValues.size());
+  for (auto val : enumValues) {
+    enumValStrs.emplace_back(stringifyFn(val));
+  }
+  return builder.getStrArrayAttr(enumValStrs);
+}
+
+/// Parses the next keyword in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass, typename ParserType>
+ParseResult
+parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
+                     StringRef attrName = spirv::attributeName<EnumClass>()) {
+  StringRef keyword;
+  SmallVector<NamedAttribute, 1> attr;
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseKeyword(&keyword))
+    return failure();
+
+  if (std::optional<EnumClass> attr =
+          spirv::symbolizeEnum<EnumClass>(keyword)) {
+    value = *attr;
+    return success();
+  }
+  return parser.emitError(loc, "invalid ")
+         << attrName << " attribute specification: " << keyword;
+}
+
+/// Parses the next string attribute in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass>
+ParseResult
+parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
+                 StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
+  Attribute attrVal;
+  NamedAttrList attr;
+  auto loc = parser.getCurrentLocation();
+  if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
+                            attrName, attr))
+    return failure();
+  if (!llvm::isa<StringAttr>(attrVal))
+    return parser.emitError(loc, "expected ")
+           << attrName << " attribute specified as string";
+  auto attrOptional = spirv::symbolizeEnum<EnumClass>(
+      llvm::cast<StringAttr>(attrVal).getValue());
+  if (!attrOptional)
+    return parser.emitError(loc, "invalid ")
+           << attrName << " attribute specification: " << attrVal;
+  value = *attrOptional;
+  return success();
+}
+
+/// Parses the next string attribute in `parser` as an enumerant of the given
+/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
+/// attribute with the enum class's name as attribute name.
+template <typename EnumAttrClass,
+          typename EnumClass = typename EnumAttrClass::ValueType>
+ParseResult
+parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
+                 StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
+  if (parseEnumStrAttr(value, parser, attrName))
+    return failure();
+  state.addAttribute(attrName,
+                     parser.getBuilder().getAttr<EnumAttrClass>(value));
+  return success();
+}
+
+/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
+/// and inserts the enumerant into `state` as an 32-bit integer attribute with
+/// the enum class's name as attribute name.
+template <typename EnumAttrClass,
+          typename EnumClass = typename EnumAttrClass::ValueType>
+ParseResult
+parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
+                     OperationState &state,
+                     StringRef attrName = spirv::attributeName<EnumClass>()) {
+  static_assert(std::is_enum_v<EnumClass>);
+  if (parseEnumKeywordAttr(value, parser))
+    return failure();
+  state.addAttribute(attrName,
+                     parser.getBuilder().getAttr<EnumAttrClass>(value));
+  return success();
+}
+
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+///     (`[` memory-access `]`)?
+/// where:
+///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+///         integer-literal | `"NonTemporal"`
+ParseResult parseMemoryAccessAttributes(
+    OpAsmParser &parser, OperationState &state,
+    StringRef attrName = AttrNames::kMemoryAccessAttrName);
+
+} // namespace mlir::spirv


        


More information about the Mlir-commits mailing list