[Mlir-commits] [mlir] replaces attribute names in SPIRV dialect parsing code (PR #81552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 12 16:05:45 PST 2024


https://github.com/tw-ilson created https://github.com/llvm/llvm-project/pull/81552

None

>From c794f24655d131fb698a7cf2310b1f50259b7d16 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Mon, 12 Feb 2024 00:31:59 -0500
Subject: [PATCH] replaces attribute names in SPIRV parsing code

---
 mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp       |  3 +-
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  |  6 +-
 mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp        | 12 ++--
 .../Dialect/SPIRV/IR/IntegerDotProductOps.cpp |  8 ++-
 mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp       | 30 +++++-----
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 41 +++++++++-----
 .../Dialect/SPIRV/IR/SPIRVParsingUtils.cpp    |  4 +-
 mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h | 55 ++++++++++---------
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   |  2 +-
 9 files changed, 90 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..d84133d5933415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
+  StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
   auto memorySemantics =
-      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
   if (failed(verifyMemorySemantics(op, memorySemantics))) {
     return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..650b5448f041b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,7 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(kBranchWeightAttrName,
+    result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
                         builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
@@ -199,11 +199,11 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+  return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..ea82f8442dcfe0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -26,9 +26,9 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
   if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                kExecutionScopeAttrName) ||
+                                                ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
       spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  kGroupOperationAttrName) ||
+                                                  GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -61,11 +61,11 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
   printer
       << " \""
       << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+             groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" \""
       << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+             groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
                  .getValue())
       << "\" " << groupOp->getOperand(0);
 
@@ -76,14 +76,14 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
 
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+      groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+      groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..3abc4361321d04 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,10 +33,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
-            op->getAttr(kPackedVectorFormatAttrName));
+            op->getAttr(packedVectorFormatAttrName));
     if (!packedVectorFormat)
       return op->emitOpError("requires Packed Vector Format attribute for "
                              "integer vector operands");
@@ -50,7 +51,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
                         "integer vector operands to be 32-bits wide",
                         packedVectorFormat.getValue()));
   } else {
-    if (op->hasAttr(kPackedVectorFormatAttrName))
+    if (op->hasAttr(packedVectorFormatAttrName))
       return op->emitOpError(llvm::formatv(
           "with invalid format attribute for vector operands of type '{0}'",
           factorTy));
@@ -99,9 +100,10 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
+  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
-        op->getAttr(kPackedVectorFormatAttrName));
+        op->getAttr(packedVectorFormatAttrName));
     if (formatAttr.getValue() ==
         spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
       capabilities.push_back(dotProductInput4x8BitPackedCap);
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..eb7900fd6ef8fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,17 +38,19 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
+  StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
-          memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+          memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
 
   if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
+    StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                               state.attributes)) {
       return failure();
     }
@@ -72,7 +74,7 @@ static void printSourceMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -80,7 +82,7 @@ static void printSourceMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kSourceAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -98,7 +100,7 @@ static void printMemoryAccessAttribute(
   // Print optional memory access attribute.
   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                               : memoryOp.getMemoryAccess())) {
-    elidedAttrs.push_back(kMemoryAccessAttrName);
+    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
 
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
 
@@ -106,7 +108,7 @@ static void printMemoryAccessAttribute(
       // Print integer alignment attribute.
       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                                : memoryOp.getAlignment())) {
-        elidedAttrs.push_back(kAlignmentAttrName);
+        elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
         printer << ", " << *alignment;
       }
     }
@@ -136,11 +138,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -157,11 +159,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
@@ -180,11 +182,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification without aligned memory access "
           "specification");
@@ -201,11 +203,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
 
   if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                 spirv::MemoryAccess::Aligned)) {
-    if (!op->getAttr(kSourceAlignmentAttrName)) {
+    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError("missing alignment value");
     }
   } else {
-    if (op->getAttr(kSourceAlignmentAttrName)) {
+    if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
       return memoryOp.emitOpError(
           "invalid alignment specification with non-aligned memory access "
           "specification");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..c27df9e30f24e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,13 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
   Type compositeType;
   SMLoc attrLocation;
 
   if (parser.parseOperand(compositeInfo) ||
       parser.getCurrentLocation(&attrLocation) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(compositeType) ||
       parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
     return failure();
@@ -513,11 +514,12 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
+  StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
 
   return failure(
       parser.parseOperandList(operands, 2) ||
-      parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
       parser.parseColonType(objectType) ||
       parser.parseKeywordType("into", compositeType) ||
       parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +561,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+  if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
   Type type = NoneType::get(parser.getContext());
@@ -822,7 +825,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(kInterfaceAttrName,
+  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
                       parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
@@ -875,7 +878,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  result.addAttribute(kValuesAttrName,
+  StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+  result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
 }
@@ -1150,16 +1154,17 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
+  StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
   }
 
   // Parse optional initializer
-  if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
     FlatSymbolRefAttr initSymbol;
     if (parser.parseLParen() ||
-        parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+        parser.parseAttribute(initSymbol, Type(), initializerAttrName,
                               result.attributes) ||
         parser.parseRParen())
       return failure();
@@ -1170,6 +1175,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
+  StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1177,7 +1183,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   if (!llvm::isa<spirv::PointerType>(type)) {
     return parser.emitError(loc, "expected spirv.ptr type");
   }
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
@@ -1191,15 +1197,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getSymName());
   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
 
+  StringRef initializerAttrName = this->getInitializerAttrName().strref();
   // Print optional initializer
   if (auto initializer = this->getInitializer()) {
-    printer << " " << kInitializerAttrName << '(';
+    printer << " " << initializerAttrName << '(';
     printer.printSymbolName(*initializer);
     printer << ')';
-    elidedAttrs.push_back(kInitializerAttrName);
+    elidedAttrs.push_back(initializerAttrName);
   }
 
-  elidedAttrs.push_back(kTypeAttrName);
+  StringRef typeAttrName = this->getTypeAttrName().strref();
+  elidedAttrs.push_back(typeAttrName);
   spirv::printVariableDecorations(*this, printer, elidedAttrs);
   printer << " : " << getType();
 }
@@ -1220,7 +1228,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
   }
 
   if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+          (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1594,6 +1602,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
+  StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1609,7 +1618,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, kDefaultValueAttrName,
+      parser.parseAttribute(valueAttr, defaultValueAttrName,
                             result.attributes))
     return failure();
 
@@ -1783,14 +1792,16 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  result.addAttribute(kCompositeSpecConstituentsName,
+  StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+  result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
   Type type;
   if (parser.parseColonType(type))
     return failure();
 
-  result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+  StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+  result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 27c373300aee8e..f48b0b93ab5fe8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,6 +12,8 @@
 
 #include "SPIRVParsingUtils.h"
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
 #include "llvm/ADT/StringExtras.h"
 
 using namespace mlir::spirv::AttrNames;
@@ -39,7 +41,7 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
         parser.parseAttribute(alignmentAttr, i32Type,
-                              AttrNames::kAlignmentAttrName,
+                              CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
                               state.attributes)) {
       return failure();
     }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 625c82f6e8e899..d4c0deef6aae3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,34 +20,35 @@
 
 namespace mlir::spirv {
 namespace AttrNames {
+inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+
 // TODO: generate these strings using ODS.
-inline constexpr char kAlignmentAttrName[] = "alignment";
-inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-inline constexpr char kCallee[] = "callee";
-inline constexpr char kClusterSize[] = "cluster_size";
-inline constexpr char kControl[] = "control";
-inline constexpr char kDefaultValueAttrName[] = "default_value";
-inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-inline constexpr char kFnNameAttrName[] = "fn";
-inline constexpr char kGroupOperationAttrName[] = "group_operation";
-inline constexpr char kIndicesAttrName[] = "indices";
-inline constexpr char kInitializerAttrName[] = "initializer";
-inline constexpr char kInterfaceAttrName[] = "interface";
-inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-inline constexpr char kMemoryAccessAttrName[] = "memory_access";
-inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-inline constexpr char kPackedVectorFormatAttrName[] = "format";
-inline constexpr char kSemanticsAttrName[] = "semantics";
-inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-inline constexpr char kSpecIdAttrName[] = "spec_id";
-inline constexpr char kTypeAttrName[] = "type";
-inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-inline constexpr char kValueAttrName[] = "value";
-inline constexpr char kValuesAttrName[] = "values";
-inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+// inline constexpr char kAlignmentAttrName[] = "alignment";
+// inline constexpr char kBranchWeightAttrName[] = "branch_weights";
+// inline constexpr char kCallee[] = "callee";
+// inline constexpr char kDefaultValueAttrName[] = "default_value";
+// inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+// inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
+// inline constexpr char kGroupOperationAttrName[] = "group_operation";
+// inline constexpr char kIndicesAttrName[] = "indices";
+// inline constexpr char kInitializerAttrName[] = "initializer";
+// inline constexpr char kInterfaceAttrName[] = "interface";
+// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
+// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
+// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
+// inline constexpr char kPackedVectorFormatAttrName[] = "format";
+// inline constexpr char kSemanticsAttrName[] = "semantics";
+// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
+// inline constexpr char kTypeAttrName[] = "type";
+// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
+// inline constexpr char kValueAttrName[] = "value";
+// inline constexpr char kValuesAttrName[] = "values";
+// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
 } // namespace AttrNames
 
 template <typename Ty>
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 71326049af0579..d34b367e98e3d2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1142,7 +1142,7 @@ void OpEmitter::genAttrNameGetters() {
     } else {
       const char *const getAttrName = R"(
   assert(index < {0} && "invalid attribute index");
-  assert(name.getStringRef() == getOperationName() && "invalid operation name");
+  // assert(name.getStringRef() == getOperationName() && "invalid operation name");
   assert(name.isRegistered() && "Operation isn't registered, missing a "
         "dependent dialect loading?");
   return name.getAttributeNames()[index];



More information about the Mlir-commits mailing list