[Mlir-commits] [mlir] ab6827f - [mlir][spirv] Extract Atomic/Cast/Group op implementation. NFC.
Jakub Kuderski
llvmlistbot at llvm.org
Thu Jul 20 08:16:07 PDT 2023
Author: Jakub Kuderski
Date: 2023-07-20T11:15:30-04:00
New Revision: ab6827f2d4321ea673ad156be713e345ef354ea0
URL: https://github.com/llvm/llvm-project/commit/ab6827f2d4321ea673ad156be713e345ef354ea0
DIFF: https://github.com/llvm/llvm-project/commit/ab6827f2d4321ea673ad156be713e345ef354ea0.diff
LOG: [mlir][spirv] Extract Atomic/Cast/Group op implementation. NFC.
Continue to work outlined in D155747 and split the main SPIR-V ops
implementation file into a few smaller and quicker to compile files.
This organization matches the op definition organizaion in `.td` files.
In this patch, extract atomic, cast/conversion, and group op
implementation into separate files.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D155777
Added:
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index 9109d41303d833..183ec617c2a38f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -51,7 +51,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
If Result Type has a
diff erent number of components than Operand, the
total number of bits in Result Type must equal the total number of bits
- in Operand. Let L be the type, either Result Type or Operand’s type,
+ in Operand. Let L be the type, either Result Type or Operand's type,
that has the larger number of components. Let S be the other type, with
the smaller number of components. The number of components in L must be
an integer multiple of the number of components in S. The first
@@ -335,17 +335,17 @@ def SPIRV_UConvertOp : SPIRV_CastOp<"UConvert",
def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
let summary = [{
- Bit pattern-preserving conversion of a pointer to
+ Bit pattern-preserving conversion of a pointer to
an unsigned scalar integer of possibly
diff erent bit width.
}];
let description = [{
Result Type must be a scalar of integer type, whose Signedness operand is 0.
- Pointer must be a physical pointer type. If the bit width of Pointer is
- smaller than that of Result Type, the conversion zero extends Pointer.
- If the bit width of Pointer is larger than that of Result Type,
- the conversion truncates Pointer.
+ Pointer must be a physical pointer type. If the bit width of Pointer is
+ smaller than that of Result Type, the conversion zero extends Pointer.
+ If the bit width of Pointer is larger than that of Result Type,
+ the conversion truncates Pointer.
For same bit width Pointer and Result Type, this is the same as OpBitcast.
@@ -359,7 +359,7 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
#### Example:
```mlir
- %1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
+ %1 = spirv.ConvertPtrToU %0 : !spirv.ptr<i32, Generic> to i32
```
}];
@@ -390,18 +390,18 @@ def SPIRV_ConvertPtrToUOp : SPIRV_Op<"ConvertPtrToU", []> {
def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
let summary = [{
- Bit pattern-preserving conversion of an unsigned scalar integer
+ Bit pattern-preserving conversion of an unsigned scalar integer
to a pointer.
}];
let description = [{
Result Type must be a physical pointer type.
- Integer Value must be a scalar of integer type, whose Signedness
- operand is 0. If the bit width of Integer Value is smaller
+ Integer Value must be a scalar of integer type, whose Signedness
+ operand is 0. If the bit width of Integer Value is smaller
than that of Result Type, the conversion zero extends Integer Value.
- If the bit width of Integer Value is larger than that of Result Type,
- the conversion truncates Integer Value.
+ If the bit width of Integer Value is larger than that of Result Type,
+ the conversion truncates Integer Value.
For same-width Integer Value and Result Type, this is the same as OpBitcast.
@@ -415,7 +415,7 @@ def SPIRV_ConvertUToPtrOp : SPIRV_Op<"ConvertUToPtr", [UnsignedOp]> {
#### Example:
```mlir
- %1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
+ %1 = spirv.ConvertUToPtr %0 : i32 to !spirv.ptr<i32, Generic>
```
}];
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
new file mode 100644
index 00000000000000..3efa955e7d8b87
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -0,0 +1,441 @@
+//===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the atomic operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+// Parses an atomic update op. If the update op does not take a value (like
+// AtomicIIncrement) `hasValue` must be false.
+static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
+ OperationState &state, bool hasValue) {
+ spirv::Scope scope;
+ spirv::MemorySemantics memoryScope;
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
+ OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
+ Type type;
+ SMLoc loc;
+ if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
+ kSemanticsAttrName) ||
+ parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
+ parser.getCurrentLocation(&loc) || parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ if (!ptrType)
+ return parser.emitError(loc, "expected pointer type");
+
+ SmallVector<Type, 2> operandTypes;
+ operandTypes.push_back(ptrType);
+ if (hasValue)
+ operandTypes.push_back(ptrType.getPointeeType());
+ if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
+ state.operands))
+ return failure();
+ return parser.addTypeToList(ptrType.getPointeeType(), state.types);
+}
+
+// Prints an atomic update op.
+static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
+ printer << " \"";
+ auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
+ printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
+ auto memorySemanticsAttr =
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
+ printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
+ << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
+}
+
+template <typename T>
+static StringRef stringifyTypeName();
+
+template <>
+StringRef stringifyTypeName<IntegerType>() {
+ return "integer";
+}
+
+template <>
+StringRef stringifyTypeName<FloatType>() {
+ return "float";
+}
+
+// Verifies an atomic update op.
+template <typename ExpectedElementType>
+static LogicalResult verifyAtomicUpdateOp(Operation *op) {
+ auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
+ auto elementType = ptrType.getPointeeType();
+ if (!llvm::isa<ExpectedElementType>(elementType))
+ return op->emitOpError() << "pointer operand must point to an "
+ << stringifyTypeName<ExpectedElementType>()
+ << " value, found " << elementType;
+
+ if (op->getNumOperands() > 1) {
+ auto valueType = op->getOperand(1).getType();
+ if (valueType != elementType)
+ return op->emitOpError("expected value to have the same type as the "
+ "pointer operand's pointee type ")
+ << elementType << ", but found " << valueType;
+ }
+ auto memorySemantics =
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+ .getValue();
+ if (failed(verifyMemorySemantics(op, memorySemantics))) {
+ return failure();
+ }
+ return success();
+}
+
+template <typename T>
+static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
+ printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
+ << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
+ << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
+ << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
+}
+
+static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
+ OperationState &state) {
+ spirv::Scope memoryScope;
+ spirv::MemorySemantics equalSemantics, unequalSemantics;
+ SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
+ Type type;
+ if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ equalSemantics, parser, state, kEqualSemanticsAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
+ parser.parseOperandList(operandInfo, 3))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ if (!ptrType)
+ return parser.emitError(loc, "expected pointer type");
+
+ if (parser.resolveOperands(
+ operandInfo,
+ {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
+ parser.getNameLoc(), state.operands))
+ return failure();
+
+ return parser.addTypeToList(ptrType.getPointeeType(), state.types);
+}
+
+template <typename T>
+static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
+ // According to the spec:
+ // "The type of Value must be the same as Result Type. The type of the value
+ // pointed to by Pointer must be the same as Result Type. This type must also
+ // match the type of Comparator."
+ if (atomOp.getType() != atomOp.getValue().getType())
+ return atomOp.emitOpError("value operand must have the same type as the op "
+ "result, but found ")
+ << atomOp.getValue().getType() << " vs " << atomOp.getType();
+
+ if (atomOp.getType() != atomOp.getComparator().getType())
+ return atomOp.emitOpError(
+ "comparator operand must have the same type as the op "
+ "result, but found ")
+ << atomOp.getComparator().getType() << " vs " << atomOp.getType();
+
+ Type pointeeType =
+ llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
+ .getPointeeType();
+ if (atomOp.getType() != pointeeType)
+ return atomOp.emitOpError(
+ "pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << pointeeType << " vs " << atomOp.getType();
+
+ // TODO: Unequal cannot be set to Release or Acquire and Release.
+ // In addition, Unequal cannot be set to a stronger memory-order then Equal.
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicAndOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicAndOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicAndOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicCompareExchangeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicCompareExchangeOp::verify() {
+ return verifyAtomicCompareExchangeImpl(*this);
+}
+
+ParseResult AtomicCompareExchangeOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseAtomicCompareExchangeImpl(parser, result);
+}
+
+void AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
+ printAtomicCompareExchangeImpl(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicCompareExchangeWeakOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicCompareExchangeWeakOp::verify() {
+ return verifyAtomicCompareExchangeImpl(*this);
+}
+
+ParseResult AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseAtomicCompareExchangeImpl(parser, result);
+}
+
+void AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
+ printAtomicCompareExchangeImpl(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicExchange
+//===----------------------------------------------------------------------===//
+
+void AtomicExchangeOp::print(OpAsmPrinter &printer) {
+ printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
+ << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
+ << " : " << getPointer().getType();
+}
+
+ParseResult AtomicExchangeOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ spirv::Scope memoryScope;
+ spirv::MemorySemantics semantics;
+ SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
+ Type type;
+ if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
+ kSemanticsAttrName) ||
+ parser.parseOperandList(operandInfo, 2))
+ return failure();
+
+ auto loc = parser.getCurrentLocation();
+ if (parser.parseColonType(type))
+ return failure();
+
+ auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
+ if (!ptrType)
+ return parser.emitError(loc, "expected pointer type");
+
+ if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
+ parser.getNameLoc(), result.operands))
+ return failure();
+
+ return parser.addTypeToList(ptrType.getPointeeType(), result.types);
+}
+
+LogicalResult AtomicExchangeOp::verify() {
+ if (getType() != getValue().getType())
+ return emitOpError("value operand must have the same type as the op "
+ "result, but found ")
+ << getValue().getType() << " vs " << getType();
+
+ Type pointeeType =
+ llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
+ if (getType() != pointeeType)
+ return emitOpError("pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << pointeeType << " vs " << getType();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicIAddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicIAddOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicIAddOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.EXT.AtomicFAddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult EXTAtomicFAddOp::verify() {
+ return verifyAtomicUpdateOp<FloatType>(getOperation());
+}
+
+ParseResult EXTAtomicFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
+ printAtomicUpdateOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicIDecrementOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicIDecrementOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicIDecrementOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, false);
+}
+
+void AtomicIDecrementOp::print(OpAsmPrinter &p) {
+ printAtomicUpdateOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicIIncrementOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicIIncrementOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicIIncrementOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, false);
+}
+
+void AtomicIIncrementOp::print(OpAsmPrinter &p) {
+ printAtomicUpdateOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicISubOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicISubOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicISubOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicOrOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicOrOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicOrOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicSMaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicSMaxOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicSMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicSMinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicSMinOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicSMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicUMaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicUMaxOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicUMaxOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicUMinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicUMinOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicUMinOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+//===----------------------------------------------------------------------===//
+// spirv.AtomicXorOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult AtomicXorOp::verify() {
+ return verifyAtomicUpdateOp<IntegerType>(getOperation());
+}
+
+ParseResult AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) {
+ return parseAtomicUpdateOp(parser, result, true);
+}
+
+void AtomicXorOp::print(OpAsmPrinter &p) { printAtomicUpdateOp(*this, p); }
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index 70e2eb786e397c..d36e2ad8a73e85 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,7 +3,10 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
+ AtomicOps.cpp
+ CastOps.cpp
CooperativeMatrixOps.cpp
+ GroupOps.cpp
IntegerDotProductOps.cpp
JointMatrixOps.cpp
SPIRVAttributes.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
new file mode 100644
index 00000000000000..f24da2ca5c3f24
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -0,0 +1,339 @@
+//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the cast and conversion operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+static LogicalResult verifyCastOp(Operation *op,
+ bool requireSameBitWidth = true,
+ bool skipBitWidthCheck = false) {
+ // Some CastOps have no limit on bit widths for result and operand type.
+ if (skipBitWidthCheck)
+ return success();
+
+ Type operandType = op->getOperand(0).getType();
+ Type resultType = op->getResult(0).getType();
+
+ // ODS checks that result type and operand type have the same shape. Check
+ // that composite types match and extract the element types, if any.
+ using TypePair = std::pair<Type, Type>;
+ auto [operandElemTy, resultElemTy] =
+ TypeSwitch<Type, TypePair>(operandType)
+ .Case<VectorType, spirv::CooperativeMatrixType,
+ spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
+ [resultType](auto concreteOperandTy) -> TypePair {
+ if (auto concreteResultTy =
+ dyn_cast<decltype(concreteOperandTy)>(resultType)) {
+ return {concreteOperandTy.getElementType(),
+ concreteResultTy.getElementType()};
+ }
+ return {};
+ })
+ .Default([resultType](Type operandType) -> TypePair {
+ return {operandType, resultType};
+ });
+
+ if (!operandElemTy || !resultElemTy)
+ return op->emitOpError("incompatible operand and result types");
+
+ unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
+ unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
+ bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
+
+ if (requireSameBitWidth) {
+ if (!isSameBitWidth) {
+ return op->emitOpError(
+ "expected the same bit widths for operand type and result "
+ "type, but provided ")
+ << operandElemTy << " and " << resultElemTy;
+ }
+ return success();
+ }
+
+ if (isSameBitWidth) {
+ return op->emitOpError(
+ "expected the
diff erent bit widths for operand type and result "
+ "type, but provided ")
+ << operandElemTy << " and " << resultElemTy;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.BitcastOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult BitcastOp::verify() {
+ // TODO: The SPIR-V spec validation rules are
diff erent for
diff erent
+ // versions.
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ if (operandType == resultType) {
+ return emitError("result type must be
diff erent from operand type");
+ }
+ if (llvm::isa<spirv::PointerType>(operandType) &&
+ !llvm::isa<spirv::PointerType>(resultType)) {
+ return emitError(
+ "unhandled bit cast conversion from pointer type to non-pointer type");
+ }
+ if (!llvm::isa<spirv::PointerType>(operandType) &&
+ llvm::isa<spirv::PointerType>(resultType)) {
+ return emitError(
+ "unhandled bit cast conversion from non-pointer type to pointer type");
+ }
+ auto operandBitWidth = getBitWidth(operandType);
+ auto resultBitWidth = getBitWidth(resultType);
+ if (operandBitWidth != resultBitWidth) {
+ return emitOpError("mismatch in result type bitwidth ")
+ << resultBitWidth << " and operand type bitwidth "
+ << operandBitWidth;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertPtrToUOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertPtrToUOp::verify() {
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
+ if (!resultType || !resultType.isSignlessInteger())
+ return emitError("result must be a scalar type of unsigned integer");
+ auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule)
+ return success();
+ auto addressingModel = spirvModule.getAddressingModel();
+ if ((addressingModel == spirv::AddressingModel::Logical) ||
+ (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
+ operandType.getStorageClass() !=
+ spirv::StorageClass::PhysicalStorageBuffer))
+ return emitError("operand must be a physical pointer");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertUToPtrOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertUToPtrOp::verify() {
+ auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+ if (!operandType || !operandType.isSignlessInteger())
+ return emitError("result must be a scalar type of unsigned integer");
+ auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
+ if (!spirvModule)
+ return success();
+ auto addressingModel = spirvModule.getAddressingModel();
+ if ((addressingModel == spirv::AddressingModel::Logical) ||
+ (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
+ resultType.getStorageClass() !=
+ spirv::StorageClass::PhysicalStorageBuffer))
+ return emitError("result must be a physical pointer");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.PtrCastToGenericOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult PtrCastToGenericOp::verify() {
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+
+ spirv::StorageClass operandStorage = operandType.getStorageClass();
+ if (operandStorage != spirv::StorageClass::Workgroup &&
+ operandStorage != spirv::StorageClass::CrossWorkgroup &&
+ operandStorage != spirv::StorageClass::Function)
+ return emitError("pointer must point to the Workgroup, CrossWorkgroup"
+ ", or Function Storage Class");
+
+ spirv::StorageClass resultStorage = resultType.getStorageClass();
+ if (resultStorage != spirv::StorageClass::Generic)
+ return emitError("result type must be of storage class Generic");
+
+ Type operandPointeeType = operandType.getPointeeType();
+ Type resultPointeeType = resultType.getPointeeType();
+ if (operandPointeeType != resultPointeeType)
+ return emitOpError("pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << operandPointeeType << " vs " << resultPointeeType;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GenericCastToPtrOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GenericCastToPtrOp::verify() {
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+
+ spirv::StorageClass operandStorage = operandType.getStorageClass();
+ if (operandStorage != spirv::StorageClass::Generic)
+ return emitError("pointer type must be of storage class Generic");
+
+ spirv::StorageClass resultStorage = resultType.getStorageClass();
+ if (resultStorage != spirv::StorageClass::Workgroup &&
+ resultStorage != spirv::StorageClass::CrossWorkgroup &&
+ resultStorage != spirv::StorageClass::Function)
+ return emitError("result must point to the Workgroup, CrossWorkgroup, "
+ "or Function Storage Class");
+
+ Type operandPointeeType = operandType.getPointeeType();
+ Type resultPointeeType = resultType.getPointeeType();
+ if (operandPointeeType != resultPointeeType)
+ return emitOpError("pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << operandPointeeType << " vs " << resultPointeeType;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GenericCastToPtrExplicitOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GenericCastToPtrExplicitOp::verify() {
+ auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
+ auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
+
+ spirv::StorageClass operandStorage = operandType.getStorageClass();
+ if (operandStorage != spirv::StorageClass::Generic)
+ return emitError("pointer type must be of storage class Generic");
+
+ spirv::StorageClass resultStorage = resultType.getStorageClass();
+ if (resultStorage != spirv::StorageClass::Workgroup &&
+ resultStorage != spirv::StorageClass::CrossWorkgroup &&
+ resultStorage != spirv::StorageClass::Function)
+ return emitError("result must point to the Workgroup, CrossWorkgroup, "
+ "or Function Storage Class");
+
+ Type operandPointeeType = operandType.getPointeeType();
+ Type resultPointeeType = resultType.getPointeeType();
+ if (operandPointeeType != resultPointeeType)
+ return emitOpError("pointer operand's pointee type must have the same "
+ "as the op result type, but found ")
+ << operandPointeeType << " vs " << resultPointeeType;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertFToSOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertFToSOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false,
+ /*skipBitWidthCheck=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertFToUOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertFToUOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false,
+ /*skipBitWidthCheck=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertSToFOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertSToFOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false,
+ /*skipBitWidthCheck=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.ConvertUToFOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConvertUToFOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false,
+ /*skipBitWidthCheck=*/true);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertBF16ToFOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELConvertBF16ToFOp::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INTELConvertFToBF16Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult INTELConvertFToBF16Op::verify() {
+ auto operandType = getOperand().getType();
+ auto resultType = getResult().getType();
+ // ODS checks that vector result type and vector operand type have the same
+ // shape.
+ if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
+ unsigned operandNumElements = vectorType.getNumElements();
+ unsigned resultNumElements =
+ llvm::cast<VectorType>(resultType).getNumElements();
+ if (operandNumElements != resultNumElements) {
+ return emitOpError(
+ "operand and result must have same number of elements");
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.FConvertOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::FConvertOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.SConvertOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::SConvertOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.UConvertOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::UConvertOp::verify() {
+ return verifyCastOp(*this, /*requireSameBitWidth=*/false);
+}
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
new file mode 100644
index 00000000000000..84bf3de2f43aab
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -0,0 +1,407 @@
+//===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines the group operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+
+#include "SPIRVOpUtils.h"
+#include "SPIRVParsingUtils.h"
+
+using namespace mlir::spirv::AttrNames;
+
+namespace mlir::spirv {
+
+static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
+ OperationState &state) {
+ spirv::Scope executionScope;
+ GroupOperation groupOperation;
+ OpAsmParser::UnresolvedOperand valueInfo;
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
+ kExecutionScopeAttrName) ||
+ spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
+ kGroupOperationAttrName) ||
+ parser.parseOperand(valueInfo))
+ return failure();
+
+ std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
+ if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
+ clusterSizeInfo = OpAsmParser::UnresolvedOperand();
+ if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
+ parser.parseRParen())
+ return failure();
+ }
+
+ Type resultType;
+ if (parser.parseColonType(resultType))
+ return failure();
+
+ if (parser.resolveOperand(valueInfo, resultType, state.operands))
+ return failure();
+
+ if (clusterSizeInfo) {
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
+ return failure();
+ }
+
+ return parser.addTypeToList(resultType, state.types);
+}
+
+static void printGroupNonUniformArithmeticOp(Operation *groupOp,
+ OpAsmPrinter &printer) {
+ printer
+ << " \""
+ << stringifyScope(
+ groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ .getValue())
+ << "\" \""
+ << stringifyGroupOperation(
+ groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ .getValue())
+ << "\" " << groupOp->getOperand(0);
+
+ if (groupOp->getNumOperands() > 1)
+ printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
+ printer << " : " << groupOp->getResult(0).getType();
+}
+
+static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
+ spirv::Scope scope =
+ groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ .getValue();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return groupOp->emitOpError(
+ "execution scope must be 'Workgroup' or 'Subgroup'");
+
+ GroupOperation operation =
+ groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ .getValue();
+ if (operation == GroupOperation::ClusteredReduce &&
+ groupOp->getNumOperands() == 1)
+ return groupOp->emitOpError("cluster size operand must be provided for "
+ "'ClusteredReduce' group operation");
+ if (groupOp->getNumOperands() > 1) {
+ Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
+ int32_t clusterSize = 0;
+
+ // TODO: support specialization constant here.
+ if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
+ return groupOp->emitOpError(
+ "cluster size operand must come from a constant op");
+
+ if (!llvm::isPowerOf2_32(clusterSize))
+ return groupOp->emitOpError(
+ "cluster size operand must be a power of two");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupBroadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupBroadcastOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
+ if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
+ return emitOpError("localid is a vector and can be with only "
+ " 2 or 3 components, actual number is ")
+ << localIdTy.getNumElements();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBallotOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBallotOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformBroadcast
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformBroadcastOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ // SPIR-V spec: "Before version 1.5, Id must come from a
+ // constant instruction.
+ auto targetEnv = spirv::getDefaultTargetEnv(getContext());
+ if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
+ targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
+
+ if (targetEnv.getVersion() < spirv::Version::V_1_5) {
+ auto *idOp = getId().getDefiningOp();
+ if (!idOp || !isa<spirv::ConstantOp, // for normal constant
+ spirv::ReferenceOfOp>(idOp)) // for spec constant
+ return emitOpError("id must be the result of a constant op");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformShuffle*
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
+ spirv::Scope scope = op.getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ if (op.getOperands().back().getType().isSignedInteger())
+ return op.emitOpError("second operand must be a singless/unsigned integer");
+
+ return success();
+}
+
+LogicalResult GroupNonUniformShuffleOp::verify() {
+ return verifyGroupNonUniformShuffleOp(*this);
+}
+LogicalResult GroupNonUniformShuffleDownOp::verify() {
+ return verifyGroupNonUniformShuffleOp(*this);
+}
+LogicalResult GroupNonUniformShuffleUpOp::verify() {
+ return verifyGroupNonUniformShuffleOp(*this);
+}
+LogicalResult GroupNonUniformShuffleXorOp::verify() {
+ return verifyGroupNonUniformShuffleOp(*this);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformElectOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformElectOp::verify() {
+ spirv::Scope scope = getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformFAddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformFAddOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformFMaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformFMaxOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformFMinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformFMinOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformFMulOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformFMulOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformIAddOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformIAddOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformIMulOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformIMulOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformSMaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformSMaxOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformSMinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformSMinOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformUMaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformUMaxOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GroupNonUniformUMinOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult GroupNonUniformUMinOp::verify() {
+ return verifyGroupNonUniformArithmeticOp(*this);
+}
+
+ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseGroupNonUniformArithmeticOp(parser, result);
+}
+
+void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
+ printGroupNonUniformArithmeticOp(*this, p);
+}
+
+//===----------------------------------------------------------------------===//
+// Group op verification
+//===----------------------------------------------------------------------===//
+
+template <typename Op>
+static LogicalResult verifyGroupOp(Op op) {
+ spirv::Scope scope = op.getExecutionScope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
+LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
+
+LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
+
+} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
index efe596cd725c5e..fff06bb5a7f207 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpUtils.h
@@ -29,4 +29,9 @@ inline unsigned getBitWidth(Type type) {
llvm_unreachable("unhandled bit width computation for type");
}
+LogicalResult extractValueFromConstOp(Operation *op, int32_t &value);
+
+LogicalResult verifyMemorySemantics(Operation *op,
+ spirv::MemorySemantics memorySemantics);
+
} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2184cec953fb05..47ffdc1cdad18f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,7 +34,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <numeric>
#include <type_traits>
@@ -112,7 +111,7 @@ static bool isDirectInModuleLikeOp(Operation *op) {
return op && op->hasTrait<OpTrait::SymbolTable>();
}
-static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
+LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
if (!constOp) {
return failure();
@@ -293,61 +292,6 @@ static LogicalResult verifyImageOperands(Op imageOp,
return success();
}
-static LogicalResult verifyCastOp(Operation *op,
- bool requireSameBitWidth = true,
- bool skipBitWidthCheck = false) {
- // Some CastOps have no limit on bit widths for result and operand type.
- if (skipBitWidthCheck)
- return success();
-
- Type operandType = op->getOperand(0).getType();
- Type resultType = op->getResult(0).getType();
-
- // ODS checks that result type and operand type have the same shape. Check
- // that composite types match and extract the element types, if any.
- using TypePair = std::pair<Type, Type>;
- auto [operandElemTy, resultElemTy] =
- TypeSwitch<Type, TypePair>(operandType)
- .Case<VectorType, spirv::CooperativeMatrixType,
- spirv::CooperativeMatrixNVType, spirv::JointMatrixINTELType>(
- [resultType](auto concreteOperandTy) -> TypePair {
- if (auto concreteResultTy =
- dyn_cast<decltype(concreteOperandTy)>(resultType)) {
- return {concreteOperandTy.getElementType(),
- concreteResultTy.getElementType()};
- }
- return {};
- })
- .Default([resultType](Type operandType) -> TypePair {
- return {operandType, resultType};
- });
-
- if (!operandElemTy || !resultElemTy)
- return op->emitOpError("incompatible operand and result types");
-
- unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
- unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
- bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
-
- if (requireSameBitWidth) {
- if (!isSameBitWidth) {
- return op->emitOpError(
- "expected the same bit widths for operand type and result "
- "type, but provided ")
- << operandElemTy << " and " << resultElemTy;
- }
- return success();
- }
-
- if (isSameBitWidth) {
- return op->emitOpError(
- "expected the
diff erent bit widths for operand type and result "
- "type, but provided ")
- << operandElemTy << " and " << resultElemTy;
- }
- return success();
-}
-
template <typename MemoryOpTy>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// ODS checks for attributes values. Just need to verify that if the
@@ -432,8 +376,9 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
-static LogicalResult
-verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
+LogicalResult
+spirv::verifyMemorySemantics(Operation *op,
+ spirv::MemorySemantics memorySemantics) {
// According to the SPIR-V specification:
// "Despite being a mask and allowing multiple bits to be combined, it is
// invalid for more than one of these four bits to be set: Acquire, Release,
@@ -672,178 +617,6 @@ static void printArithmeticExtendedBinaryOp(Operation *op,
printer << " : " << op->getResultTypes().front();
}
-//===----------------------------------------------------------------------===//
-// Common parsers and printers
-//===----------------------------------------------------------------------===//
-
-// Parses an atomic update op. If the update op does not take a value (like
-// AtomicIIncrement) `hasValue` must be false.
-static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
- OperationState &state, bool hasValue) {
- spirv::Scope scope;
- spirv::MemorySemantics memoryScope;
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
- OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
- Type type;
- SMLoc loc;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
- kMemoryScopeAttrName) ||
- spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
- memoryScope, parser, state, kSemanticsAttrName) ||
- parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
- parser.getCurrentLocation(&loc) || parser.parseColonType(type))
- return failure();
-
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
- if (!ptrType)
- return parser.emitError(loc, "expected pointer type");
-
- SmallVector<Type, 2> operandTypes;
- operandTypes.push_back(ptrType);
- if (hasValue)
- operandTypes.push_back(ptrType.getPointeeType());
- if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
- state.operands))
- return failure();
- return parser.addTypeToList(ptrType.getPointeeType(), state.types);
-}
-
-// Prints an atomic update op.
-static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
- printer << " \"";
- auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
- printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
- auto memorySemanticsAttr =
- op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
- printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
- << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
-}
-
-template <typename T>
-static StringRef stringifyTypeName();
-
-template <>
-StringRef stringifyTypeName<IntegerType>() {
- return "integer";
-}
-
-template <>
-StringRef stringifyTypeName<FloatType>() {
- return "float";
-}
-
-// Verifies an atomic update op.
-template <typename ExpectedElementType>
-static LogicalResult verifyAtomicUpdateOp(Operation *op) {
- auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
- auto elementType = ptrType.getPointeeType();
- if (!llvm::isa<ExpectedElementType>(elementType))
- return op->emitOpError() << "pointer operand must point to an "
- << stringifyTypeName<ExpectedElementType>()
- << " value, found " << elementType;
-
- if (op->getNumOperands() > 1) {
- auto valueType = op->getOperand(1).getType();
- if (valueType != elementType)
- return op->emitOpError("expected value to have the same type as the "
- "pointer operand's pointee type ")
- << elementType << ", but found " << valueType;
- }
- auto memorySemantics =
- op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
- .getValue();
- if (failed(verifyMemorySemantics(op, memorySemantics))) {
- return failure();
- }
- return success();
-}
-
-static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
- OperationState &state) {
- spirv::Scope executionScope;
- spirv::GroupOperation groupOperation;
- OpAsmParser::UnresolvedOperand valueInfo;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- kExecutionScopeAttrName) ||
- spirv::parseEnumStrAttr<spirv::GroupOperationAttr>(
- groupOperation, parser, state, kGroupOperationAttrName) ||
- parser.parseOperand(valueInfo))
- return failure();
-
- std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
- if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
- clusterSizeInfo = OpAsmParser::UnresolvedOperand();
- if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
- parser.parseRParen())
- return failure();
- }
-
- Type resultType;
- if (parser.parseColonType(resultType))
- return failure();
-
- if (parser.resolveOperand(valueInfo, resultType, state.operands))
- return failure();
-
- if (clusterSizeInfo) {
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
- return failure();
- }
-
- return parser.addTypeToList(resultType, state.types);
-}
-
-static void printGroupNonUniformArithmeticOp(Operation *groupOp,
- OpAsmPrinter &printer) {
- printer
- << " \""
- << stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
- .getValue())
- << "\" \""
- << stringifyGroupOperation(groupOp
- ->getAttrOfType<spirv::GroupOperationAttr>(
- kGroupOperationAttrName)
- .getValue())
- << "\" " << groupOp->getOperand(0);
-
- if (groupOp->getNumOperands() > 1)
- printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
- printer << " : " << groupOp->getResult(0).getType();
-}
-
-static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
- spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
- .getValue();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return groupOp->emitOpError(
- "execution scope must be 'Workgroup' or 'Subgroup'");
-
- spirv::GroupOperation operation =
- groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
- .getValue();
- if (operation == spirv::GroupOperation::ClusteredReduce &&
- groupOp->getNumOperands() == 1)
- return groupOp->emitOpError("cluster size operand must be provided for "
- "'ClusteredReduce' group operation");
- if (groupOp->getNumOperands() > 1) {
- Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
- int32_t clusterSize = 0;
-
- // TODO: support specialization constant here.
- if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
- return groupOp->emitOpError(
- "cluster size operand must come from a constant op");
-
- if (!llvm::isPowerOf2_32(clusterSize))
- return groupOp->emitOpError(
- "cluster size operand must be a power of two");
- }
- return success();
-}
-
/// Result of a logical op must be a scalar or vector of boolean type.
static Type getUnaryOpResultType(Type operandType) {
Builder builder(operandType.getContext());
@@ -901,7 +674,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
// TODO: this should be relaxed to allow
// integer literals of other bitwidths.
- if (failed(extractValueFromConstOp(op, index))) {
+ if (failed(spirv::extractValueFromConstOp(op, index))) {
emitError(
baseLoc,
"'spirv.AccessChain' index must be an integer spirv.Constant to "
@@ -1032,514 +805,6 @@ LogicalResult spirv::AddressOfOp::verify() {
return success();
}
-template <typename T>
-static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
- printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
- << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
- << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
- << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
-}
-
-static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
- OperationState &state) {
- spirv::Scope memoryScope;
- spirv::MemorySemantics equalSemantics, unequalSemantics;
- SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
- Type type;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
- kMemoryScopeAttrName) ||
- spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
- equalSemantics, parser, state, kEqualSemanticsAttrName) ||
- spirv::parseEnumStrAttr<spirv::MemorySemanticsAttr>(
- unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
- parser.parseOperandList(operandInfo, 3))
- return failure();
-
- auto loc = parser.getCurrentLocation();
- if (parser.parseColonType(type))
- return failure();
-
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
- if (!ptrType)
- return parser.emitError(loc, "expected pointer type");
-
- if (parser.resolveOperands(
- operandInfo,
- {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
- parser.getNameLoc(), state.operands))
- return failure();
-
- return parser.addTypeToList(ptrType.getPointeeType(), state.types);
-}
-
-template <typename T>
-static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
- // According to the spec:
- // "The type of Value must be the same as Result Type. The type of the value
- // pointed to by Pointer must be the same as Result Type. This type must also
- // match the type of Comparator."
- if (atomOp.getType() != atomOp.getValue().getType())
- return atomOp.emitOpError("value operand must have the same type as the op "
- "result, but found ")
- << atomOp.getValue().getType() << " vs " << atomOp.getType();
-
- if (atomOp.getType() != atomOp.getComparator().getType())
- return atomOp.emitOpError(
- "comparator operand must have the same type as the op "
- "result, but found ")
- << atomOp.getComparator().getType() << " vs " << atomOp.getType();
-
- Type pointeeType =
- llvm::cast<spirv::PointerType>(atomOp.getPointer().getType())
- .getPointeeType();
- if (atomOp.getType() != pointeeType)
- return atomOp.emitOpError(
- "pointer operand's pointee type must have the same "
- "as the op result type, but found ")
- << pointeeType << " vs " << atomOp.getType();
-
- // TODO: Unequal cannot be set to Release or Acquire and Release.
- // In addition, Unequal cannot be set to a stronger memory-order then Equal.
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicAndOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicAndOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicCompareExchangeOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicCompareExchangeOp::verify() {
- return ::verifyAtomicCompareExchangeImpl(*this);
-}
-
-ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicCompareExchangeImpl(parser, result);
-}
-void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
- ::printAtomicCompareExchangeImpl(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicCompareExchangeWeakOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
- return ::verifyAtomicCompareExchangeImpl(*this);
-}
-
-ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicCompareExchangeImpl(parser, result);
-}
-void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
- ::printAtomicCompareExchangeImpl(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicExchange
-//===----------------------------------------------------------------------===//
-
-void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
- printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
- << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
- << " : " << getPointer().getType();
-}
-
-ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
- OperationState &result) {
- spirv::Scope memoryScope;
- spirv::MemorySemantics semantics;
- SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
- Type type;
- if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
- kMemoryScopeAttrName) ||
- parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
- kSemanticsAttrName) ||
- parser.parseOperandList(operandInfo, 2))
- return failure();
-
- auto loc = parser.getCurrentLocation();
- if (parser.parseColonType(type))
- return failure();
-
- auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
- if (!ptrType)
- return parser.emitError(loc, "expected pointer type");
-
- if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
- parser.getNameLoc(), result.operands))
- return failure();
-
- return parser.addTypeToList(ptrType.getPointeeType(), result.types);
-}
-
-LogicalResult spirv::AtomicExchangeOp::verify() {
- if (getType() != getValue().getType())
- return emitOpError("value operand must have the same type as the op "
- "result, but found ")
- << getValue().getType() << " vs " << getType();
-
- Type pointeeType =
- llvm::cast<spirv::PointerType>(getPointer().getType()).getPointeeType();
- if (getType() != pointeeType)
- return emitOpError("pointer operand's pointee type must have the same "
- "as the op result type, but found ")
- << pointeeType << " vs " << getType();
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicIAddOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicIAddOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.EXT.AtomicFAddOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::EXTAtomicFAddOp::verify() {
- return ::verifyAtomicUpdateOp<FloatType>(getOperation());
-}
-
-ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicIDecrementOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicIDecrementOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, false);
-}
-void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicIIncrementOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicIIncrementOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, false);
-}
-void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicISubOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicISubOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicOrOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicOrOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicSMaxOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicSMaxOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicSMinOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicSMinOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicUMaxOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicUMaxOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicUMinOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicUMinOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.AtomicXorOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::AtomicXorOp::verify() {
- return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
-}
-
-ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return ::parseAtomicUpdateOp(parser, result, true);
-}
-void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
- ::printAtomicUpdateOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.BitcastOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::BitcastOp::verify() {
- // TODO: The SPIR-V spec validation rules are
diff erent for
diff erent
- // versions.
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- if (operandType == resultType) {
- return emitError("result type must be
diff erent from operand type");
- }
- if (llvm::isa<spirv::PointerType>(operandType) &&
- !llvm::isa<spirv::PointerType>(resultType)) {
- return emitError(
- "unhandled bit cast conversion from pointer type to non-pointer type");
- }
- if (!llvm::isa<spirv::PointerType>(operandType) &&
- llvm::isa<spirv::PointerType>(resultType)) {
- return emitError(
- "unhandled bit cast conversion from non-pointer type to pointer type");
- }
- auto operandBitWidth = getBitWidth(operandType);
- auto resultBitWidth = getBitWidth(resultType);
- if (operandBitWidth != resultBitWidth) {
- return emitOpError("mismatch in result type bitwidth ")
- << resultBitWidth << " and operand type bitwidth "
- << operandBitWidth;
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ConvertPtrToUOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertPtrToUOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
- if (!resultType || !resultType.isSignlessInteger())
- return emitError("result must be a scalar type of unsigned integer");
- auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
- if (!spirvModule)
- return success();
- auto addressingModel = spirvModule.getAddressingModel();
- if ((addressingModel == spirv::AddressingModel::Logical) ||
- (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
- operandType.getStorageClass() !=
- spirv::StorageClass::PhysicalStorageBuffer))
- return emitError("operand must be a physical pointer");
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ConvertUToPtrOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertUToPtrOp::verify() {
- auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
- if (!operandType || !operandType.isSignlessInteger())
- return emitError("result must be a scalar type of unsigned integer");
- auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
- if (!spirvModule)
- return success();
- auto addressingModel = spirvModule.getAddressingModel();
- if ((addressingModel == spirv::AddressingModel::Logical) ||
- (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
- resultType.getStorageClass() !=
- spirv::StorageClass::PhysicalStorageBuffer))
- return emitError("result must be a physical pointer");
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.PtrCastToGenericOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::PtrCastToGenericOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
-
- spirv::StorageClass operandStorage = operandType.getStorageClass();
- if (operandStorage != spirv::StorageClass::Workgroup &&
- operandStorage != spirv::StorageClass::CrossWorkgroup &&
- operandStorage != spirv::StorageClass::Function)
- return emitError("pointer must point to the Workgroup, CrossWorkgroup"
- ", or Function Storage Class");
-
- spirv::StorageClass resultStorage = resultType.getStorageClass();
- if (resultStorage != spirv::StorageClass::Generic)
- return emitError("result type must be of storage class Generic");
-
- Type operandPointeeType = operandType.getPointeeType();
- Type resultPointeeType = resultType.getPointeeType();
- if (operandPointeeType != resultPointeeType)
- return emitOpError("pointer operand's pointee type must have the same "
- "as the op result type, but found ")
- << operandPointeeType << " vs " << resultPointeeType;
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GenericCastToPtrOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GenericCastToPtrOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
-
- spirv::StorageClass operandStorage = operandType.getStorageClass();
- if (operandStorage != spirv::StorageClass::Generic)
- return emitError("pointer type must be of storage class Generic");
-
- spirv::StorageClass resultStorage = resultType.getStorageClass();
- if (resultStorage != spirv::StorageClass::Workgroup &&
- resultStorage != spirv::StorageClass::CrossWorkgroup &&
- resultStorage != spirv::StorageClass::Function)
- return emitError("result must point to the Workgroup, CrossWorkgroup, "
- "or Function Storage Class");
-
- Type operandPointeeType = operandType.getPointeeType();
- Type resultPointeeType = resultType.getPointeeType();
- if (operandPointeeType != resultPointeeType)
- return emitOpError("pointer operand's pointee type must have the same "
- "as the op result type, but found ")
- << operandPointeeType << " vs " << resultPointeeType;
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GenericCastToPtrExplicitOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
- auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
- auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
-
- spirv::StorageClass operandStorage = operandType.getStorageClass();
- if (operandStorage != spirv::StorageClass::Generic)
- return emitError("pointer type must be of storage class Generic");
-
- spirv::StorageClass resultStorage = resultType.getStorageClass();
- if (resultStorage != spirv::StorageClass::Workgroup &&
- resultStorage != spirv::StorageClass::CrossWorkgroup &&
- resultStorage != spirv::StorageClass::Function)
- return emitError("result must point to the Workgroup, CrossWorkgroup, "
- "or Function Storage Class");
-
- Type operandPointeeType = operandType.getPointeeType();
- Type resultPointeeType = resultType.getPointeeType();
- if (operandPointeeType != resultPointeeType)
- return emitOpError("pointer operand's pointee type must have the same "
- "as the op result type, but found ")
- << operandPointeeType << " vs " << resultPointeeType;
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.BranchOp
//===----------------------------------------------------------------------===//
@@ -2065,84 +1330,6 @@ LogicalResult spirv::ControlBarrierOp::verify() {
return verifyMemorySemantics(getOperation(), getMemorySemantics());
}
-//===----------------------------------------------------------------------===//
-// spirv.ConvertFToSOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertFToSOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false,
- /*skipBitWidthCheck=*/true);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ConvertFToUOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertFToUOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false,
- /*skipBitWidthCheck=*/true);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ConvertSToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertSToFOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false,
- /*skipBitWidthCheck=*/true);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.ConvertUToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::ConvertUToFOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false,
- /*skipBitWidthCheck=*/true);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertBF16ToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertFToBF16Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::INTELConvertFToBF16Op::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//
@@ -2253,30 +1440,6 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
});
}
-//===----------------------------------------------------------------------===//
-// spirv.FConvertOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::FConvertOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.SConvertOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::SConvertOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.UConvertOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::UConvertOp::verify() {
- return verifyCastOp(*this, /*requireSameBitWidth=*/false);
-}
-
//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//
@@ -2641,90 +1804,6 @@ LogicalResult spirv::GlobalVariableOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.GroupBroadcast
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupBroadcastOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType()))
- if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
- return emitOpError("localid is a vector and can be with only "
- " 2 or 3 components, actual number is ")
- << localIdTy.getNumElements();
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBallotOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformBallotOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformBroadcast
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- // SPIR-V spec: "Before version 1.5, Id must come from a
- // constant instruction.
- auto targetEnv = spirv::getDefaultTargetEnv(getContext());
- if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
- targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
-
- if (targetEnv.getVersion() < spirv::Version::V_1_5) {
- auto *idOp = getId().getDefiningOp();
- if (!idOp || !isa<spirv::ConstantOp, // for normal constant
- spirv::ReferenceOfOp>(idOp)) // for spec constant
- return emitOpError("id must be the result of a constant op");
- }
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformShuffle*
-//===----------------------------------------------------------------------===//
-
-template <typename OpTy>
-static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
- spirv::Scope scope = op.getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- if (op.getOperands().back().getType().isSignedInteger())
- return op.emitOpError("second operand must be a singless/unsigned integer");
-
- return success();
-}
-
-LogicalResult spirv::GroupNonUniformShuffleOp::verify() {
- return verifyGroupNonUniformShuffleOp(*this);
-}
-LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() {
- return verifyGroupNonUniformShuffleOp(*this);
-}
-LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() {
- return verifyGroupNonUniformShuffleOp(*this);
-}
-LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
- return verifyGroupNonUniformShuffleOp(*this);
-}
-
//===----------------------------------------------------------------------===//
// spirv.INTEL.SubgroupBlockRead
//===----------------------------------------------------------------------===//
@@ -2803,178 +1882,6 @@ LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformElectOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformElectOp::verify() {
- spirv::Scope scope = getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformFAddOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformFAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformFMaxOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformFMinOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformFMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformFMulOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformFMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformIAddOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformIAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformIMulOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformIMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformSMaxOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformSMinOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformSMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformUMaxOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GroupNonUniformUMinOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GroupNonUniformUMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
-}
-
-ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
- OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
-}
-void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
-}
-
//===----------------------------------------------------------------------===//
// spirv.IAddCarryOp
//===----------------------------------------------------------------------===//
@@ -4514,39 +3421,6 @@ LogicalResult spirv::VectorTimesScalarOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// Group ops
-//===----------------------------------------------------------------------===//
-
-template <typename Op>
-static LogicalResult verifyGroupOp(Op op) {
- spirv::Scope scope = op.getExecutionScope();
- if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
- return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
-
- return success();
-}
-
-LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
-
-LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
-
// TableGen'erated operation interfaces for querying versions, extensions, and
// capabilities.
#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
More information about the Mlir-commits
mailing list