[Mlir-commits] [mlir] 1d5e3b2 - [mlir][spirv] Use ODS generated attribute names for op definitions (#81552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 25 16:47:29 PST 2024


Author: tw-ilson
Date: 2024-02-25T16:47:25-08:00
New Revision: 1d5e3b2d6559a853c544099e4cf1d46f44f83368

URL: https://github.com/llvm/llvm-project/commit/1d5e3b2d6559a853c544099e4cf1d46f44f83368
DIFF: https://github.com/llvm/llvm-project/commit/1d5e3b2d6559a853c544099e4cf1d46f44f83368.diff

LOG: [mlir][spirv] Use ODS generated attribute names for op definitions (#81552)

Since ODS generates getters functions for SPIRV operations' attribute
names, we replace instances of these hardcoded strings in the SPIR-V
dialect's op parser/printer with function calls for consistency.

Fixes https://github.com/llvm/llvm-project/issues/77627

---------

Co-authored-by: Lei Zhang <antiagainst at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
    mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
    mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..948d48980f2e84 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
 }
 
 // Verifies an atomic update op.
-template <typename ExpectedElementType>
+template <typename AtomicOpTy, typename ExpectedElementType>
 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
   auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
   auto elementType = ptrType.getPointeeType();
@@ -42,8 +42,10 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
+  StringAttr semanticsAttrName =
+      AtomicOpTy::getSemanticsAttrName(op->getName());
   auto memorySemantics =
-      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
   if (failed(verifyMemorySemantics(op, memorySemantics))) {
     return failure();
@@ -56,7 +58,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicAndOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -64,7 +66,7 @@ LogicalResult AtomicAndOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIAddOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -72,7 +74,7 @@ LogicalResult AtomicIAddOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult EXTAtomicFAddOp::verify() {
-  return verifyAtomicUpdateOp<FloatType>(getOperation());
+  return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -80,7 +82,7 @@ LogicalResult EXTAtomicFAddOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIDecrementOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -88,7 +90,7 @@ LogicalResult AtomicIDecrementOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIIncrementOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -96,7 +98,7 @@ LogicalResult AtomicIIncrementOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicISubOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -104,7 +106,7 @@ LogicalResult AtomicISubOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicOrOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,7 +114,7 @@ LogicalResult AtomicOrOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicSMaxOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -120,7 +122,7 @@ LogicalResult AtomicSMaxOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicSMinOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -128,7 +130,7 @@ LogicalResult AtomicSMinOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicUMaxOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -136,7 +138,7 @@ LogicalResult AtomicUMaxOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicUMinOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -144,7 +146,7 @@ LogicalResult AtomicUMinOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicXorOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
 }
 
 } // namespace mlir::spirv

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 7170a899069ee3..3e317319b68fc5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(kBranchWeightAttrName,
+    StringAttr branchWeightsAttrName =
+        BranchConditionalOp::getBranchWeightsAttrName(result.name);
+    result.addAttribute(branchWeightsAttrName,
                         builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
@@ -199,11 +201,11 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+  return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+  (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {

diff  --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..fcbef2be75f9a0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -20,15 +20,18 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+template <typename OpTy>
 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
                                                     OperationState &state) {
   spirv::Scope executionScope;
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
-  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                kExecutionScopeAttrName) ||
-      spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  kGroupOperationAttrName) ||
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+          executionScope, parser, state,
+          OpTy::getExecutionScopeAttrName(state.name)) ||
+      spirv::parseEnumStrAttr<GroupOperationAttr>(
+          groupOperation, parser, state,
+          OpTy::getGroupOperationAttrName(state.name)) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -56,16 +59,23 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   return parser.addTypeToList(resultType, state.types);
 }
 
+template <typename GroupNonUniformArithmeticOpTy>
 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                              OpAsmPrinter &printer) {
   printer
       << " \""
       << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+             groupOp
+                 ->getAttrOfType<spirv::ScopeAttr>(
+                     GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+                         groupOp->getName()))
                  .getValue())
       << "\" \""
       << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+             groupOp
+                 ->getAttrOfType<GroupOperationAttr>(
+                     GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+                         groupOp->getName()))
                  .getValue())
       << "\" " << groupOp->getOperand(0);
 
@@ -74,16 +84,21 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
   printer << " : " << groupOp->getResult(0).getType();
 }
 
+template <typename OpTy>
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+      groupOp
+          ->getAttrOfType<spirv::ScopeAttr>(
+              OpTy::getExecutionScopeAttrName(groupOp->getName()))
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+      groupOp
+          ->getAttrOfType<GroupOperationAttr>(
+              OpTy::getGroupOperationAttrName(groupOp->getName()))
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
@@ -206,16 +221,17 @@ LogicalResult GroupNonUniformElectOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFAddOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
 }
 
 ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -223,16 +239,17 @@ void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -240,16 +257,17 @@ void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
 }
 
 ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -257,16 +275,17 @@ void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMulOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
 }
 
 ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -274,16 +293,17 @@ void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformIAddOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
 }
 
 ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -291,16 +311,17 @@ void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformIMulOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
 }
 
 ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -308,16 +329,17 @@ void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformSMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -325,16 +347,17 @@ void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformSMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
 }
 
 ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -342,16 +365,17 @@ void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformUMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -359,16 +383,17 @@ void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformUMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
 }
 
 ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -376,16 +401,17 @@ void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseAndOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -393,16 +419,17 @@ void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseOrOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
+                                                                      result);
 }
 
 void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -410,16 +437,17 @@ void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseXorOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -427,16 +455,17 @@ void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalAndOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -444,16 +473,17 @@ void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalOrOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
+                                                                      result);
 }
 
 void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -461,16 +491,17 @@ void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalXorOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..f5676f36a0f5f5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -25,6 +25,7 @@ namespace mlir::spirv {
 // Integer Dot Product ops
 //===----------------------------------------------------------------------===//
 
+template <typename IntegerDotProductOpTy>
 static LogicalResult verifyIntegerDotProduct(Operation *op) {
   assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
          "Not an integer dot product op?");
@@ -33,10 +34,12 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
+  StringAttr packedVectorFormatAttrName =
+      IntegerDotProductOpTy::getFormatAttrName(op->getName());
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(kPackedVectorFormatAttrName));
+            op->getAttr(packedVectorFormatAttrName));
     if (!packedVectorFormat)
       return op->emitOpError("requires Packed Vector Format attribute for "
                              "integer vector operands");
@@ -50,7 +53,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
                         "integer vector operands to be 32-bits wide",
                         packedVectorFormat.getValue()));
   } else {
-    if (op->hasAttr(kPackedVectorFormatAttrName))
+    if (op->hasAttr(packedVectorFormatAttrName))
       return op->emitOpError(llvm::formatv(
           "with invalid format attribute for vector operands of type '{0}'",
           factorTy));
@@ -84,6 +87,7 @@ getIntegerDotProductExtensions() {
   return {extension};
 }
 
+template <typename IntegerDotProductOpTy>
 static SmallVector<ArrayRef<spirv::Capability>, 1>
 getIntegerDotProductCapabilities(Operation *op) {
   // Requires the the DotProduct capability and capabilities that depend on
@@ -99,9 +103,11 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
+  StringAttr packedVectorFormatAttrName =
+      IntegerDotProductOpTy::getFormatAttrName(op->getName());
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
-        op->getAttr(kPackedVectorFormatAttrName));
+        op->getAttr(packedVectorFormatAttrName));
     if (formatAttr.getValue() ==
         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
       capabilities.push_back(dotProductInput4x8BitPackedCap);
@@ -120,12 +126,14 @@ getIntegerDotProductCapabilities(Operation *op) {
 }
 
 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
-  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
+  LogicalResult OpName::verify() {                                             \
+    return verifyIntegerDotProduct<OpName>(*this);                             \
+  }                                                                            \
   SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
     return getIntegerDotProductExtensions();                                   \
   }                                                                            \
   SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
-    return getIntegerDotProductCapabilities(*this);                            \
+    return getIntegerDotProductCapabilities<OpName>(*this);                    \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMinVersion() {                      \
     return getIntegerDotProductMinVersion();                                   \

diff  --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..c4c7ff722175dc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -25,10 +25,48 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+///     (`[` memory-access `]`)?
+/// where:
+///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+///         integer-literal | `"NonTemporal"`
+template <typename MemoryOpTy>
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+                                        OperationState &state) {
+  // Parse an optional list of attributes staring with '['
+  if (parser.parseOptionalLSquare()) {
+    // Nothing to do
+    return success();
+  }
+
+  spirv::MemoryAccess memoryAccessAttr;
+  StringAttr memoryAccessAttrName =
+      MemoryOpTy::getMemoryAccessAttrName(state.name);
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
+    return failure();
+
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
+    // Parse integer attribute for alignment.
+    Attribute alignmentAttr;
+    StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.parseComma() ||
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
+                              state.attributes)) {
+      return failure();
+    }
+  }
+  return parser.parseRSquare();
+}
+
 // TODO Make sure to merge this and the previous function into one template
 // parameterized by memory access attribute name and alignment. Doing so now
 // results in VS2017 in producing an internal error (at the call site) that's
 // not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                      OperationState &state) {
   // Parse an optional list of attributes staring with '['
@@ -38,17 +76,21 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
+  StringRef memoryAccessAttrName =
+      MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
-          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
+    StringAttr alignmentAttrName =
+        MemoryOpTy::getSourceAlignmentAttrName(state.name);
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                               state.attributes)) {
       return failure();
     }
@@ -72,7 +114,7 @@ static void printSourceMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -80,7 +122,7 @@ static void printSourceMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kSourceAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
         printer << ", " << *alignment;
       }
     }
@@ -98,7 +140,7 @@ static void printMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -106,7 +148,7 @@ static void printMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
         printer << ", " << *alignment;
       }
     }
@@ -136,11 +178,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -157,11 +199,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -180,11 +222,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -201,11 +243,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kSourceAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -376,7 +418,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand ptrInfo;
   Type elementType;
   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
-      parseMemoryAccessAttributes(parser, result) ||
+      parseMemoryAccessAttributes<LoadOp>(parser, result) ||
       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
       parser.parseType(elementType)) {
     return failure();
@@ -425,8 +467,8 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
   Type elementType;
   if (parseEnumStrAttr(storageClass, parser) ||
       parser.parseOperandList(operandInfo, 2) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
-      parser.parseType(elementType)) {
+      parseMemoryAccessAttributes<StoreOp>(parser, result) ||
+      parser.parseColon() || parser.parseType(elementType)) {
     return failure();
   }
 
@@ -499,13 +541,13 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
       parseEnumStrAttr(sourceStorageClass, parser) ||
       parser.parseOperand(sourcePtrInfo) ||
-      parseMemoryAccessAttributes(parser, result)) {
+      parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
     return failure();
   }
 
   if (!parser.parseOptionalComma()) {
     // Parse 2nd memory access attributes.
-    if (parseSourceMemoryAccessAttributes(parser, result)) {
+    if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
       return failure();
     }
   }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..38415620fd4f96 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,14 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
+  StringRef indicesAttrName =
+      spirv::CompositeExtractOp::getIndicesAttrName(result.name);
   Type compositeType;
   SMLoc attrLocation;
 
   if (parser.parseOperand(compositeInfo) ||
       parser.getCurrentLocation(&attrLocation) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(compositeType) ||
       parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
     return failure();
@@ -513,11 +515,13 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
+  StringRef indicesAttrName =
+      spirv::CompositeInsertOp::getIndicesAttrName(result.name);
   auto loc = parser.getCurrentLocation();
 
   return failure(
       parser.parseOperandList(operands, 2) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(objectType) ||
       parser.parseKeywordType("into", compositeType) ||
       parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
+  if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
   Type type = NoneType::get(parser.getContext());
@@ -822,7 +827,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(kInterfaceAttrName,
+  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
                       parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
@@ -875,7 +880,9 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  result.addAttribute(kValuesAttrName,
+  StringRef valuesAttrName =
+      spirv::ExecutionModeOp::getValuesAttrName(result.name);
+  result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
 }
@@ -1150,16 +1157,18 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
+  StringRef initializerAttrName =
+      spirv::GlobalVariableOp::getInitializerAttrName(result.name);
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
   }
 
   // Parse optional initializer
-  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
     FlatSymbolRefAttr initSymbol;
     if (parser.parseLParen() ||
-        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+        parser.parseAttribute(initSymbol, Type(), initializerAttrName,
                               result.attributes) ||
         parser.parseRParen())
       return failure();
@@ -1170,6 +1179,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
+  StringRef typeAttrName =
+      spirv::GlobalVariableOp::getTypeAttrName(result.name);
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1177,7 +1188,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   if (!llvm::isa<spirv::PointerType>(type)) {
     return parser.emitError(loc, "expected spirv.ptr type");
   }
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
@@ -1191,15 +1202,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getSymName());
   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
 
+  StringRef initializerAttrName = this->getInitializerAttrName();
   // Print optional initializer
   if (auto initializer = this->getInitializer()) {
-    printer << " " << kInitializerAttrName << '(';
+    printer << " " << initializerAttrName << '(';
     printer.printSymbolName(*initializer);
     printer << ')';
-    elidedAttrs.push_back(kInitializerAttrName);
+    elidedAttrs.push_back(initializerAttrName);
   }
 
-  elidedAttrs.push_back(kTypeAttrName);
+  StringRef typeAttrName = this->getTypeAttrName();
+  elidedAttrs.push_back(typeAttrName);
   spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
 }
@@ -1219,8 +1232,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
            << stringifyStorageClass(storageClass) << "'";
   }
 
-  if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+          this->getInitializerAttrName())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1594,6 +1607,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
+  StringRef defaultValueAttrName =
+      spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1609,8 +1624,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, kDefaultValueAttrName,
-                            result.attributes))
+      parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
     return failure();
 
   return success();
@@ -1783,14 +1797,18 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  result.addAttribute(kCompositeSpecConstituentsName,
+  StringAttr compositeSpecConstituentsName =
+      spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
+  result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
   Type type;
   if (parser.parseColonType(type))
     return failure();
 
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  StringAttr typeAttrName =
+      spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 27c373300aee8e..726fb7ce9fdb74 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -18,35 +18,6 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
-ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
-                                        OperationState &state,
-                                        StringRef attrName) {
-  // Parse an optional list of attributes staring with '['
-  if (parser.parseOptionalLSquare()) {
-    // Nothing to do
-    return success();
-  }
-
-  spirv::MemoryAccess memoryAccessAttr;
-  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
-                                                       state, attrName))
-    return failure();
-
-  if (spirv::bitEnumContainsAll(memoryAccessAttr,
-                                spirv::MemoryAccess::Aligned)) {
-    // Parse integer attribute for alignment.
-    Attribute alignmentAttr;
-    Type i32Type = parser.getBuilder().getIntegerType(32);
-    if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type,
-                              AttrNames::kAlignmentAttrName,
-                              state.attributes)) {
-      return failure();
-    }
-  }
-  return parser.parseRSquare();
-}
-
 ParseResult parseVariableDecorations(OpAsmParser &parser,
                                      OperationState &state) {
   auto builtInName = llvm::convertToSnakeFromCamelCase(

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 625c82f6e8e899..858b94f7be8b08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,34 +20,12 @@
 
 namespace mlir::spirv {
 namespace AttrNames {
-// TODO: generate these strings using ODS.
-inline constexpr char kAlignmentAttrName[] = "alignment";
-inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-inline constexpr char kCallee[] = "callee";
-inline constexpr char kClusterSize[] = "cluster_size";
-inline constexpr char kControl[] = "control";
-inline constexpr char kDefaultValueAttrName[] = "default_value";
-inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-inline constexpr char kFnNameAttrName[] = "fn";
-inline constexpr char kGroupOperationAttrName[] = "group_operation";
-inline constexpr char kIndicesAttrName[] = "indices";
-inline constexpr char kInitializerAttrName[] = "initializer";
-inline constexpr char kInterfaceAttrName[] = "interface";
-inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-inline constexpr char kMemoryAccessAttrName[] = "memory_access";
-inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-inline constexpr char kPackedVectorFormatAttrName[] = "format";
-inline constexpr char kSemanticsAttrName[] = "semantics";
-inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-inline constexpr char kSpecIdAttrName[] = "spec_id";
-inline constexpr char kTypeAttrName[] = "type";
-inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-inline constexpr char kValueAttrName[] = "value";
-inline constexpr char kValuesAttrName[] = "values";
-inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+
+inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
+inline constexpr char kControl[] = "control";          // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn";        // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id";   // no ODS generation
+
 } // namespace AttrNames
 
 template <typename Ty>
@@ -143,16 +121,6 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
   return success();
 }
 
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-///     (`[` memory-access `]`)?
-/// where:
-///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-///         integer-literal | `"NonTemporal"`
-ParseResult parseMemoryAccessAttributes(
-    OpAsmParser &parser, OperationState &state,
-    StringRef attrName = AttrNames::kMemoryAccessAttrName);
-
 ParseResult parseVariableDecorations(OpAsmParser &parser,
                                      OperationState &state);
 


        


More information about the Mlir-commits mailing list