[Mlir-commits] [mlir] 9415241 - [mlir][spirv] Split op implementation file into subfiles. NFC.
Jakub Kuderski
llvmlistbot at llvm.org
Wed Jul 19 13:50:59 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-19T16:48:47-04:00
New Revision: 9415241c5ba700379d67006391e50204df4f32f4
URL: https://github.com/llvm/llvm-project/commit/9415241c5ba700379d67006391e50204df4f32f4
DIFF: https://github.com/llvm/llvm-project/commit/9415241c5ba700379d67006391e50204df4f32f4.diff
LOG: [mlir][spirv] Split op implementation file into subfiles. NFC.
The main op implementation file for SPIR-V grew past 5k LOC. This makes it
take a long time to compile and index with LSPs like clangd.
Pull out the first few SPIR-V extension ops into their own `.cpp` files,
just like we do with `.td` op definitions. This includes the
KHR/NV/Intel coop matrix and the integer dot prod extensions.
I plan to further split this in future revisions.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D155747
Added:
mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
Modified:
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Removed:
mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h b/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h
deleted file mode 100644
index 073bdcef2293fc..00000000000000
--- a/mlir/include/mlir/Dialect/SPIRV/IR/ParserUtils.h
+++ /dev/null
@@ -1,45 +0,0 @@
-//===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines utilities used for parsing types and ops for SPIR-V
-// dialect.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_
-#define MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_
-
-#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
-#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/OpImplementation.h"
-
-namespace mlir {
-
-/// Parses the next keyword in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass, typename ParserType>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- StringRef keyword;
- SmallVector<NamedAttribute, 1> attr;
- auto loc = parser.getCurrentLocation();
- if (parser.parseKeyword(&keyword))
- return failure();
- if (std::optional<EnumClass> attr =
- spirv::symbolizeEnum<EnumClass>(keyword)) {
- value = *attr;
- return success();
- }
- return parser.emitError(loc, "invalid ")
- << attrName << " attribute specification: " << keyword;
-}
-
-} // namespace mlir
-
-#endif // MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b59cd07f87f92f..70e2eb786e397c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,12 +3,16 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
+ CooperativeMatrixOps.cpp
+ IntegerDotProductOps.cpp
+ JointMatrixOps.cpp
SPIRVAttributes.cpp
SPIRVCanonicalization.cpp
SPIRVGLCanonicalization.cpp
SPIRVDialect.cpp
SPIRVEnums.cpp
SPIRVOps.cpp
+ SPIRVParsingUtils.cpp
SPIRVTypes.cpp
TargetAndABI.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
new file mode 100644
index 00000000000000..bdd87677866501
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -0,0 +1,306 @@
+//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Cooperative Matrix operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult KHRCooperativeMatrixLengthOp::verify() {
+ if (!isa<CooperativeMatrixType>(getCooperativeMatrixType())) {
+ return emitOpError(
+ "type attribute must be a '!spirv.coopmatrix' type, found ")
+ << getCooperativeMatrixType() << " instead";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
+ if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
+ return failure();
+ if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
+ return failure();
+
+ CooperativeMatrixLayoutKHR layout;
+ if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
+ layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+ return failure();
+ }
+
+ if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+ return failure();
+
+ Type ptrType;
+ Type elementType;
+ if (parser.parseColon() || parser.parseType(ptrType) ||
+ parser.parseKeywordType("as", elementType)) {
+ return failure();
+ }
+ result.addTypes(elementType);
+
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ if (parser.resolveOperands(operandInfo, {ptrType, strideType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+void KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getStride() << ", "
+ << getMatrixLayout();
+ // Print optional memory operand attribute.
+ if (auto memOperand = getMemoryOperand())
+ printer << " [\"" << memOperand << "\"]";
+ printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
+ Type coopMatrix) {
+ auto pointerType = cast<PointerType>(pointer);
+ Type pointeeType = pointerType.getPointeeType();
+ if (!isa<ScalarType, VectorType>(pointeeType)) {
+ return op->emitError(
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ }
+
+ // TODO: Verify the memory object behind the pointer:
+ // > If the Shader capability was declared, Pointer must point into an array
+ // > and any ArrayStride decoration on Pointer is ignored.
+
+ return success();
+}
+
+LogicalResult KHRCooperativeMatrixLoadOp::verify() {
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
+ for (auto &op : operandInfo) {
+ if (parser.parseOperand(op) || parser.parseComma())
+ return failure();
+ }
+
+ CooperativeMatrixLayoutKHR layout;
+ if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
+ layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
+ return failure();
+ }
+
+ if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
+ return failure();
+
+ Type ptrType;
+ Type objectType;
+ if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
+ parser.parseType(objectType)) {
+ return failure();
+ }
+
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+ << ", " << getMatrixLayout();
+
+ // Print optional memory operand attribute.
+ if (auto memOperand = getMemoryOperand())
+ printer << " [\"" << *memOperand << "\"]";
+ printer << " : " << getPointer().getType() << ", " << getObject().getType();
+}
+
+LogicalResult KHRCooperativeMatrixStoreOp::verify() {
+ return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
+ getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLength
+//===----------------------------------------------------------------------===//
+
+LogicalResult NVCooperativeMatrixLengthOp::verify() {
+ if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
+ return emitOpError(
+ "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
+ << getCooperativeMatrixType() << " instead";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
+ParseResult NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
+ Type elementType;
+ if (parser.parseOperandList(operandInfo, 3) ||
+ parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+ parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
+ return failure();
+ }
+ if (parser.resolveOperands(operandInfo,
+ {ptrType, strideType, columnMajorType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ result.addTypes(elementType);
+ return success();
+}
+
+void NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getStride() << ", "
+ << getColumnmajor();
+ // Print optional memory access attribute.
+ if (auto memAccess = getMemoryAccess())
+ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+ printer << " : " << getPointer().getType() << " as " << getType();
+}
+
+static LogicalResult
+verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
+ Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
+ if (!llvm::isa<ScalarType>(pointeeType) &&
+ !llvm::isa<VectorType>(pointeeType))
+ return op->emitError(
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
+ if (storage != StorageClass::Workgroup &&
+ storage != StorageClass::StorageBuffer &&
+ storage != StorageClass::PhysicalStorageBuffer)
+ return op->emitError(
+ "Pointer storage class must be Workgroup, StorageBuffer or "
+ "PhysicalStorageBufferEXT but provided ")
+ << stringifyStorageClass(storage);
+ return success();
+}
+
+LogicalResult NVCooperativeMatrixLoadOp::verify() {
+ return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+ getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixStore
+//===----------------------------------------------------------------------===//
+
+ParseResult NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
+ Type strideType = parser.getBuilder().getIntegerType(32);
+ Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
+ Type elementType;
+ if (parser.parseOperandList(operandInfo, 4) ||
+ parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+ parser.parseType(ptrType) || parser.parseComma() ||
+ parser.parseType(elementType)) {
+ return failure();
+ }
+ if (parser.resolveOperands(
+ operandInfo, {ptrType, elementType, strideType, columnMajorType},
+ parser.getNameLoc(), result.operands)) {
+ return failure();
+ }
+
+ return success();
+}
+
+void NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
+ printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
+ << ", " << getColumnmajor();
+ // Print optional memory access attribute.
+ if (auto memAccess = getMemoryAccess())
+ printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
+ printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
+}
+
+LogicalResult NVCooperativeMatrixStoreOp::verify() {
+ return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
+ getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.NV.CooperativeMatrixMulAdd
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op) {
+ if (op.getC().getType() != op.getResult().getType())
+ return op.emitOpError("result and third operand must have the same type");
+ auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
+ auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
+ auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
+ auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
+ if (typeA.getRows() != typeR.getRows() ||
+ typeA.getColumns() != typeB.getRows() ||
+ typeB.getColumns() != typeR.getColumns())
+ return op.emitOpError("matrix size must match");
+ if (typeR.getScope() != typeA.getScope() ||
+ typeR.getScope() != typeB.getScope() ||
+ typeR.getScope() != typeC.getScope())
+ return op.emitOpError("matrix scope must match");
+ auto elementTypeA = typeA.getElementType();
+ auto elementTypeB = typeB.getElementType();
+ if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
+ if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
+ llvm::cast<IntegerType>(elementTypeB).getWidth())
+ return op.emitOpError(
+ "matrix A and B integer element types must be the same bit width");
+ } else if (elementTypeA != elementTypeB) {
+ return op.emitOpError(
+ "matrix A and B non-integer element types must match");
+ }
+ if (typeR.getElementType() != typeC.getElementType())
+ return op.emitOpError("matrix accumulator element type must match");
+ return success();
+}
+
+LogicalResult NVCooperativeMatrixMulAddOp::verify() {
+ return verifyCoopMatrixMulAddNV(*this);
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
new file mode 100644
index 00000000000000..28efe4f046fcde
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -0,0 +1,158 @@
+//===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Integer Dot Product operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+//===----------------------------------------------------------------------===//
+// Integer Dot Product ops
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyIntegerDotProduct(Operation *op) {
+ assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
+ "Not an integer dot product op?");
+ assert(op->getNumResults() == 1 && "Expected a single result");
+
+ Type factorTy = op->getOperand(0).getType();
+ if (op->getOperand(1).getType() != factorTy)
+ return op->emitOpError("requires the same type for both vector operands");
+
+ unsigned expectedNumAttrs = 0;
+ if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
+ ++expectedNumAttrs;
+ auto packedVectorFormat =
+ llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
+ op->getAttr(kPackedVectorFormatAttrName));
+ if (!packedVectorFormat)
+ return op->emitOpError("requires Packed Vector Format attribute for "
+ "integer vector operands");
+
+ assert(packedVectorFormat.getValue() ==
+ spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
+ "Unknown Packed Vector Format");
+ if (intTy.getWidth() != 32)
+ return op->emitOpError(
+ llvm::formatv("with specified Packed Vector Format ({0}) requires "
+ "integer vector operands to be 32-bits wide",
+ packedVectorFormat.getValue()));
+ } else {
+ if (op->hasAttr(kPackedVectorFormatAttrName))
+ return op->emitOpError(llvm::formatv(
+ "with invalid format attribute for vector operands of type '{0}'",
+ factorTy));
+ }
+
+ if (op->getAttrs().size() > expectedNumAttrs)
+ return op->emitError(
+ "op only supports the 'format' #spirv.packed_vector_format attribute");
+
+ Type resultTy = op->getResultTypes().front();
+ bool hasAccumulator = op->getNumOperands() == 3;
+ if (hasAccumulator && op->getOperand(2).getType() != resultTy)
+ return op->emitOpError(
+ "requires the same accumulator operand and result types");
+
+ unsigned factorBitWidth = getBitWidth(factorTy);
+ unsigned resultBitWidth = getBitWidth(resultTy);
+ if (factorBitWidth > resultBitWidth)
+ return op->emitOpError(
+ llvm::formatv("result type has insufficient bit-width ({0} bits) "
+ "for the specified vector operand type ({1} bits)",
+ resultBitWidth, factorBitWidth));
+
+ return success();
+}
+
+static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
+ return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
+}
+
+static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
+ return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
+}
+
+static SmallVector<ArrayRef<spirv::Extension>, 1>
+getIntegerDotProductExtensions() {
+ // Requires the SPV_KHR_integer_dot_product extension, specified either
+ // explicitly or implied by target env's SPIR-V version >= 1.6.
+ static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
+ return {extension};
+}
+
+static SmallVector<ArrayRef<spirv::Capability>, 1>
+getIntegerDotProductCapabilities(Operation *op) {
+ // Requires the the DotProduct capability and capabilities that depend on
+ // exact op types.
+ static const auto dotProductCap = spirv::Capability::DotProduct;
+ static const auto dotProductInput4x8BitPackedCap =
+ spirv::Capability::DotProductInput4x8BitPacked;
+ static const auto dotProductInput4x8BitCap =
+ spirv::Capability::DotProductInput4x8Bit;
+ static const auto dotProductInputAllCap =
+ spirv::Capability::DotProductInputAll;
+
+ SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
+
+ Type factorTy = op->getOperand(0).getType();
+ if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
+ auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
+ op->getAttr(kPackedVectorFormatAttrName));
+ if (formatAttr.getValue() ==
+ spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
+ capabilities.push_back(dotProductInput4x8BitPackedCap);
+
+ return capabilities;
+ }
+
+ auto vecTy = llvm::cast<VectorType>(factorTy);
+ if (vecTy.getElementTypeBitWidth() == 8) {
+ capabilities.push_back(dotProductInput4x8BitCap);
+ return capabilities;
+ }
+
+ capabilities.push_back(dotProductInputAllCap);
+ return capabilities;
+}
+
+#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
+ LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
+ SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
+ return getIntegerDotProductExtensions(); \
+ } \
+ SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
+ return getIntegerDotProductCapabilities(*this); \
+ } \
+ std::optional<spirv::Version> OpName::getMinVersion() { \
+ return getIntegerDotProductMinVersion(); \
+ } \
+ std::optional<spirv::Version> OpName::getMaxVersion() { \
+ return getIntegerDotProductMaxVersion(); \
+ }
+
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(SUDotAccSatOp)
+SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(UDotAccSatOp)
+
+#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
new file mode 100644
index 00000000000000..63305ecdd0c4e9
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/JointMatrixOps.cpp
@@ -0,0 +1,84 @@
+//===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the Intel Joint Matrix operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixLoad
+//===----------------------------------------------------------------------===//
+
+static LogicalResult
+verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
+ Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
+ if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
+ !llvm::isa<VectorType>(pointeeType))
+ return op->emitError(
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ spirv::StorageClass storage =
+ llvm::cast<spirv::PointerType>(pointer).getStorageClass();
+ if (storage != spirv::StorageClass::Workgroup &&
+ storage != spirv::StorageClass::CrossWorkgroup &&
+ storage != spirv::StorageClass::UniformConstant &&
+ storage != spirv::StorageClass::Generic)
+ return op->emitError("Pointer storage class must be Workgroup or "
+ "CrossWorkgroup but provided ")
+ << stringifyStorageClass(storage);
+ return success();
+}
+
+LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
+ return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+ getResult().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixStore
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
+ return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
+ getObject().getType());
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.JointMatrixMad
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
+ if (op.getC().getType() != op.getResult().getType())
+ return op.emitOpError("result and third operand must have the same type");
+ auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
+ auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
+ auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
+ auto typeR =
+ llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
+ if (typeA.getRows() != typeR.getRows() ||
+ typeA.getColumns() != typeB.getRows() ||
+ typeB.getColumns() != typeR.getColumns())
+ return op.emitOpError("matrix size must match");
+ if (typeR.getScope() != typeA.getScope() ||
+ typeR.getScope() != typeB.getScope() ||
+ typeR.getScope() != typeC.getScope())
+ return op.emitOpError("matrix scope must match");
+ if (typeA.getElementType() != typeB.getElementType() ||
+ typeR.getElementType() != typeC.getElementType())
+ return op.emitOpError("matrix element type must match");
+ return success();
+}
+
+LogicalResult spirv::INTELJointMatrixMadOp::verify() {
+ return verifyJointMatrixMad(*this);
+}
+
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 2e1c7923e24126..124d4ed6e8e6ed 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -11,7 +11,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
+
+#include "SPIRVParsingUtils.h"
+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -341,11 +343,13 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return {};
Scope scope;
- if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ if (parser.parseComma() ||
+ spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return {};
CooperativeMatrixUseKHR use;
- if (parser.parseComma() || parseEnumKeywordAttr(use, parser, "use <id>"))
+ if (parser.parseComma() ||
+ spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
return {};
if (parser.parseGreater())
@@ -376,7 +380,8 @@ static Type parseCooperativeMatrixNVType(SPIRVDialect const &dialect,
return Type();
Scope scope;
- if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ if (parser.parseComma() ||
+ spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
if (parser.parseGreater())
@@ -407,10 +412,11 @@ static Type parseJointMatrixType(SPIRVDialect const &dialect,
return Type();
MatrixLayout matrixLayout;
if (parser.parseComma() ||
- parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
+ spirv::parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>"))
return Type();
Scope scope;
- if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
+ if (parser.parseComma() ||
+ spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
return Type();
if (parser.parseGreater())
return Type();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
new file mode 100644
index 00000000000000..efe596cd725c5e
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
@@ -0,0 +1,32 @@
+//===- SPIRVOpUtils.h - MLIR SPIR-V Dialect Op Definition Utilities -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+namespace mlir::spirv {
+
+/// Returns the bit width of the `type`.
+inline unsigned getBitWidth(Type type) {
+ if (isa<spirv::PointerType>(type)) {
+ // Just return 64 bits for pointer types for now.
+ // TODO: Make sure not caller relies on the actual pointer width value.
+ return 64;
+ }
+
+ if (type.isIntOrFloat())
+ return type.getIntOrFloatBitWidth();
+
+ if (auto vectorType = dyn_cast<VectorType>(type)) {
+ assert(vectorType.getElementType().isIntOrFloat());
+ return vectorType.getNumElements() *
+ vectorType.getElementType().getIntOrFloatBitWidth();
+ }
+ llvm_unreachable("unhandled bit width computation for type");
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2516c29fbc58a8..2184cec953fb05 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -12,7 +12,9 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
@@ -33,41 +35,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/FormatVariadic.h"
#include <cassert>
#include <numeric>
#include <type_traits>
using namespace mlir;
-
-// TODO: generate these strings using ODS.
-constexpr char kAlignmentAttrName[] = "alignment";
-constexpr char kBranchWeightAttrName[] = "branch_weights";
-constexpr char kCallee[] = "callee";
-constexpr char kClusterSize[] = "cluster_size";
-constexpr char kControl[] = "control";
-constexpr char kDefaultValueAttrName[] = "default_value";
-constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-constexpr char kExecutionScopeAttrName[] = "execution_scope";
-constexpr char kFnNameAttrName[] = "fn";
-constexpr char kGroupOperationAttrName[] = "group_operation";
-constexpr char kIndicesAttrName[] = "indices";
-constexpr char kInitializerAttrName[] = "initializer";
-constexpr char kInterfaceAttrName[] = "interface";
-constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-constexpr char kMemoryAccessAttrName[] = "memory_access";
-constexpr char kMemoryOperandAttrName[] = "memory_operand";
-constexpr char kMemoryScopeAttrName[] = "memory_scope";
-constexpr char kPackedVectorFormatAttrName[] = "format";
-constexpr char kSemanticsAttrName[] = "semantics";
-constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-constexpr char kSpecIdAttrName[] = "spec_id";
-constexpr char kTypeAttrName[] = "type";
-constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-constexpr char kValueAttrName[] = "value";
-constexpr char kValuesAttrName[] = "values";
-constexpr char kCompositeSpecConstituentsName[] = "constituents";
+using namespace mlir::spirv::AttrNames;
//===----------------------------------------------------------------------===//
// Common utility functions
@@ -158,79 +131,6 @@ static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
return success();
}
-template <typename Ty>
-static ArrayAttr
-getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
- function_ref<StringRef(Ty)> stringifyFn) {
- if (enumValues.empty()) {
- return nullptr;
- }
- SmallVector<StringRef, 1> enumValStrs;
- enumValStrs.reserve(enumValues.size());
- for (auto val : enumValues) {
- enumValStrs.emplace_back(stringifyFn(val));
- }
- return builder.getStrArrayAttr(enumValStrs);
-}
-
-/// Parses the next string attribute in `parser` as an enumerant of the given
-/// `EnumClass`.
-template <typename EnumClass>
-static ParseResult
-parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- static_assert(std::is_enum_v<EnumClass>);
- Attribute attrVal;
- NamedAttrList attr;
- auto loc = parser.getCurrentLocation();
- if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
- attrName, attr))
- return failure();
- if (!llvm::isa<StringAttr>(attrVal))
- return parser.emitError(loc, "expected ")
- << attrName << " attribute specified as string";
- auto attrOptional = spirv::symbolizeEnum<EnumClass>(
- llvm::cast<StringAttr>(attrVal).getValue());
- if (!attrOptional)
- return parser.emitError(loc, "invalid ")
- << attrName << " attribute specification: " << attrVal;
- value = *attrOptional;
- return success();
-}
-
-/// Parses the next string attribute in `parser` as an enumerant of the given
-/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
-/// attribute with the enum class's name as attribute name.
-template <typename EnumAttrClass,
- typename EnumClass = typename EnumAttrClass::ValueType>
-static ParseResult
-parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- static_assert(std::is_enum_v<EnumClass>);
- if (parseEnumStrAttr(value, parser))
- return failure();
- state.addAttribute(attrName,
- parser.getBuilder().getAttr<EnumAttrClass>(value));
- return success();
-}
-
-/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
-/// and inserts the enumerant into `state` as an 32-bit integer attribute with
-/// the enum class's name as attribute name.
-template <typename EnumAttrClass,
- typename EnumClass = typename EnumAttrClass::ValueType>
-static ParseResult
-parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
- OperationState &state,
- StringRef attrName = spirv::attributeName<EnumClass>()) {
- static_assert(std::is_enum_v<EnumClass>);
- if (parseEnumKeywordAttr(value, parser))
- return failure();
- state.addAttribute(attrName,
- parser.getBuilder().getAttr<EnumAttrClass>(value));
- return success();
-}
-
/// Parses Function, Selection and Loop control attributes. If no control is
/// specified, "None" is used as a default.
template <typename EnumAttrClass, typename EnumClass>
@@ -240,7 +140,7 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
if (succeeded(parser.parseOptionalKeyword(kControl))) {
EnumClass control;
if (parser.parseLParen() ||
- parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
+ spirv::parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
parser.parseRParen())
return failure();
return success();
@@ -252,40 +152,6 @@ parseControlAttribute(OpAsmParser &parser, OperationState &state,
return success();
}
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-/// (`[` memory-access `]`)?
-/// where:
-/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-/// integer-literal | `"NonTemporal"`
-static ParseResult
-parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state,
- StringRef attrName = kMemoryAccessAttrName) {
- // Parse an optional list of attributes staring with '['
- if (parser.parseOptionalLSquare()) {
- // Nothing to do
- return success();
- }
-
- spirv::MemoryAccess memoryAccessAttr;
- if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
- attrName))
- return failure();
-
- if (spirv::bitEnumContainsAll(memoryAccessAttr,
- spirv::MemoryAccess::Aligned)) {
- // Parse integer attribute for alignment.
- Attribute alignmentAttr;
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
- state.attributes)) {
- return failure();
- }
- }
- return parser.parseRSquare();
-}
-
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
@@ -299,8 +165,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
- kSourceMemoryAccessAttrName))
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+ memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -683,25 +549,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}
-// Get bit width of types.
-static unsigned getBitWidth(Type type) {
- if (llvm::isa<spirv::PointerType>(type)) {
- // Just return 64 bits for pointer types for now.
- // TODO: Make sure not caller relies on the actual pointer width value.
- return 64;
- }
-
- if (type.isIntOrFloat())
- return type.getIntOrFloatBitWidth();
-
- if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
- assert(vectorType.getElementType().isIntOrFloat());
- return vectorType.getNumElements() *
- vectorType.getElementType().getIntOrFloatBitWidth();
- }
- llvm_unreachable("unhandled bit width computation for type");
-}
-
/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
@@ -839,10 +686,10 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
Type type;
SMLoc loc;
- if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
- kMemoryScopeAttrName) ||
- parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
- kSemanticsAttrName) ||
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
+ kMemoryScopeAttrName) ||
+ spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ memoryScope, parser, state, kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
return failure();
@@ -916,10 +763,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
spirv::Scope executionScope;
spirv::GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
- if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- kExecutionScopeAttrName) ||
- parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
- kGroupOperationAttrName) ||
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
+ kExecutionScopeAttrName) ||
+ spirv::parseEnumStrAttr<spirv::GroupOperationAttr>(
+ groupOperation, parser, state, kGroupOperationAttrName) ||
parser.parseOperand(valueInfo))
return failure();
@@ -1199,11 +1046,11 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
spirv::MemorySemantics equalSemantics, unequalSemantics;
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type type;
- if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
- kMemoryScopeAttrName) ||
- parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
+ kMemoryScopeAttrName) ||
+ spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
equalSemantics, parser, state, kEqualSemanticsAttrName) ||
- parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 3))
return failure();
@@ -3478,10 +3325,10 @@ ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
- if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
- result) ||
- ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
- result))
+ if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
+ result) ||
+ spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
+ result))
return failure();
if (succeeded(parser.parseOptionalKeyword("requires"))) {
@@ -4028,364 +3875,6 @@ LogicalResult spirv::VectorShuffleOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::KHRCooperativeMatrixLengthOp::verify() {
- if (!isa<spirv::CooperativeMatrixType>(getCooperativeMatrixType())) {
- return emitOpError(
- "type attribute must be a '!spirv.coopmatrix' type, found ")
- << getCooperativeMatrixType() << " instead";
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::KHRCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
- OperationState &result) {
- std::array<OpAsmParser::UnresolvedOperand, 2> operandInfo = {};
- if (parser.parseOperand(operandInfo[0]) || parser.parseComma())
- return failure();
- if (parser.parseOperand(operandInfo[1]) || parser.parseComma())
- return failure();
-
- spirv::CooperativeMatrixLayoutKHR layout;
- if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
- layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
- return failure();
- }
-
- if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
- return failure();
-
- Type ptrType;
- Type elementType;
- if (parser.parseColon() || parser.parseType(ptrType) ||
- parser.parseKeywordType("as", elementType)) {
- return failure();
- }
- result.addTypes(elementType);
-
- Type strideType = parser.getBuilder().getIntegerType(32);
- if (parser.resolveOperands(operandInfo, {ptrType, strideType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void spirv::KHRCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getStride() << ", "
- << getMatrixLayout();
- // Print optional memory operand attribute.
- if (auto memOperand = getMemoryOperand())
- printer << " [\"" << memOperand << "\"]";
- printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
- Type coopMatrix) {
- auto pointerType = cast<spirv::PointerType>(pointer);
- Type pointeeType = pointerType.getPointeeType();
- if (!isa<spirv::ScalarType, VectorType>(pointeeType)) {
- return op->emitError(
- "Pointer must point to a scalar or vector type but provided ")
- << pointeeType;
- }
-
- // TODO: Verify the memory object behind the pointer:
- // > If the Shader capability was declared, Pointer must point into an array
- // > and any ArrayStride decoration on Pointer is ignored.
-
- return success();
-}
-
-LogicalResult spirv::KHRCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
- OperationState &result) {
- std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
- for (auto &op : operandInfo) {
- if (parser.parseOperand(op) || parser.parseComma())
- return failure();
- }
-
- spirv::CooperativeMatrixLayoutKHR layout;
- if (::parseEnumKeywordAttr<spirv::CooperativeMatrixLayoutKHRAttr>(
- layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
- return failure();
- }
-
- if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
- return failure();
-
- Type ptrType;
- Type objectType;
- if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
- parser.parseType(objectType)) {
- return failure();
- }
-
- Type strideType = parser.getBuilder().getIntegerType(32);
- if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void spirv::KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
- << ", " << getMatrixLayout();
-
- // Print optional memory operand attribute.
- if (auto memOperand = getMemoryOperand())
- printer << " [\"" << *memOperand << "\"]";
- printer << " : " << getPointer().getType() << ", " << getObject().getType();
-}
-
-LogicalResult spirv::KHRCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
- getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLength
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::NVCooperativeMatrixLengthOp::verify() {
- if (!isa<spirv::CooperativeMatrixNVType>(getCooperativeMatrixType())) {
- return emitOpError(
- "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
- << getCooperativeMatrixType() << " instead";
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
- Type strideType = parser.getBuilder().getIntegerType(32);
- Type columnMajorType = parser.getBuilder().getIntegerType(1);
- Type ptrType;
- Type elementType;
- if (parser.parseOperandList(operandInfo, 3) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
- return failure();
- }
- if (parser.resolveOperands(operandInfo,
- {ptrType, strideType, columnMajorType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- result.addTypes(elementType);
- return success();
-}
-
-void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getStride() << ", "
- << getColumnmajor();
- // Print optional memory access attribute.
- if (auto memAccess = getMemoryAccess())
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << getPointer().getType() << " as " << getType();
-}
-
-static LogicalResult
-verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix) {
- Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
- if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
- !llvm::isa<VectorType>(pointeeType))
- return op->emitError(
- "Pointer must point to a scalar or vector type but provided ")
- << pointeeType;
- spirv::StorageClass storage =
- llvm::cast<spirv::PointerType>(pointer).getStorageClass();
- if (storage != spirv::StorageClass::Workgroup &&
- storage != spirv::StorageClass::StorageBuffer &&
- storage != spirv::StorageClass::PhysicalStorageBuffer)
- return op->emitError(
- "Pointer storage class must be Workgroup, StorageBuffer or "
- "PhysicalStorageBufferEXT but provided ")
- << stringifyStorageClass(storage);
- return success();
-}
-
-LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
- return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
- getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixStore
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
- Type strideType = parser.getBuilder().getIntegerType(32);
- Type columnMajorType = parser.getBuilder().getIntegerType(1);
- Type ptrType;
- Type elementType;
- if (parser.parseOperandList(operandInfo, 4) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(ptrType) || parser.parseComma() ||
- parser.parseType(elementType)) {
- return failure();
- }
- if (parser.resolveOperands(
- operandInfo, {ptrType, elementType, strideType, columnMajorType},
- parser.getNameLoc(), result.operands)) {
- return failure();
- }
-
- return success();
-}
-
-void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
- printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
- << ", " << getColumnmajor();
- // Print optional memory access attribute.
- if (auto memAccess = getMemoryAccess())
- printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
-}
-
-LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
- return verifyPointerAndCoopMatrixNVType(*this, getPointer().getType(),
- getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.NV.CooperativeMatrixMulAdd
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyCoopMatrixMulAddNV(spirv::NVCooperativeMatrixMulAddOp op) {
- if (op.getC().getType() != op.getResult().getType())
- return op.emitOpError("result and third operand must have the same type");
- auto typeA = llvm::cast<spirv::CooperativeMatrixNVType>(op.getA().getType());
- auto typeB = llvm::cast<spirv::CooperativeMatrixNVType>(op.getB().getType());
- auto typeC = llvm::cast<spirv::CooperativeMatrixNVType>(op.getC().getType());
- auto typeR =
- llvm::cast<spirv::CooperativeMatrixNVType>(op.getResult().getType());
- if (typeA.getRows() != typeR.getRows() ||
- typeA.getColumns() != typeB.getRows() ||
- typeB.getColumns() != typeR.getColumns())
- return op.emitOpError("matrix size must match");
- if (typeR.getScope() != typeA.getScope() ||
- typeR.getScope() != typeB.getScope() ||
- typeR.getScope() != typeC.getScope())
- return op.emitOpError("matrix scope must match");
- auto elementTypeA = typeA.getElementType();
- auto elementTypeB = typeB.getElementType();
- if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
- if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
- llvm::cast<IntegerType>(elementTypeB).getWidth())
- return op.emitOpError(
- "matrix A and B integer element types must be the same bit width");
- } else if (elementTypeA != elementTypeB) {
- return op.emitOpError(
- "matrix A and B non-integer element types must match");
- }
- if (typeR.getElementType() != typeC.getElementType())
- return op.emitOpError("matrix accumulator element type must match");
- return success();
-}
-
-LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
- return verifyCoopMatrixMulAddNV(*this);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixLoad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult
-verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
- Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
- if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
- !llvm::isa<VectorType>(pointeeType))
- return op->emitError(
- "Pointer must point to a scalar or vector type but provided ")
- << pointeeType;
- spirv::StorageClass storage =
- llvm::cast<spirv::PointerType>(pointer).getStorageClass();
- if (storage != spirv::StorageClass::Workgroup &&
- storage != spirv::StorageClass::CrossWorkgroup &&
- storage != spirv::StorageClass::UniformConstant &&
- storage != spirv::StorageClass::Generic)
- return op->emitError("Pointer storage class must be Workgroup or "
- "CrossWorkgroup but provided ")
- << stringifyStorageClass(storage);
- return success();
-}
-
-LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
- return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
- getResult().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixStore
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
- return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
- getObject().getType());
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTEL.JointMatrixMad
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
- if (op.getC().getType() != op.getResult().getType())
- return op.emitOpError("result and third operand must have the same type");
- auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
- auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
- auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
- auto typeR =
- llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
- if (typeA.getRows() != typeR.getRows() ||
- typeA.getColumns() != typeB.getRows() ||
- typeB.getColumns() != typeR.getColumns())
- return op.emitOpError("matrix size must match");
- if (typeR.getScope() != typeA.getScope() ||
- typeR.getScope() != typeB.getScope() ||
- typeR.getScope() != typeC.getScope())
- return op.emitOpError("matrix scope must match");
- if (typeA.getElementType() != typeB.getElementType() ||
- typeR.getElementType() != typeC.getElementType())
- return op.emitOpError("matrix element type must match");
- return success();
-}
-
-LogicalResult spirv::INTELJointMatrixMadOp::verify() {
- return verifyJointMatrixMad(*this);
-}
-
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesScalar
//===----------------------------------------------------------------------===//
@@ -5058,140 +4547,6 @@ LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
-//===----------------------------------------------------------------------===//
-// Integer Dot Product ops
-//===----------------------------------------------------------------------===//
-
-static LogicalResult verifyIntegerDotProduct(Operation *op) {
- assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
- "Not an integer dot product op?");
- assert(op->getNumResults() == 1 && "Expected a single result");
-
- Type factorTy = op->getOperand(0).getType();
- if (op->getOperand(1).getType() != factorTy)
- return op->emitOpError("requires the same type for both vector operands");
-
- unsigned expectedNumAttrs = 0;
- if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
- ++expectedNumAttrs;
- auto packedVectorFormat =
- llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
- if (!packedVectorFormat)
- return op->emitOpError("requires Packed Vector Format attribute for "
- "integer vector operands");
-
- assert(packedVectorFormat.getValue() ==
- spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
- "Unknown Packed Vector Format");
- if (intTy.getWidth() != 32)
- return op->emitOpError(
- llvm::formatv("with specified Packed Vector Format ({0}) requires "
- "integer vector operands to be 32-bits wide",
- packedVectorFormat.getValue()));
- } else {
- if (op->hasAttr(kPackedVectorFormatAttrName))
- return op->emitOpError(llvm::formatv(
- "with invalid format attribute for vector operands of type '{0}'",
- factorTy));
- }
-
- if (op->getAttrs().size() > expectedNumAttrs)
- return op->emitError(
- "op only supports the 'format' #spirv.packed_vector_format attribute");
-
- Type resultTy = op->getResultTypes().front();
- bool hasAccumulator = op->getNumOperands() == 3;
- if (hasAccumulator && op->getOperand(2).getType() != resultTy)
- return op->emitOpError(
- "requires the same accumulator operand and result types");
-
- unsigned factorBitWidth = getBitWidth(factorTy);
- unsigned resultBitWidth = getBitWidth(resultTy);
- if (factorBitWidth > resultBitWidth)
- return op->emitOpError(
- llvm::formatv("result type has insufficient bit-width ({0} bits) "
- "for the specified vector operand type ({1} bits)",
- resultBitWidth, factorBitWidth));
-
- return success();
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
- return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
-}
-
-static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
- return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
-}
-
-static SmallVector<ArrayRef<spirv::Extension>, 1>
-getIntegerDotProductExtensions() {
- // Requires the SPV_KHR_integer_dot_product extension, specified either
- // explicitly or implied by target env's SPIR-V version >= 1.6.
- static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
- return {extension};
-}
-
-static SmallVector<ArrayRef<spirv::Capability>, 1>
-getIntegerDotProductCapabilities(Operation *op) {
- // Requires the the DotProduct capability and capabilities that depend on
- // exact op types.
- static const auto dotProductCap = spirv::Capability::DotProduct;
- static const auto dotProductInput4x8BitPackedCap =
- spirv::Capability::DotProductInput4x8BitPacked;
- static const auto dotProductInput4x8BitCap =
- spirv::Capability::DotProductInput4x8Bit;
- static const auto dotProductInputAllCap =
- spirv::Capability::DotProductInputAll;
-
- SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
-
- Type factorTy = op->getOperand(0).getType();
- if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
- auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
- if (formatAttr.getValue() ==
- spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
- capabilities.push_back(dotProductInput4x8BitPackedCap);
-
- return capabilities;
- }
-
- auto vecTy = llvm::cast<VectorType>(factorTy);
- if (vecTy.getElementTypeBitWidth() == 8) {
- capabilities.push_back(dotProductInput4x8BitCap);
- return capabilities;
- }
-
- capabilities.push_back(dotProductInputAllCap);
- return capabilities;
-}
-
-#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
- LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
- SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
- return getIntegerDotProductExtensions(); \
- } \
- SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
- return getIntegerDotProductCapabilities(*this); \
- } \
- std::optional<spirv::Version> OpName::getMinVersion() { \
- return getIntegerDotProductMinVersion(); \
- } \
- std::optional<spirv::Version> OpName::getMaxVersion() { \
- return getIntegerDotProductMaxVersion(); \
- }
-
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
-SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
-
-#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
-
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
new file mode 100644
index 00000000000000..43c0beaccc0fd3
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -0,0 +1,48 @@
+//===- SPIRVParsingUtilities.cpp - MLIR SPIR-V Dialect Parsing Utils-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements common SPIR-V dialect parsing functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVParsingUtils.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+ OperationState &state,
+ StringRef attrName) {
+ // Parse an optional list of attributes staring with '['
+ if (parser.parseOptionalLSquare()) {
+ // Nothing to do
+ return success();
+ }
+
+ spirv::MemoryAccess memoryAccessAttr;
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
+ state, attrName))
+ return failure();
+
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
+ // Parse integer attribute for alignment.
+ Attribute alignmentAttr;
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseComma() ||
+ parser.parseAttribute(alignmentAttr, i32Type,
+ AttrNames::kAlignmentAttrName,
+ state.attributes)) {
+ return failure();
+ }
+ }
+ return parser.parseRSquare();
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
new file mode 100644
index 00000000000000..fd2faf4b7b333f
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -0,0 +1,156 @@
+//===- SPIRVParsingUtils.h - MLIR SPIR-V Dialect Parsing Utilities --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <type_traits>
+
+namespace mlir::spirv {
+namespace AttrNames {
+// TODO: generate these strings using ODS.
+inline constexpr char kAlignmentAttrName[] = "alignment";
+inline constexpr char kBranchWeightAttrName[] = "branch_weights";
+inline constexpr char kCallee[] = "callee";
+inline constexpr char kClusterSize[] = "cluster_size";
+inline constexpr char kControl[] = "control";
+inline constexpr char kDefaultValueAttrName[] = "default_value";
+inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
+inline constexpr char kFnNameAttrName[] = "fn";
+inline constexpr char kGroupOperationAttrName[] = "group_operation";
+inline constexpr char kIndicesAttrName[] = "indices";
+inline constexpr char kInitializerAttrName[] = "initializer";
+inline constexpr char kInterfaceAttrName[] = "interface";
+inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
+inline constexpr char kMemoryAccessAttrName[] = "memory_access";
+inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
+inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
+inline constexpr char kPackedVectorFormatAttrName[] = "format";
+inline constexpr char kSemanticsAttrName[] = "semantics";
+inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
+inline constexpr char kSpecIdAttrName[] = "spec_id";
+inline constexpr char kTypeAttrName[] = "type";
+inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
+inline constexpr char kValueAttrName[] = "value";
+inline constexpr char kValuesAttrName[] = "values";
+inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+} // namespace AttrNames
+
+template <typename Ty>
+ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
+ function_ref<StringRef(Ty)> stringifyFn) {
+ if (enumValues.empty()) {
+ return nullptr;
+ }
+ SmallVector<StringRef, 1> enumValStrs;
+ enumValStrs.reserve(enumValues.size());
+ for (auto val : enumValues) {
+ enumValStrs.emplace_back(stringifyFn(val));
+ }
+ return builder.getStrArrayAttr(enumValStrs);
+}
+
+/// Parses the next keyword in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass, typename ParserType>
+ParseResult
+parseEnumKeywordAttr(EnumClass &value, ParserType &parser,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ StringRef keyword;
+ SmallVector<NamedAttribute, 1> attr;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&keyword))
+ return failure();
+
+ if (std::optional<EnumClass> attr =
+ spirv::symbolizeEnum<EnumClass>(keyword)) {
+ value = *attr;
+ return success();
+ }
+ return parser.emitError(loc, "invalid ")
+ << attrName << " attribute specification: " << keyword;
+}
+
+/// Parses the next string attribute in `parser` as an enumerant of the given
+/// `EnumClass`.
+template <typename EnumClass>
+ParseResult
+parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
+ Attribute attrVal;
+ NamedAttrList attr;
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
+ attrName, attr))
+ return failure();
+ if (!llvm::isa<StringAttr>(attrVal))
+ return parser.emitError(loc, "expected ")
+ << attrName << " attribute specified as string";
+ auto attrOptional = spirv::symbolizeEnum<EnumClass>(
+ llvm::cast<StringAttr>(attrVal).getValue());
+ if (!attrOptional)
+ return parser.emitError(loc, "invalid ")
+ << attrName << " attribute specification: " << attrVal;
+ value = *attrOptional;
+ return success();
+}
+
+/// Parses the next string attribute in `parser` as an enumerant of the given
+/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
+/// attribute with the enum class's name as attribute name.
+template <typename EnumAttrClass,
+ typename EnumClass = typename EnumAttrClass::ValueType>
+ParseResult
+parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
+ if (parseEnumStrAttr(value, parser, attrName))
+ return failure();
+ state.addAttribute(attrName,
+ parser.getBuilder().getAttr<EnumAttrClass>(value));
+ return success();
+}
+
+/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
+/// and inserts the enumerant into `state` as an 32-bit integer attribute with
+/// the enum class's name as attribute name.
+template <typename EnumAttrClass,
+ typename EnumClass = typename EnumAttrClass::ValueType>
+ParseResult
+parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
+ OperationState &state,
+ StringRef attrName = spirv::attributeName<EnumClass>()) {
+ static_assert(std::is_enum_v<EnumClass>);
+ if (parseEnumKeywordAttr(value, parser))
+ return failure();
+ state.addAttribute(attrName,
+ parser.getBuilder().getAttr<EnumAttrClass>(value));
+ return success();
+}
+
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+/// (`[` memory-access `]`)?
+/// where:
+/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+/// integer-literal | `"NonTemporal"`
+ParseResult parseMemoryAccessAttributes(
+ OpAsmParser &parser, OperationState &state,
+ StringRef attrName = AttrNames::kMemoryAccessAttrName);
+
+} // namespace mlir::spirv
More information about the Mlir-commits
mailing list