[Mlir-commits] [mlir] b8bea83 - [mlir][spirv] Refactor vendor op definitions
Jakub Kuderski
llvmlistbot at llvm.org
Tue Sep 6 10:36:00 PDT 2022
Author: Jakub Kuderski
Date: 2022-09-06T13:35:08-04:00
New Revision: b8bea837f34534c36fd07477e51b41dd222fb869
URL: https://github.com/llvm/llvm-project/commit/b8bea837f34534c36fd07477e51b41dd222fb869
DIFF: https://github.com/llvm/llvm-project/commit/b8bea837f34534c36fd07477e51b41dd222fb869.diff
LOG: [mlir][spirv] Refactor vendor op definitions
Use dedicated vendor op classes/categories. This is so that we can later
change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`.
Issue: https://github.com/llvm/llvm-project/issues/56863
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/utils/spirv/define_inst.sh
mlir/utils/spirv/gen_spirv_dialect.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 7654c528b75de..ba9d5a7223297 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
// -----
-def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
+def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
let summary = "TBD";
let description = [{
@@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
3) store the New Value back through Pointer.
- The instruction’s result is the Original Value.
+ The instruction's result is the Original Value.
Result Type must be a floating-point type scalar.
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 413af73c506de..bff4543c526e1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -15,7 +15,7 @@
// -----
-def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
+def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
[NoSideEffect]> {
let summary = "See extension SPV_NV_cooperative_matrix";
@@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
// -----
-def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
+def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{
@@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
// -----
-def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
+def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_NV_cooperative_matrix";
@@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
// -----
-def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
+def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
let summary = "See extension SPV_NV_cooperative_matrix";
let description = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
index 32a9aafe32f37..e640824133226 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
@@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
// -----
-def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
+def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
let summary = "See extension SPV_KHR_shader_ballot";
let description = [{
@@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
// -----
-def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
+def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{
@@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
// -----
-def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
+def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
let summary = "See extension SPV_INTEL_subgroups";
let description = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
index aa45ef80b5e94..e95be56f26750 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
@@ -15,12 +15,12 @@
// -----
-def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
+def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
[NoSideEffect]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
- Return number of components owned by the current work-item in
+ Return number of components owned by the current work-item in
a joint matrix.
Result Type must be an 32-bit unsigned integer type scalar.
@@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
// -----
-def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
+def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
@@ -68,26 +68,26 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
- Pointer is the pointer to load through. It specifies start of memory region where
+ Pointer is the pointer to load through. It specifies start of memory region where
elements of the matrix are stored and arranged according to Layout.
- Stride is the number of elements in memory between beginnings of successive rows,
+ Stride is the number of elements in memory between beginnings of successive rows,
columns (or words) in the result. It must be a scalar integer type.
- Layout indicates how the values loaded from memory are arranged. It must be the
+ Layout indicates how the values loaded from memory are arranged. It must be the
result of a constant instruction.
- Scope is syncronization scope for operation on the matrix. It must be the result
+ Scope is syncronization scope for operation on the matrix. It must be the result
of a constant instruction with scalar integer type.
- If present, any Memory Operands must begin with a memory operand literal. If not
+ If present, any Memory Operands must begin with a memory operand literal. If not
present, it is the same as specifying the memory operand None.
#### Example:
```mlir
- %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
- {memory_access = #spv.memory_access<Volatile>} :
- (!spv.ptr<i32, CrossWorkgroup>, i32) ->
+ %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
+ {memory_access = #spv.memory_access<Volatile>} :
+ (!spv.ptr<i32, CrossWorkgroup>, i32) ->
!spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
```
}];
@@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
// -----
-def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
+def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
[NoSideEffect, AllTypesMatch<["c", "result"]>]> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
- Multiply matrix A by matrix B and add matrix C to the result
- of the multiplication: A*B+C. Here A is a M x K matrix, B is
+ Multiply matrix A by matrix B and add matrix C to the result
+ of the multiplication: A*B+C. Here A is a M x K matrix, B is
a K x N matrix and C is a M x N matrix.
- Behavior is undefined if sizes of operands do not meet the
- conditions above. All operands and the Result Type must be
+ Behavior is undefined if sizes of operands do not meet the
+ conditions above. All operands and the Result Type must be
OpTypeJointMatrixINTEL.
- A must be a OpTypeJointMatrixINTEL whose Component Type is a
- signed numerical type, Row Count equals to M and Column Count
+ A must be a OpTypeJointMatrixINTEL whose Component Type is a
+ signed numerical type, Row Count equals to M and Column Count
equals to K
- B must be a OpTypeJointMatrixINTEL whose Component Type is a
- signed numerical type, Row Count equals to K and Column Count
+ B must be a OpTypeJointMatrixINTEL whose Component Type is a
+ signed numerical type, Row Count equals to K and Column Count
equals to N
- C and Result Type must be a OpTypeJointMatrixINTEL with Row
+ C and Result Type must be a OpTypeJointMatrixINTEL with Row
Count equals to M and Column Count equals to N
- Scope is syncronization scope for operation on the matrix.
- It must be the result of a constant instruction with scalar
+ Scope is syncronization scope for operation on the matrix.
+ It must be the result of a constant instruction with scalar
integer type.
#### Example:
```mlir
- %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
- !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
- !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
+ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
+ !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
+ !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
-> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>
```
@@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
// -----
-def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
+def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
let summary = "See extension SPV_INTEL_joint_matrix";
let description = [{
Store a matrix through a pointer.
- Pointer is the pointer to store through. It specifies
- start of memory region where elements of the matrix must
+ Pointer is the pointer to store through. It specifies
+ start of memory region where elements of the matrix must
be stored and arranged according to Layout.
- Object is the matrix to store. It must be
+ Object is the matrix to store. It must be
OpTypeJointMatrixINTEL.
- Stride is the number of elements in memory between beginnings
- of successive rows, columns (or words) of the Object. It must
+ Stride is the number of elements in memory between beginnings
+ of successive rows, columns (or words) of the Object. It must
be a scalar integer type.
- Layout indicates how the values stored to memory are arranged.
+ Layout indicates how the values stored to memory are arranged.
It must be the result of a constant instruction.
- Scope is syncronization scope for operation on the matrix.
- It must be the result of a constant instruction with scalar
+ Scope is syncronization scope for operation on the matrix.
+ It must be the result of a constant instruction with scalar
integer type.
- If present, any Memory Operands must begin with a memory operand
- literal. If not present, it is the same as specifying the memory
+ If present, any Memory Operands must begin with a memory operand
+ literal. If not present, it is the same as specifying the memory
operand None.
#### Example:
```mlir
- spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
- {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
+ spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
+ {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
!spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
```
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 2f3f2f08488d9..d514f363137ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
// -----
-def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
+def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
let summary = "TBD";
let description = [{
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 03a5f32ab7c80..7cb7588b5379a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
// spv.AtomicFAddEXTOp
//===----------------------------------------------------------------------===//
-LogicalResult spirv::AtomicFAddEXTOp::verify() {
+LogicalResult spirv::EXTAtomicFAddOp::verify() {
return ::verifyAtomicUpdateOp<FloatType>(getOperation());
}
-ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
+ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
return ::parseAtomicUpdateOp(parser, result, true);
}
-void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
+void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
::printAtomicUpdateOp(*this, p);
}
@@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
// spv.SubgroupBlockReadINTEL
//===----------------------------------------------------------------------===//
-ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
+ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
@@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
return success();
}
-void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
+void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << " : " << getType();
}
-LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
+LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();
@@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
// spv.SubgroupBlockWriteINTEL
//===----------------------------------------------------------------------===//
-ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
+ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse the storage class specification
spirv::StorageClass storageClass;
@@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
return success();
}
-void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
+void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
printer << " " << ptr() << ", " << value() << " : " << value().getType();
}
-LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
+LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
return failure();
@@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
// spv.CooperativeMatrixLoadNV
//===----------------------------------------------------------------------===//
-ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
+ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
@@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
return success();
}
-void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
+void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
// Print optional memory access attribute.
if (auto memAccess = memory_access())
@@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
return success();
}
-LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
result().getType());
}
@@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
// spv.CooperativeMatrixStoreNV
//===----------------------------------------------------------------------===//
-ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
+ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
@@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
return success();
}
-void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
+void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
<< columnmajor();
// Print optional memory access attribute.
@@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
}
-LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
object().getType());
}
@@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
//===----------------------------------------------------------------------===//
static LogicalResult
-verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
+verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
@@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
return success();
}
-LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
return verifyCoopMatrixMulAdd(*this);
}
@@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
// spv.JointMatrixLoadINTEL
//===----------------------------------------------------------------------===//
-LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
result().getType());
}
@@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
// spv.JointMatrixStoreINTEL
//===----------------------------------------------------------------------===//
-LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
return verifyPointerAndJointMatrixType(*this, pointer().getType(),
object().getType());
}
@@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
// spv.JointMatrixMadINTEL
//===----------------------------------------------------------------------===//
-static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
+static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
if (op.c().getType() != op.result().getType())
return op.emitOpError("result and third operand must have the same type");
auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
@@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
return success();
}
-LogicalResult spirv::JointMatrixMadINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixMadOp::verify() {
return verifyJointMatrixMad(*this);
}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 7f796aafb964a..b835fda37dfd4 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Value predicate = op->getOperand(0);
- rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
+ rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
op, op->getResult(0).getType(), predicate);
return success();
}
diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh
index b932db8292606..eb37cef602a4a 100755
--- a/mlir/utils/spirv/define_inst.sh
+++ b/mlir/utils/spirv/define_inst.sh
@@ -23,13 +23,17 @@ file_name=$1
baseclass=$2
case $baseclass in
- Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
+ Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
+ | LogicalBinaryOp | LogicalUnaryOp \
+ | CastOp | ControlFlowOp | StructureOp \
+ | AtomicUpdateOp | AtomicUpdateWithValueOp \
+ | KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
;;
*)
echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
echo "<baseclass> must be one of " \
- "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
+ "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
exit 1;
;;
esac
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index c0b145e17cedb..73071386aa332 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')
else:
- fmt_str = ('def SPV_{opname_src}Op : '
- 'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
+ fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
+ 'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
'{{\n let summary = {summary};\n\n let description = '
'[{{\n{description}}}];{availability}\n')
+ vendor_name = ''
inst_category = existing_info.get('inst_category', 'Op')
if inst_category == 'Op':
fmt_str +='\n let arguments = (ins{args});\n\n'\
' let results = (outs{results});\n'
+ elif inst_category.endswith('VendorOp'):
+ vendor_name = inst_category.split('VendorOp')[0].upper()
+ assert len(vendor_name) != 0, 'Invalid instruction category'
fmt_str +='{extras}'\
'}}\n'
@@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opname_src = instruction['opname']
if opname.startswith('Op'):
opname_src = opname_src[2:]
+ if len(vendor_name) > 0:
+ assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
+ opname_src = opname_src[:-len(vendor_name)]
category_args = existing_info.get('category_args', '')
@@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
# Format summary. If the summary can fit in the same line, we print it out
# as a "-quoted string; otherwise, wrap the lines using "[{...}]".
- summary = summary.strip();
+ summary = summary.strip()
if len(summary) + len(' let summary = "";') <= 80:
summary = '"{}"'.format(summary)
else:
@@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
opcode=instruction['opcode'],
category_args=category_args,
inst_category=inst_category,
+ vendor_name=vendor_name,
traits=existing_info.get('traits', ''),
summary=summary,
description=description,
More information about the Mlir-commits
mailing list