[Mlir-commits] [mlir] replaces hardcoded attribute names in SPIRV dialect parsing code (PR #81552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 15 10:27:33 PST 2024


https://github.com/tw-ilson updated https://github.com/llvm/llvm-project/pull/81552

>From c794f24655d131fb698a7cf2310b1f50259b7d16 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Mon, 12 Feb 2024 00:31:59 -0500
Subject: [PATCH 1/4] replaces attribute names in SPIRV parsing code

---
 mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp       |  3 +-
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  |  6 +-
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 12 ++--
 .../Dialect/SPIRV/IR/IntegerDotProductOps.cpp |  8 ++-
 mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp       | 30 +++++-----
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 41 +++++++++-----
 .../Dialect/SPIRV/IR/SPIRVParsingUtils.cpp    |  4 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h | 55 ++++++++++---------
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   |  2 +-
 9 files changed, 90 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..d84133d5933415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
+  StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
   auto memorySemantics =
-      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
   if (failed(verifyMemorySemantics(op, memorySemantics))) {
     return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..650b5448f041b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,7 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(kBranchWeightAttrName,
+    result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
                         builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
@@ -199,11 +199,11 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+  return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..ea82f8442dcfe0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -26,9 +26,9 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
   if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                kExecutionScopeAttrName) ||
+                                                ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
       spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  kGroupOperationAttrName) ||
+                                                  GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -61,11 +61,11 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
   printer
       << " \""
       << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+             groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" \""
       << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+             groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" " << groupOp->getOperand(0);
 
@@ -76,14 +76,14 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
 
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+      groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+      groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..3abc4361321d04 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,10 +33,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(kPackedVectorFormatAttrName));
+            op->getAttr(packedVectorFormatAttrName));
     if (!packedVectorFormat)
       return op->emitOpError("requires Packed Vector Format attribute for "
                              "integer vector operands");
@@ -50,7 +51,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
                         "integer vector operands to be 32-bits wide",
                         packedVectorFormat.getValue()));
   } else {
-    if (op->hasAttr(kPackedVectorFormatAttrName))
+    if (op->hasAttr(packedVectorFormatAttrName))
       return op->emitOpError(llvm::formatv(
           "with invalid format attribute for vector operands of type '{0}'",
           factorTy));
@@ -99,9 +100,10 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
-        op->getAttr(kPackedVectorFormatAttrName));
+        op->getAttr(packedVectorFormatAttrName));
     if (formatAttr.getValue() ==
         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
       capabilities.push_back(dotProductInput4x8BitPackedCap);
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..eb7900fd6ef8fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,17 +38,19 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
+  StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
-          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
+    StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                               state.attributes)) {
       return failure();
     }
@@ -72,7 +74,7 @@ static void printSourceMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -80,7 +82,7 @@ static void printSourceMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kSourceAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -98,7 +100,7 @@ static void printMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -106,7 +108,7 @@ static void printMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -136,11 +138,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -157,11 +159,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -180,11 +182,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -201,11 +203,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kSourceAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..c27df9e30f24e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,13 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
   Type compositeType;
   SMLoc attrLocation;
 
   if (parser.parseOperand(compositeInfo) ||
       parser.getCurrentLocation(&attrLocation) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(compositeType) ||
       parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
     return failure();
@@ -513,11 +514,12 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
 
   return failure(
       parser.parseOperandList(operands, 2) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(objectType) ||
       parser.parseKeywordType("into", compositeType) ||
       parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +561,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+  if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
   Type type = NoneType::get(parser.getContext());
@@ -822,7 +825,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(kInterfaceAttrName,
+  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
                       parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
@@ -875,7 +878,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  result.addAttribute(kValuesAttrName,
+  StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+  result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
 }
@@ -1150,16 +1154,17 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
+  StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
   }
 
   // Parse optional initializer
-  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
     FlatSymbolRefAttr initSymbol;
     if (parser.parseLParen() ||
-        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+        parser.parseAttribute(initSymbol, Type(), initializerAttrName,
                               result.attributes) ||
         parser.parseRParen())
       return failure();
@@ -1170,6 +1175,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
+  StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1177,7 +1183,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   if (!llvm::isa<spirv::PointerType>(type)) {
     return parser.emitError(loc, "expected spirv.ptr type");
   }
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
@@ -1191,15 +1197,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getSymName());
   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
 
+  StringRef initializerAttrName = this->getInitializerAttrName().strref();
   // Print optional initializer
   if (auto initializer = this->getInitializer()) {
-    printer << " " << kInitializerAttrName << '(';
+    printer << " " << initializerAttrName << '(';
     printer.printSymbolName(*initializer);
     printer << ')';
-    elidedAttrs.push_back(kInitializerAttrName);
+    elidedAttrs.push_back(initializerAttrName);
   }
 
-  elidedAttrs.push_back(kTypeAttrName);
+  StringRef typeAttrName = this->getTypeAttrName().strref();
+  elidedAttrs.push_back(typeAttrName);
   spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
 }
@@ -1220,7 +1228,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
   }
 
   if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+          (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1594,6 +1602,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
+  StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1609,7 +1618,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, kDefaultValueAttrName,
+      parser.parseAttribute(valueAttr, defaultValueAttrName,
                             result.attributes))
     return failure();
 
@@ -1783,14 +1792,16 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  result.addAttribute(kCompositeSpecConstituentsName,
+  StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+  result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
   Type type;
   if (parser.parseColonType(type))
     return failure();
 
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 27c373300aee8e..f48b0b93ab5fe8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,6 +12,8 @@
 
 #include "SPIRVParsingUtils.h"
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
 #include "llvm/ADT/StringExtras.h"
 
 using namespace mlir::spirv::AttrNames;
@@ -39,7 +41,7 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
         parser.parseAttribute(alignmentAttr, i32Type,
-                              AttrNames::kAlignmentAttrName,
+                              CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
                               state.attributes)) {
       return failure();
     }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 625c82f6e8e899..d4c0deef6aae3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,34 +20,35 @@
 
 namespace mlir::spirv {
 namespace AttrNames {
+inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+
 // TODO: generate these strings using ODS.
-inline constexpr char kAlignmentAttrName[] = "alignment";
-inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-inline constexpr char kCallee[] = "callee";
-inline constexpr char kClusterSize[] = "cluster_size";
-inline constexpr char kControl[] = "control";
-inline constexpr char kDefaultValueAttrName[] = "default_value";
-inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-inline constexpr char kFnNameAttrName[] = "fn";
-inline constexpr char kGroupOperationAttrName[] = "group_operation";
-inline constexpr char kIndicesAttrName[] = "indices";
-inline constexpr char kInitializerAttrName[] = "initializer";
-inline constexpr char kInterfaceAttrName[] = "interface";
-inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-inline constexpr char kMemoryAccessAttrName[] = "memory_access";
-inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-inline constexpr char kPackedVectorFormatAttrName[] = "format";
-inline constexpr char kSemanticsAttrName[] = "semantics";
-inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-inline constexpr char kSpecIdAttrName[] = "spec_id";
-inline constexpr char kTypeAttrName[] = "type";
-inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-inline constexpr char kValueAttrName[] = "value";
-inline constexpr char kValuesAttrName[] = "values";
-inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+// inline constexpr char kAlignmentAttrName[] = "alignment";
+// inline constexpr char kBranchWeightAttrName[] = "branch_weights";
+// inline constexpr char kCallee[] = "callee";
+// inline constexpr char kDefaultValueAttrName[] = "default_value";
+// inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+// inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
+// inline constexpr char kGroupOperationAttrName[] = "group_operation";
+// inline constexpr char kIndicesAttrName[] = "indices";
+// inline constexpr char kInitializerAttrName[] = "initializer";
+// inline constexpr char kInterfaceAttrName[] = "interface";
+// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
+// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
+// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
+// inline constexpr char kPackedVectorFormatAttrName[] = "format";
+// inline constexpr char kSemanticsAttrName[] = "semantics";
+// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
+// inline constexpr char kTypeAttrName[] = "type";
+// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
+// inline constexpr char kValueAttrName[] = "value";
+// inline constexpr char kValuesAttrName[] = "values";
+// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
 } // namespace AttrNames
 
 template <typename Ty>
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 71326049af0579..d34b367e98e3d2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1142,7 +1142,7 @@ void OpEmitter::genAttrNameGetters() {
     } else {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
-  assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  // assert(name.getStringRef() == getOperationName() && "invalid operation name");
   assert(name.isRegistered() && "Operation isn't registered, missing a "
         "dependent dialect loading?");
   return name.getAttributeNames()[index];

>From 1cd7b881fe305cee24c0a0145494da4a6315815e Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Wed, 14 Feb 2024 14:25:56 -0500
Subject: [PATCH 2/4] Change SPIRV parser to use ODS generated Attr names

---
 mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp       |  29 ++--
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  |  14 +-
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 141 ++++++++++--------
 .../Dialect/SPIRV/IR/IntegerDotProductOps.cpp |   9 +-
 mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp       |  56 ++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  39 +++--
 .../Dialect/SPIRV/IR/SPIRVParsingUtils.cpp    |  29 ----
 mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h |  35 +----
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   |   2 +-
 9 files changed, 183 insertions(+), 171 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index d84133d5933415..948d48980f2e84 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
 }
 
 // Verifies an atomic update op.
-template <typename ExpectedElementType>
+template <typename AtomicOpTy, typename ExpectedElementType>
 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
   auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
   auto elementType = ptrType.getPointeeType();
@@ -42,7 +42,8 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
-  StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
+  StringAttr semanticsAttrName =
+      AtomicOpTy::getSemanticsAttrName(op->getName());
   auto memorySemantics =
       op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
@@ -57,7 +58,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicAndOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -65,7 +66,7 @@ LogicalResult AtomicAndOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIAddOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -73,7 +74,7 @@ LogicalResult AtomicIAddOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult EXTAtomicFAddOp::verify() {
-  return verifyAtomicUpdateOp<FloatType>(getOperation());
+  return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -81,7 +82,7 @@ LogicalResult EXTAtomicFAddOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIDecrementOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -89,7 +90,7 @@ LogicalResult AtomicIDecrementOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicIIncrementOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -97,7 +98,7 @@ LogicalResult AtomicIIncrementOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicISubOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -105,7 +106,7 @@ LogicalResult AtomicISubOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicOrOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -113,7 +114,7 @@ LogicalResult AtomicOrOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicSMaxOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -121,7 +122,7 @@ LogicalResult AtomicSMaxOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicSMinOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -129,7 +130,7 @@ LogicalResult AtomicSMinOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicUMaxOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -137,7 +138,7 @@ LogicalResult AtomicUMaxOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicUMinOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
 }
 
 //===----------------------------------------------------------------------===//
@@ -145,7 +146,7 @@ LogicalResult AtomicUMinOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicXorOp::verify() {
-  return verifyAtomicUpdateOp<IntegerType>(getOperation());
+  return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
 }
 
 } // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 650b5448f041b5..c441cd4e637b29 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,8 +87,11 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
-                        builder.getArrayAttr({trueWeight, falseWeight}));
+    StringAttr branchWeightsAttrName =
+        BranchConditionalOp::getBranchWeightsAttrName(result.name);
+    result.addAttribute(
+        branchWeightsAttrName,
+        builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
   // Parse the true branch.
@@ -199,11 +202,14 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
+  return (*this)->getAttrOfType<SymbolRefAttr>(
+      FunctionCallOp::getCalleeAttrName((*this)->getName()));
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
+  (*this)->setAttr(
+      FunctionCallOp::getCalleeAttrName((*this)->getName()),
+      callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ea82f8442dcfe0..3da6870e4ae883 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -20,15 +20,18 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+template <typename GroupNonUniformArithmenticOpTy>
 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
                                                     OperationState &state) {
   spirv::Scope executionScope;
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
-  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
-      spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+          executionScope, parser, state,
+          GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(state.name)) ||
+      spirv::parseEnumStrAttr<GroupOperationAttr>(
+          groupOperation, parser, state,
+          GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(state.name)) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -56,34 +59,44 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   return parser.addTypeToList(resultType, state.types);
 }
 
+template <typename GroupNonUniformArithmeticOpTy>
 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                              OpAsmPrinter &printer) {
-  printer
-      << " \""
-      << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
-                 .getValue())
-      << "\" \""
-      << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
-                 .getValue())
-      << "\" " << groupOp->getOperand(0);
+  printer << " \""
+          << stringifyScope(groupOp
+                                ->getAttrOfType<spirv::ScopeAttr>(
+                                    GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+                                        groupOp->getName()))
+                                .getValue())
+          << "\" \""
+          << stringifyGroupOperation(
+                 groupOp
+                     ->getAttrOfType<GroupOperationAttr>(
+                         GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+                             groupOp->getName()))
+                     .getValue())
+          << "\" " << groupOp->getOperand(0);
 
   if (groupOp->getNumOperands() > 1)
     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
   printer << " : " << groupOp->getResult(0).getType();
 }
 
+template <typename GroupNonUniformArithmenticOpTy>
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
+      groupOp
+          ->getAttrOfType<spirv::ScopeAttr>(
+              GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(groupOp->getName()))
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
+      groupOp
+          ->getAttrOfType<GroupOperationAttr>(
+              GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(groupOp->getName()))
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
@@ -206,16 +219,16 @@ LogicalResult GroupNonUniformElectOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFAddOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
 }
 
 ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser, result);
 }
 
 void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -223,16 +236,16 @@ void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser, result);
 }
 
 void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -240,16 +253,16 @@ void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
 }
 
 ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser, result);
 }
 
 void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -257,16 +270,16 @@ void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformFMulOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
 }
 
 ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser, result);
 }
 
 void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -274,16 +287,16 @@ void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformIAddOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
 }
 
 ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser, result);
 }
 
 void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -291,16 +304,16 @@ void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformIMulOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
 }
 
 ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser, result);
 }
 
 void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -308,16 +321,16 @@ void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformSMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser, result);
 }
 
 void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -325,16 +338,16 @@ void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformSMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
 }
 
 ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser, result);
 }
 
 void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -342,16 +355,16 @@ void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformUMaxOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
 }
 
 ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser, result);
 }
 
 void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -359,16 +372,16 @@ void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformUMinOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
 }
 
 ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser, result);
 }
 
 void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -376,16 +389,16 @@ void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseAndOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser, result);
 }
 
 void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -393,16 +406,16 @@ void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseOrOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser, result);
 }
 
 void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -410,16 +423,16 @@ void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformBitwiseXorOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
 }
 
 ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser, result);
 }
 
 void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -427,16 +440,16 @@ void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalAndOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser, result);
 }
 
 void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -444,16 +457,16 @@ void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalOrOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser, result);
 }
 
 void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
@@ -461,16 +474,16 @@ void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult GroupNonUniformLogicalXorOp::verify() {
-  return verifyGroupNonUniformArithmeticOp(*this);
+  return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
 }
 
 ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser, result);
 }
 
 void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
-  printGroupNonUniformArithmeticOp(*this, p);
+  printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this, p);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 3abc4361321d04..0ef1e7bc0ca8be 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -25,6 +25,7 @@ namespace mlir::spirv {
 // Integer Dot Product ops
 //===----------------------------------------------------------------------===//
 
+template <typename IntegerDotProductOpTy>
 static LogicalResult verifyIntegerDotProduct(Operation *op) {
   assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
          "Not an integer dot product op?");
@@ -33,7 +34,8 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
-  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+  StringAttr packedVectorFormatAttrName =
+      IntegerDotProductOpTy::getFormatAttrName(op->getName());
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
@@ -100,7 +102,8 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
-  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+  StringAttr packedVectorFormatAttrName =
+      SDotAccSatOp::getFormatAttrName(op->getName());
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
         op->getAttr(packedVectorFormatAttrName));
@@ -122,7 +125,7 @@ getIntegerDotProductCapabilities(Operation *op) {
 }
 
 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
-  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); }    \
+  LogicalResult OpName::verify() { return verifyIntegerDotProduct<OpName>(*this); }    \
   SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
     return getIntegerDotProductExtensions();                                   \
   }                                                                            \
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index eb7900fd6ef8fb..25afebf9f92efc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -25,10 +25,49 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+///     (`[` memory-access `]`)?
+/// where:
+///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+///         integer-literal | `"NonTemporal"`
+template<typename MemoryOpTy>
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+                                        OperationState &state) {
+  // Parse an optional list of attributes staring with '['
+  if (parser.parseOptionalLSquare()) {
+    // Nothing to do
+    return success();
+  }
+
+  spirv::MemoryAccess memoryAccessAttr;
+  StringAttr memoryAccessAttrName = MemoryOpTy::getMemoryAccessAttrName(state.name);
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
+                                                       state, memoryAccessAttrName))
+    return failure();
+
+  if (spirv::bitEnumContainsAll(memoryAccessAttr,
+                                spirv::MemoryAccess::Aligned)) {
+    // Parse integer attribute for alignment.
+    Attribute alignmentAttr;
+    StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.parseComma() ||
+        parser.parseAttribute(
+            alignmentAttr, i32Type,
+            alignmentAttrName,
+            state.attributes)) {
+      return failure();
+    }
+  }
+  return parser.parseRSquare();
+}
+
 // TODO Make sure to merge this and the previous function into one template
 // parameterized by memory access attribute name and alignment. Doing so now
 // results in VS2017 in producing an internal error (at the call site) that's
 // not detailed enough to understand what is happening.
+template<typename MemoryOpTy>
 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                      OperationState &state) {
   // Parse an optional list of attributes staring with '['
@@ -38,7 +77,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
+  StringRef memoryAccessAttrName =
+      MemoryOpTy::getMemoryAccessAttrName(state.name).strref();
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
           memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
@@ -47,7 +87,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
-    StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
+    StringAttr alignmentAttrName =
+        MemoryOpTy::getSourceAlignmentAttrName(state.name);
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
         parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
@@ -182,7 +223,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+  auto memAccessAttr =
+      op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
@@ -378,7 +420,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
   OpAsmParser::UnresolvedOperand ptrInfo;
   Type elementType;
   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
-      parseMemoryAccessAttributes(parser, result) ||
+      parseMemoryAccessAttributes<LoadOp>(parser, result) ||
       parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
       parser.parseType(elementType)) {
     return failure();
@@ -427,7 +469,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
   Type elementType;
   if (parseEnumStrAttr(storageClass, parser) ||
       parser.parseOperandList(operandInfo, 2) ||
-      parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+      parseMemoryAccessAttributes<StoreOp>(parser, result) || parser.parseColon() ||
       parser.parseType(elementType)) {
     return failure();
   }
@@ -501,13 +543,13 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
       parseEnumStrAttr(sourceStorageClass, parser) ||
       parser.parseOperand(sourcePtrInfo) ||
-      parseMemoryAccessAttributes(parser, result)) {
+      parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
     return failure();
   }
 
   if (!parser.parseOptionalComma()) {
     // Parse 2nd memory access attributes.
-    if (parseSourceMemoryAccessAttributes(parser, result)) {
+    if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
       return failure();
     }
   }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c27df9e30f24e3..5d6ddaa876575c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,7 +457,8 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
-  StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+  StringRef indicesAttrName =
+      spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
   Type compositeType;
   SMLoc attrLocation;
 
@@ -514,7 +515,8 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
-  StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+  StringRef indicesAttrName =
+      spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
 
   return failure(
@@ -561,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+  StringRef valueAttrName =
+      spirv::ConstantOp::getValueAttrName(result.name).strref();
   if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
@@ -825,8 +828,9 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
-                      parser.getBuilder().getArrayAttr(interfaceVars));
+  result.addAttribute(
+      spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+      parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
 
@@ -878,7 +882,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+  StringRef valuesAttrName =
+      spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
   result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
@@ -1154,7 +1159,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
-  StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+  StringRef initializerAttrName =
+      spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
@@ -1175,7 +1181,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
-  StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+  StringRef typeAttrName =
+      spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1227,8 +1234,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
            << stringifyStorageClass(storageClass) << "'";
   }
 
-  if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
+  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+          this->getInitializerAttrName().strref())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1602,7 +1609,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
-  StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+  StringRef defaultValueAttrName =
+      spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1618,8 +1626,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, defaultValueAttrName,
-                            result.attributes))
+      parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
     return failure();
 
   return success();
@@ -1792,7 +1799,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+  StringAttr compositeSpecConstituentsName =
+      spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
   result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
@@ -1800,7 +1808,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseColonType(type))
     return failure();
 
-  StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+  StringAttr typeAttrName =
+      spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
   result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index f48b0b93ab5fe8..1fd25f1ac92e06 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -20,35 +20,6 @@ using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
 
-ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
-                                        OperationState &state,
-                                        StringRef attrName) {
-  // Parse an optional list of attributes staring with '['
-  if (parser.parseOptionalLSquare()) {
-    // Nothing to do
-    return success();
-  }
-
-  spirv::MemoryAccess memoryAccessAttr;
-  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
-                                                       state, attrName))
-    return failure();
-
-  if (spirv::bitEnumContainsAll(memoryAccessAttr,
-                                spirv::MemoryAccess::Aligned)) {
-    // Parse integer attribute for alignment.
-    Attribute alignmentAttr;
-    Type i32Type = parser.getBuilder().getIntegerType(32);
-    if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type,
-                              CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
-                              state.attributes)) {
-      return failure();
-    }
-  }
-  return parser.parseRSquare();
-}
-
 ParseResult parseVariableDecorations(OpAsmParser &parser,
                                      OperationState &state) {
   auto builtInName = llvm::convertToSnakeFromCamelCase(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index d4c0deef6aae3d..18c9e25227845c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,35 +20,12 @@
 
 namespace mlir::spirv {
 namespace AttrNames {
-inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+
 inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
 inline constexpr char kControl[] = "control"; // no ODS generation
 inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
 inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
 
-// TODO: generate these strings using ODS.
-// inline constexpr char kAlignmentAttrName[] = "alignment";
-// inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-// inline constexpr char kCallee[] = "callee";
-// inline constexpr char kDefaultValueAttrName[] = "default_value";
-// inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-// inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-// inline constexpr char kGroupOperationAttrName[] = "group_operation";
-// inline constexpr char kIndicesAttrName[] = "indices";
-// inline constexpr char kInitializerAttrName[] = "initializer";
-// inline constexpr char kInterfaceAttrName[] = "interface";
-// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-// inline constexpr char kPackedVectorFormatAttrName[] = "format";
-// inline constexpr char kSemanticsAttrName[] = "semantics";
-// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-// inline constexpr char kTypeAttrName[] = "type";
-// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-// inline constexpr char kValueAttrName[] = "value";
-// inline constexpr char kValuesAttrName[] = "values";
-// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
 } // namespace AttrNames
 
 template <typename Ty>
@@ -144,16 +121,6 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
   return success();
 }
 
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-///     (`[` memory-access `]`)?
-/// where:
-///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-///         integer-literal | `"NonTemporal"`
-ParseResult parseMemoryAccessAttributes(
-    OpAsmParser &parser, OperationState &state,
-    StringRef attrName = AttrNames::kMemoryAccessAttrName);
-
 ParseResult parseVariableDecorations(OpAsmParser &parser,
                                      OperationState &state);
 
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d34b367e98e3d2..71326049af0579 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1142,7 +1142,7 @@ void OpEmitter::genAttrNameGetters() {
     } else {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
-  // assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  assert(name.getStringRef() == getOperationName() && "invalid operation name");
   assert(name.isRegistered() && "Operation isn't registered, missing a "
         "dependent dialect loading?");
   return name.getAttributeNames()[index];

>From 3bd6a1b7c5a5d667df9b3c09aa6fe30198026eb6 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Wed, 14 Feb 2024 21:10:02 -0500
Subject: [PATCH 3/4] fix bugs in parser

---
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  | 10 +--
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 90 ++++++++++++-------
 .../Dialect/SPIRV/IR/IntegerDotProductOps.cpp |  9 +-
 mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp       | 47 +++++-----
 mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h |  6 +-
 5 files changed, 92 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index c441cd4e637b29..60dec9ce503061 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -89,9 +89,8 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
 
     StringAttr branchWeightsAttrName =
         BranchConditionalOp::getBranchWeightsAttrName(result.name);
-    result.addAttribute(
-        branchWeightsAttrName,
-        builder.getArrayAttr({trueWeight, falseWeight}));
+    result.addAttribute(branchWeightsAttrName,
+                        builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
   // Parse the true branch.
@@ -207,9 +206,8 @@ CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(
-      FunctionCallOp::getCalleeAttrName((*this)->getName()),
-      callee.get<SymbolRefAttr>());
+  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()),
+                   callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 3da6870e4ae883..7c1d3c30f7f0d9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -28,10 +28,12 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   OpAsmParser::UnresolvedOperand valueInfo;
   if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
           executionScope, parser, state,
-          GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(state.name)) ||
+          GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
+              state.name)) ||
       spirv::parseEnumStrAttr<GroupOperationAttr>(
           groupOperation, parser, state,
-          GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(state.name)) ||
+          GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
+              state.name)) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -62,20 +64,22 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
 template <typename GroupNonUniformArithmeticOpTy>
 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                              OpAsmPrinter &printer) {
-  printer << " \""
-          << stringifyScope(groupOp
-                                ->getAttrOfType<spirv::ScopeAttr>(
-                                    GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
-                                        groupOp->getName()))
-                                .getValue())
-          << "\" \""
-          << stringifyGroupOperation(
-                 groupOp
-                     ->getAttrOfType<GroupOperationAttr>(
-                         GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
-                             groupOp->getName()))
-                     .getValue())
-          << "\" " << groupOp->getOperand(0);
+  printer
+      << " \""
+      << stringifyScope(
+             groupOp
+                 ->getAttrOfType<spirv::ScopeAttr>(
+                     GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+                         groupOp->getName()))
+                 .getValue())
+      << "\" \""
+      << stringifyGroupOperation(
+             groupOp
+                 ->getAttrOfType<GroupOperationAttr>(
+                     GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+                         groupOp->getName()))
+                 .getValue())
+      << "\" " << groupOp->getOperand(0);
 
   if (groupOp->getNumOperands() > 1)
     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -87,7 +91,8 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
       groupOp
           ->getAttrOfType<spirv::ScopeAttr>(
-              GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(groupOp->getName()))
+              GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
+                  groupOp->getName()))
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
@@ -96,7 +101,8 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   GroupOperation operation =
       groupOp
           ->getAttrOfType<GroupOperationAttr>(
-              GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(groupOp->getName()))
+              GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
+                  groupOp->getName()))
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
@@ -224,7 +230,8 @@ LogicalResult GroupNonUniformFAddOp::verify() {
 
 ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
@@ -241,7 +248,8 @@ LogicalResult GroupNonUniformFMaxOp::verify() {
 
 ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
@@ -258,7 +266,8 @@ LogicalResult GroupNonUniformFMinOp::verify() {
 
 ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
@@ -275,7 +284,8 @@ LogicalResult GroupNonUniformFMulOp::verify() {
 
 ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
@@ -292,7 +302,8 @@ LogicalResult GroupNonUniformIAddOp::verify() {
 
 ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
@@ -309,7 +320,8 @@ LogicalResult GroupNonUniformIMulOp::verify() {
 
 ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
@@ -326,7 +338,8 @@ LogicalResult GroupNonUniformSMaxOp::verify() {
 
 ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
@@ -343,7 +356,8 @@ LogicalResult GroupNonUniformSMinOp::verify() {
 
 ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
@@ -360,7 +374,8 @@ LogicalResult GroupNonUniformUMaxOp::verify() {
 
 ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
@@ -377,7 +392,8 @@ LogicalResult GroupNonUniformUMinOp::verify() {
 
 ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
+                                                                 result);
 }
 
 void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
@@ -394,7 +410,8 @@ LogicalResult GroupNonUniformBitwiseAndOp::verify() {
 
 ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
@@ -411,7 +428,8 @@ LogicalResult GroupNonUniformBitwiseOrOp::verify() {
 
 ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
+                                                                      result);
 }
 
 void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
@@ -428,7 +446,8 @@ LogicalResult GroupNonUniformBitwiseXorOp::verify() {
 
 ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
@@ -445,7 +464,8 @@ LogicalResult GroupNonUniformLogicalAndOp::verify() {
 
 ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
@@ -462,7 +482,8 @@ LogicalResult GroupNonUniformLogicalOrOp::verify() {
 
 ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
                                               OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
+                                                                      result);
 }
 
 void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
@@ -479,7 +500,8 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
 
 ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser, result);
+  return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
+                                                                       result);
 }
 
 void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 0ef1e7bc0ca8be..fd7cfa9b0d4bbf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -87,6 +87,7 @@ getIntegerDotProductExtensions() {
   return {extension};
 }
 
+template <typename IntegerDotProductOpTy>
 static SmallVector<ArrayRef<spirv::Capability>, 1>
 getIntegerDotProductCapabilities(Operation *op) {
   // Requires the the DotProduct capability and capabilities that depend on
@@ -103,7 +104,7 @@ getIntegerDotProductCapabilities(Operation *op) {
 
   Type factorTy = op->getOperand(0).getType();
   StringAttr packedVectorFormatAttrName =
-      SDotAccSatOp::getFormatAttrName(op->getName());
+      IntegerDotProductOpTy::getFormatAttrName(op->getName());
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
         op->getAttr(packedVectorFormatAttrName));
@@ -125,12 +126,14 @@ getIntegerDotProductCapabilities(Operation *op) {
 }
 
 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)                              \
-  LogicalResult OpName::verify() { return verifyIntegerDotProduct<OpName>(*this); }    \
+  LogicalResult OpName::verify() {                                             \
+    return verifyIntegerDotProduct<OpName>(*this);                             \
+  }                                                                            \
   SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() {         \
     return getIntegerDotProductExtensions();                                   \
   }                                                                            \
   SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() {      \
-    return getIntegerDotProductCapabilities(*this);                            \
+    return getIntegerDotProductCapabilities<OpName>(*this);                            \
   }                                                                            \
   std::optional<spirv::Version> OpName::getMinVersion() {                      \
     return getIntegerDotProductMinVersion();                                   \
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 25afebf9f92efc..a594ae9704f3c5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -31,7 +31,7 @@ namespace mlir::spirv {
 /// where:
 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
 ///         integer-literal | `"NonTemporal"`
-template<typename MemoryOpTy>
+template <typename MemoryOpTy>
 ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                         OperationState &state) {
   // Parse an optional list of attributes staring with '['
@@ -41,9 +41,10 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  StringAttr memoryAccessAttrName = MemoryOpTy::getMemoryAccessAttrName(state.name);
-  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
-                                                       state, memoryAccessAttrName))
+  StringAttr memoryAccessAttrName =
+      MemoryOpTy::getMemoryAccessAttrName(state.name);
+  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -53,10 +54,8 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
     StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(
-            alignmentAttr, i32Type,
-            alignmentAttrName,
-            state.attributes)) {
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
+                              state.attributes)) {
       return failure();
     }
   }
@@ -67,7 +66,7 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
 // parameterized by memory access attribute name and alignment. Doing so now
 // results in VS2017 in producing an internal error (at the call site) that's
 // not detailed enough to understand what is happening.
-template<typename MemoryOpTy>
+template <typename MemoryOpTy>
 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                      OperationState &state) {
   // Parse an optional list of attributes staring with '['
@@ -78,7 +77,7 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
 
   spirv::MemoryAccess memoryAccessAttr;
   StringRef memoryAccessAttrName =
-      MemoryOpTy::getMemoryAccessAttrName(state.name).strref();
+      MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
           memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
@@ -115,7 +114,7 @@ static void printSourceMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
+    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -123,7 +122,7 @@ static void printSourceMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
+        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
         printer << ", " << *alignment;
       }
     }
@@ -141,7 +140,7 @@ static void printMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
+    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -149,7 +148,7 @@ static void printMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
+        elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
         printer << ", " << *alignment;
       }
     }
@@ -179,11 +178,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
+  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -200,11 +199,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+    if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -224,11 +223,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // present.
   auto *op = memoryOp.getOperation();
   auto memAccessAttr =
-      op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+      op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -245,11 +244,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -469,8 +468,8 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
   Type elementType;
   if (parseEnumStrAttr(storageClass, parser) ||
       parser.parseOperandList(operandInfo, 2) ||
-      parseMemoryAccessAttributes<StoreOp>(parser, result) || parser.parseColon() ||
-      parser.parseType(elementType)) {
+      parseMemoryAccessAttributes<StoreOp>(parser, result) ||
+      parser.parseColon() || parser.parseType(elementType)) {
     return failure();
   }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 18c9e25227845c..858b94f7be8b08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -22,9 +22,9 @@ namespace mlir::spirv {
 namespace AttrNames {
 
 inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
-inline constexpr char kControl[] = "control"; // no ODS generation
-inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
-inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+inline constexpr char kControl[] = "control";          // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn";        // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id";   // no ODS generation
 
 } // namespace AttrNames
 

>From 0c41cd1d73e38e438eda273246becba29b884d75 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Thu, 15 Feb 2024 13:27:22 -0500
Subject: [PATCH 4/4] strref

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5d6ddaa876575c..7cdd3fa7a19fdd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -458,7 +458,7 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
   StringRef indicesAttrName =
-      spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+      spirv::CompositeExtractOp::getIndicesAttrName(result.name);
   Type compositeType;
   SMLoc attrLocation;
 
@@ -516,7 +516,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   Type objectType, compositeType;
   Attribute indicesAttr;
   StringRef indicesAttrName =
-      spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+      spirv::CompositeInsertOp::getIndicesAttrName(result.name);
   auto loc = parser.getCurrentLocation();
 
   return failure(
@@ -564,7 +564,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
   StringRef valueAttrName =
-      spirv::ConstantOp::getValueAttrName(result.name).strref();
+      spirv::ConstantOp::getValueAttrName(result.name);
   if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
@@ -829,7 +829,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
       return failure();
   }
   result.addAttribute(
-      spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+      spirv::EntryPointOp::getInterfaceAttrName(result.name),
       parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
@@ -883,7 +883,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
   StringRef valuesAttrName =
-      spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+      spirv::ExecutionModeOp::getValuesAttrName(result.name);
   result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
@@ -1160,7 +1160,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   // Parse variable name.
   StringAttr nameAttr;
   StringRef initializerAttrName =
-      spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+      spirv::GlobalVariableOp::getInitializerAttrName(result.name);
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
@@ -1182,7 +1182,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
 
   Type type;
   StringRef typeAttrName =
-      spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+      spirv::GlobalVariableOp::getTypeAttrName(result.name);
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1204,7 +1204,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getSymName());
   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
 
-  StringRef initializerAttrName = this->getInitializerAttrName().strref();
+  StringRef initializerAttrName = this->getInitializerAttrName();
   // Print optional initializer
   if (auto initializer = this->getInitializer()) {
     printer << " " << initializerAttrName << '(';
@@ -1213,7 +1213,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
     elidedAttrs.push_back(initializerAttrName);
   }
 
-  StringRef typeAttrName = this->getTypeAttrName().strref();
+  StringRef typeAttrName = this->getTypeAttrName();
   elidedAttrs.push_back(typeAttrName);
   spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
@@ -1235,7 +1235,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
   }
 
   if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
-          this->getInitializerAttrName().strref())) {
+          this->getInitializerAttrName())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1610,7 +1610,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   StringAttr nameAttr;
   Attribute valueAttr;
   StringRef defaultValueAttrName =
-      spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+      spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))



More information about the Mlir-commits mailing list