[Mlir-commits] [mlir] b8bea83 - [mlir][spirv] Refactor vendor op definitions

Jakub Kuderski llvmlistbot at llvm.org
Tue Sep 6 10:36:00 PDT 2022


Author: Jakub Kuderski
Date: 2022-09-06T13:35:08-04:00
New Revision: b8bea837f34534c36fd07477e51b41dd222fb869

URL: https://github.com/llvm/llvm-project/commit/b8bea837f34534c36fd07477e51b41dd222fb869
DIFF: https://github.com/llvm/llvm-project/commit/b8bea837f34534c36fd07477e51b41dd222fb869.diff

LOG: [mlir][spirv] Refactor vendor op definitions

Use dedicated vendor op classes/categories. This is so that we can later
change the mnemonics of all vendor ops by changing the base class: `SPV_VendorOp`.

Issue: https://github.com/llvm/llvm-project/issues/56863

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
    mlir/utils/spirv/define_inst.sh
    mlir/utils/spirv/gen_spirv_dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
index 7654c528b75de..ba9d5a7223297 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
@@ -262,7 +262,7 @@ def SPV_AtomicExchangeOp : SPV_Op<"AtomicExchange", []> {
 
 // -----
 
-def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
+def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> {
   let summary = "TBD";
 
   let description = [{
@@ -279,7 +279,7 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> {
 
     3) store the New Value back through Pointer.
 
-    The instruction’s result is the Original Value.
+    The instruction's result is the Original Value.
 
     Result Type must be a floating-point type scalar.
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 413af73c506de..bff4543c526e1 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -15,7 +15,7 @@
 
 // -----
 
-def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
+def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength",
   [NoSideEffect]> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
@@ -60,7 +60,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV",
 
 // -----
 
-def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
+def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
   let description = [{
@@ -136,7 +136,7 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
 
 // -----
 
-def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
+def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd",
   [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
@@ -210,7 +210,7 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV",
 
 // -----
 
-def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
+def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> {
   let summary = "See extension SPV_NV_cooperative_matrix";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
index 32a9aafe32f37..e640824133226 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td
@@ -92,7 +92,7 @@ def SPV_GroupBroadcastOp : SPV_Op<"GroupBroadcast",
 
 // -----
 
-def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
+def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> {
   let summary = "See extension SPV_KHR_shader_ballot";
 
   let description = [{
@@ -146,7 +146,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
 
 // -----
 
-def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
+def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> {
   let summary = "See extension SPV_INTEL_subgroups";
 
   let description = [{
@@ -197,7 +197,7 @@ def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> {
 
 // -----
 
-def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> {
+def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> {
   let summary = "See extension SPV_INTEL_subgroups";
 
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
index aa45ef80b5e94..e95be56f26750 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td
@@ -15,12 +15,12 @@
 
 // -----
 
-def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL",
+def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength",
   [NoSideEffect]> {
   let summary = "See extension SPV_INTEL_joint_matrix";
 
   let description = [{
-    Return number of components owned by the current work-item in 
+    Return number of components owned by the current work-item in
     a joint matrix.
 
     Result Type must be an 32-bit unsigned integer type scalar.
@@ -60,7 +60,7 @@ def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTE
 
 // -----
 
-def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
+def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> {
   let summary = "See extension SPV_INTEL_joint_matrix";
 
   let description = [{
@@ -68,26 +68,26 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
 
     Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL.
 
-    Pointer is the pointer to load through. It specifies start of memory region where 
+    Pointer is the pointer to load through. It specifies start of memory region where
     elements of the matrix are stored and arranged according to Layout.
 
-    Stride is the number of elements in memory between beginnings of successive rows, 
+    Stride is the number of elements in memory between beginnings of successive rows,
     columns (or words) in the result. It must be a scalar integer type.
 
-    Layout indicates how the values loaded from memory are arranged. It must be the 
+    Layout indicates how the values loaded from memory are arranged. It must be the
     result of a constant instruction.
 
-    Scope is syncronization scope for operation on the matrix. It must be the result 
+    Scope is syncronization scope for operation on the matrix. It must be the result
     of a constant instruction with scalar integer type.
 
-    If present, any Memory Operands must begin with a memory operand literal. If not 
+    If present, any Memory Operands must begin with a memory operand literal. If not
     present, it is the same as specifying the memory operand None.
 
     #### Example:
     ```mlir
-    %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride 
-         {memory_access = #spv.memory_access<Volatile>} : 
-         (!spv.ptr<i32, CrossWorkgroup>, i32) -> 
+    %0 = spv.JointMatrixLoadINTEL <Subgroup> <RowMajor> %ptr, %stride
+         {memory_access = #spv.memory_access<Volatile>} :
+         (!spv.ptr<i32, CrossWorkgroup>, i32) ->
          !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup>
     ```
   }];
@@ -119,39 +119,39 @@ def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> {
 
 // -----
 
-def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
+def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad",
   [NoSideEffect, AllTypesMatch<["c", "result"]>]> {
   let summary = "See extension SPV_INTEL_joint_matrix";
 
   let description = [{
-    Multiply matrix A by matrix B and add matrix C to the result 
-    of the multiplication: A*B+C. Here A is a M x K matrix, B is 
+    Multiply matrix A by matrix B and add matrix C to the result
+    of the multiplication: A*B+C. Here A is a M x K matrix, B is
     a K x N matrix and C is a M x N matrix.
 
-    Behavior is undefined if sizes of operands do not meet the 
-    conditions above. All operands and the Result Type must be 
+    Behavior is undefined if sizes of operands do not meet the
+    conditions above. All operands and the Result Type must be
     OpTypeJointMatrixINTEL.
 
-    A must be a OpTypeJointMatrixINTEL whose Component Type is a 
-    signed numerical type, Row Count equals to M and Column Count 
+    A must be a OpTypeJointMatrixINTEL whose Component Type is a
+    signed numerical type, Row Count equals to M and Column Count
     equals to K
 
-    B must be a OpTypeJointMatrixINTEL whose Component Type is a 
-    signed numerical type, Row Count equals to K and Column Count 
+    B must be a OpTypeJointMatrixINTEL whose Component Type is a
+    signed numerical type, Row Count equals to K and Column Count
     equals to N
 
-    C and Result Type must be a OpTypeJointMatrixINTEL with Row 
+    C and Result Type must be a OpTypeJointMatrixINTEL with Row
     Count equals to M and Column Count equals to N
 
-    Scope is syncronization scope for operation on the matrix. 
-    It must be the result of a constant instruction with scalar 
+    Scope is syncronization scope for operation on the matrix.
+    It must be the result of a constant instruction with scalar
     integer type.
 
     #### Example:
     ```mlir
-    %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : 
-         !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, 
-         !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> 
+    %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c :
+         !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>,
+         !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup>
          -> !spv.jointmatrix<8x8xi32,  RowMajor, Subgroup>
     ```
 
@@ -182,38 +182,38 @@ def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL",
 
 // -----
 
-def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> {
+def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> {
   let summary = "See extension SPV_INTEL_joint_matrix";
 
   let description = [{
     Store a matrix through a pointer.
 
-    Pointer is the pointer to store through. It specifies 
-    start of memory region where elements of the matrix must 
+    Pointer is the pointer to store through. It specifies
+    start of memory region where elements of the matrix must
     be stored and arranged according to Layout.
 
-    Object is the matrix to store. It must be 
+    Object is the matrix to store. It must be
     OpTypeJointMatrixINTEL.
 
-    Stride is the number of elements in memory between beginnings 
-    of successive rows, columns (or words) of the Object. It must 
+    Stride is the number of elements in memory between beginnings
+    of successive rows, columns (or words) of the Object. It must
     be a scalar integer type.
 
-    Layout indicates how the values stored to memory are arranged. 
+    Layout indicates how the values stored to memory are arranged.
     It must be the result of a constant instruction.
 
-    Scope is syncronization scope for operation on the matrix. 
-    It must be the result of a constant instruction with scalar 
+    Scope is syncronization scope for operation on the matrix.
+    It must be the result of a constant instruction with scalar
     integer type.
 
-    If present, any Memory Operands must begin with a memory operand 
-    literal. If not present, it is the same as specifying the memory 
+    If present, any Memory Operands must begin with a memory operand
+    literal. If not present, it is the same as specifying the memory
     operand None.
 
     #### Example:
     ```mlir
-    spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride 
-    {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>, 
+    spv.JointMatrixStoreINTEL <Subgroup> <ColumnMajor> %ptr, %m, %stride
+    {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<i32, Workgroup>,
     !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32)
     ```
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 2f3f2f08488d9..d514f363137ec 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -18,7 +18,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 
 // -----
 
-def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> {
+def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> {
   let summary = "TBD";
 
   let description = [{

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 03a5f32ab7c80..7cb7588b5379a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1335,15 +1335,15 @@ void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
 // spv.AtomicFAddEXTOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult spirv::AtomicFAddEXTOp::verify() {
+LogicalResult spirv::EXTAtomicFAddOp::verify() {
   return ::verifyAtomicUpdateOp<FloatType>(getOperation());
 }
 
-ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
+ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
                                           OperationState &result) {
   return ::parseAtomicUpdateOp(parser, result, true);
 }
-void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
+void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
   ::printAtomicUpdateOp(*this, p);
 }
 
@@ -2646,7 +2646,7 @@ LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
 // spv.SubgroupBlockReadINTEL
 //===----------------------------------------------------------------------===//
 
-ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
+ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
                                                    OperationState &result) {
   // Parse the storage class specification
   spirv::StorageClass storageClass;
@@ -2669,11 +2669,11 @@ ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
   return success();
 }
 
-void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
+void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
   printer << " " << ptr() << " : " << getType();
 }
 
-LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
+LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
   if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
     return failure();
 
@@ -2684,7 +2684,7 @@ LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
 // spv.SubgroupBlockWriteINTEL
 //===----------------------------------------------------------------------===//
 
-ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
+ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
                                                     OperationState &result) {
   // Parse the storage class specification
   spirv::StorageClass storageClass;
@@ -2708,11 +2708,11 @@ ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
   return success();
 }
 
-void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
+void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
   printer << " " << ptr() << ", " << value() << " : " << value().getType();
 }
 
-LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
+LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
   if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
     return failure();
 
@@ -3816,7 +3816,7 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 // spv.CooperativeMatrixLoadNV
 //===----------------------------------------------------------------------===//
 
-ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
+ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
                                                     OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
@@ -3838,7 +3838,7 @@ ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
   return success();
 }
 
-void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
+void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
   printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
   // Print optional memory access attribute.
   if (auto memAccess = memory_access())
@@ -3865,7 +3865,7 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
   return success();
 }
 
-LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
   return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
                                         result().getType());
 }
@@ -3874,7 +3874,7 @@ LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
 // spv.CooperativeMatrixStoreNV
 //===----------------------------------------------------------------------===//
 
-ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
+ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
                                                      OperationState &result) {
   SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
@@ -3896,7 +3896,7 @@ ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
   return success();
 }
 
-void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
+void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
   printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
           << columnmajor();
   // Print optional memory access attribute.
@@ -3905,7 +3905,7 @@ void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
   printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
 }
 
-LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
   return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
                                         object().getType());
 }
@@ -3915,7 +3915,7 @@ LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult
-verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
+verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
   if (op.c().getType() != op.result().getType())
     return op.emitOpError("result and third operand must have the same type");
   auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
@@ -3936,7 +3936,7 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
   return success();
 }
 
-LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
+LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
   return verifyCoopMatrixMulAdd(*this);
 }
 
@@ -3963,7 +3963,7 @@ verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
 // spv.JointMatrixLoadINTEL
 //===----------------------------------------------------------------------===//
 
-LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
   return verifyPointerAndJointMatrixType(*this, pointer().getType(),
                                          result().getType());
 }
@@ -3972,7 +3972,7 @@ LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
 // spv.JointMatrixStoreINTEL
 //===----------------------------------------------------------------------===//
 
-LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
   return verifyPointerAndJointMatrixType(*this, pointer().getType(),
                                          object().getType());
 }
@@ -3981,7 +3981,7 @@ LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
 // spv.JointMatrixMadINTEL
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
+static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
   if (op.c().getType() != op.result().getType())
     return op.emitOpError("result and third operand must have the same type");
   auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
@@ -4002,7 +4002,7 @@ static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
   return success();
 }
 
-LogicalResult spirv::JointMatrixMadINTELOp::verify() {
+LogicalResult spirv::INTELJointMatrixMadOp::verify() {
   return verifyJointMatrixMad(*this);
 }
 

diff  --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 7f796aafb964a..b835fda37dfd4 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -240,7 +240,7 @@ ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
                                          PatternRewriter &rewriter) const {
   Value predicate = op->getOperand(0);
 
-  rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
+  rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
       op, op->getResult(0).getType(), predicate);
   return success();
 }

diff  --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh
index b932db8292606..eb37cef602a4a 100755
--- a/mlir/utils/spirv/define_inst.sh
+++ b/mlir/utils/spirv/define_inst.sh
@@ -23,13 +23,17 @@ file_name=$1
 baseclass=$2
 
 case $baseclass in
-  Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp)
+  Op | ArithmeticBinaryOp | ArithmeticUnaryOp \
+     | LogicalBinaryOp | LogicalUnaryOp \
+     | CastOp | ControlFlowOp | StructureOp \
+     | AtomicUpdateOp | AtomicUpdateWithValueOp \
+     | KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp )
   ;;
   *)
     echo "Usage : " $0 "<filename> <baseclass> (<opname>)*"
     echo "<filename> is the file name of MLIR SPIR-V op definitions spec"
     echo "<baseclass> must be one of " \
-      "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)"
+      "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)"
     exit 1;
   ;;
 esac

diff  --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index c0b145e17cedb..73071386aa332 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -730,15 +730,19 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
                '{{\n  let summary = {summary};\n\n  let description = '
                '[{{\n{description}}}];{availability}\n')
   else:
-    fmt_str = ('def SPV_{opname_src}Op : '
-               'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> '
+    fmt_str = ('def SPV_{vendor_name}{opname_src}Op : '
+               'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
                '{{\n  let summary = {summary};\n\n  let description = '
                '[{{\n{description}}}];{availability}\n')
 
+  vendor_name = ''
   inst_category = existing_info.get('inst_category', 'Op')
   if inst_category == 'Op':
     fmt_str +='\n  let arguments = (ins{args});\n\n'\
               '  let results = (outs{results});\n'
+  elif inst_category.endswith('VendorOp'):
+    vendor_name = inst_category.split('VendorOp')[0].upper()
+    assert len(vendor_name) != 0, 'Invalid instruction category'
 
   fmt_str +='{extras}'\
             '}}\n'
@@ -746,6 +750,9 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
   opname_src = instruction['opname']
   if opname.startswith('Op'):
     opname_src = opname_src[2:]
+  if len(vendor_name) > 0:
+    assert opname_src.endswith(vendor_name), "op name does not match the instruction category"
+    opname_src = opname_src[:-len(vendor_name)]
 
   category_args = existing_info.get('category_args', '')
 
@@ -759,7 +766,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
 
   # Format summary. If the summary can fit in the same line, we print it out
   # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
-  summary = summary.strip();
+  summary = summary.strip()
   if len(summary) + len('  let summary = "";') <= 80:
     summary = '"{}"'.format(summary)
   else:
@@ -815,6 +822,7 @@ def get_op_definition(instruction, opname, doc, existing_info, capability_mappin
       opcode=instruction['opcode'],
       category_args=category_args,
       inst_category=inst_category,
+      vendor_name=vendor_name,
       traits=existing_info.get('traits', ''),
       summary=summary,
       description=description,


        


More information about the Mlir-commits mailing list