[Mlir-commits] [mlir] replaces hardcoded attribute names in SPIRV dialect parsing code (PR #81552)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 12 16:08:40 PST 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 07bf1ddb4eb0abfff20542fd4459bace1f72107f c794f24655d131fb698a7cf2310b1f50259b7d16 -- mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index d84133d593..2f2f510004 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,7 +42,8 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              << stringifyTypeName<ExpectedElementType>()
                              << " value, found " << elementType;
 
-  StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
+  StringRef semanticsAttrName =
+      spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
   auto memorySemantics =
       op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
           .getValue();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 650b5448f0..a8b61000b6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,8 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
         parser.parseRSquare())
       return failure();
 
-    result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
-                        builder.getArrayAttr({trueWeight, falseWeight}));
+    result.addAttribute(
+        BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
+        builder.getArrayAttr({trueWeight, falseWeight}));
   }
 
   // Parse the true branch.
@@ -199,11 +200,14 @@ LogicalResult FunctionCallOp::verify() {
 }
 
 CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
-  return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
+  return (*this)->getAttrOfType<SymbolRefAttr>(
+      FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
 }
 
 void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
-  (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
+  (*this)->setAttr(
+      FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(),
+      callee.get<SymbolRefAttr>());
 }
 
 Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ea82f8442d..643b220ba4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -25,10 +25,12 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   spirv::Scope executionScope;
   GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
-  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
-                                                ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
-      spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
-                                                  GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
+  if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+          executionScope, parser, state,
+          ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
+      spirv::parseEnumStrAttr<GroupOperationAttr>(
+          groupOperation, parser, state,
+          GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -58,16 +60,22 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
 
 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                              OpAsmPrinter &printer) {
-  printer
-      << " \""
-      << stringifyScope(
-             groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
-                 .getValue())
-      << "\" \""
-      << stringifyGroupOperation(
-             groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
-                 .getValue())
-      << "\" " << groupOp->getOperand(0);
+  printer << " \""
+          << stringifyScope(groupOp
+                                ->getAttrOfType<spirv::ScopeAttr>(
+                                    ControlBarrierOp::getExecutionScopeAttrName(
+                                        groupOp->getName())
+                                        .strref())
+                                .getValue())
+          << "\" \""
+          << stringifyGroupOperation(
+                 groupOp
+                     ->getAttrOfType<GroupOperationAttr>(
+                         GroupFAddOp::getGroupOperationAttrName(
+                             groupOp->getName())
+                             .strref())
+                     .getValue())
+          << "\" " << groupOp->getOperand(0);
 
   if (groupOp->getNumOperands() > 1)
     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -76,14 +84,20 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
 
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
   spirv::Scope scope =
-      groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
+      groupOp
+          ->getAttrOfType<spirv::ScopeAttr>(
+              ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName())
+                  .strref())
           .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
   GroupOperation operation =
-      groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
+      groupOp
+          ->getAttrOfType<GroupOperationAttr>(
+              GroupFAddOp::getGroupOperationAttrName(groupOp->getName())
+                  .strref())
           .getValue();
   if (operation == GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 3abc436132..17c63d86e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,7 +33,8 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
   // ODS enforces that vector 1 and vector 2, and result and the accumulator
   // have the same types.
   Type factorTy = op->getOperand(0).getType();
-  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+  StringRef packedVectorFormatAttrName =
+      SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto packedVectorFormat =
         llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
@@ -100,7 +101,8 @@ getIntegerDotProductCapabilities(Operation *op) {
   SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
 
   Type factorTy = op->getOperand(0).getType();
-  StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+  StringRef packedVectorFormatAttrName =
+      SDotAccSatOp::getFormatAttrName(op->getName()).strref();
   if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
     auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
         op->getAttr(packedVectorFormatAttrName));
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index eb7900fd6e..8dd28cfb0c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,7 +38,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
+  StringRef memoryAccessAttrName =
+      CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
   if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
           memoryAccessAttr, parser, state, memoryAccessAttrName))
     return failure();
@@ -47,7 +48,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                 spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
     Attribute alignmentAttr;
-    StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
+    StringRef alignmentAttrName =
+        CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
         parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
@@ -182,7 +184,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
   // memory-access attribute is Aligned, then the alignment attribute must be
   // present.
   auto *op = memoryOp.getOperation();
-  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+  auto memAccessAttr =
+      op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
   if (!memAccessAttr) {
     // Alignment attribute shouldn't be present if memory access attribute is
     // not present.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c27df9e30f..535d4e0489 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,7 +457,8 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                              OperationState &result) {
   OpAsmParser::UnresolvedOperand compositeInfo;
   Attribute indicesAttr;
-  StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+  StringRef indicesAttrName =
+      spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
   Type compositeType;
   SMLoc attrLocation;
 
@@ -514,7 +515,8 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
   Type objectType, compositeType;
   Attribute indicesAttr;
-  StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+  StringRef indicesAttrName =
+      spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
 
   return failure(
@@ -561,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
   Attribute value;
-  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+  StringRef valueAttrName =
+      spirv::ConstantOp::getValueAttrName(result.name).strref();
   if (parser.parseAttribute(value, valueAttrName, result.attributes))
     return failure();
 
@@ -825,8 +828,9 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
         }))
       return failure();
   }
-  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
-                      parser.getBuilder().getArrayAttr(interfaceVars));
+  result.addAttribute(
+      spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+      parser.getBuilder().getArrayAttr(interfaceVars));
   return success();
 }
 
@@ -878,7 +882,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
     }
     values.push_back(llvm::cast<IntegerAttr>(value).getInt());
   }
-  StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+  StringRef valuesAttrName =
+      spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
   result.addAttribute(valuesAttrName,
                       parser.getBuilder().getI32ArrayAttr(values));
   return success();
@@ -1154,7 +1159,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
   // Parse variable name.
   StringAttr nameAttr;
-  StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+  StringRef initializerAttrName =
+      spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes)) {
     return failure();
@@ -1175,7 +1181,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
   }
 
   Type type;
-  StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+  StringRef typeAttrName =
+      spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
   auto loc = parser.getCurrentLocation();
   if (parser.parseColonType(type)) {
     return failure();
@@ -1227,8 +1234,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
            << stringifyStorageClass(storageClass) << "'";
   }
 
-  if (auto init =
-          (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
+  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+          this->getInitializerAttrName().strref())) {
     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
         (*this)->getParentOp(), init.getAttr());
     // TODO: Currently only variable initialization with specialization
@@ -1602,7 +1609,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
   StringAttr nameAttr;
   Attribute valueAttr;
-  StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+  StringRef defaultValueAttrName =
+      spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
 
   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                              result.attributes))
@@ -1618,8 +1626,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
   }
 
   if (parser.parseEqual() ||
-      parser.parseAttribute(valueAttr, defaultValueAttrName,
-                            result.attributes))
+      parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
     return failure();
 
   return success();
@@ -1792,7 +1799,9 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseRParen())
     return failure();
 
-  StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+  StringRef compositeSpecConstituentsName =
+      spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name)
+          .strref();
   result.addAttribute(compositeSpecConstituentsName,
                       parser.getBuilder().getArrayAttr(constituents));
 
@@ -1800,7 +1809,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
   if (parser.parseColonType(type))
     return failure();
 
-  StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+  StringRef typeAttrName =
+      spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
   result.addAttribute(typeAttrName, TypeAttr::get(type));
 
   return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index f48b0b93ab..7ddb0370ca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -40,9 +40,10 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
     Attribute alignmentAttr;
     Type i32Type = parser.getBuilder().getIntegerType(32);
     if (parser.parseComma() ||
-        parser.parseAttribute(alignmentAttr, i32Type,
-                              CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
-                              state.attributes)) {
+        parser.parseAttribute(
+            alignmentAttr, i32Type,
+            CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
+            state.attributes)) {
       return failure();
     }
   }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index d4c0deef6a..9349f63eb6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,11 +20,12 @@
 
 namespace mlir::spirv {
 namespace AttrNames {
-inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+inline constexpr char kMemoryAccessAttrName[] =
+    "memory_access"; // need hardcoded string below -- probably can change
 inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
-inline constexpr char kControl[] = "control"; // no ODS generation
-inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
-inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+inline constexpr char kControl[] = "control";          // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn";        // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id";   // no ODS generation
 
 // TODO: generate these strings using ODS.
 // inline constexpr char kAlignmentAttrName[] = "alignment";
@@ -37,18 +38,18 @@ inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
 // inline constexpr char kIndicesAttrName[] = "indices";
 // inline constexpr char kInitializerAttrName[] = "initializer";
 // inline constexpr char kInterfaceAttrName[] = "interface";
-// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-// inline constexpr char kPackedVectorFormatAttrName[] = "format";
-// inline constexpr char kSemanticsAttrName[] = "semantics";
-// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-// inline constexpr char kTypeAttrName[] = "type";
-// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-// inline constexpr char kValueAttrName[] = "value";
-// inline constexpr char kValuesAttrName[] = "values";
-// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] =
+// "matrix_layout"; inline constexpr char kMemoryOperandAttrName[] =
+// "memory_operand"; inline constexpr char kMemoryScopeAttrName[] =
+// "memory_scope"; inline constexpr char kPackedVectorFormatAttrName[] =
+// "format"; inline constexpr char kSemanticsAttrName[] = "semantics"; inline
+// constexpr char kSourceAlignmentAttrName[] = "source_alignment"; inline
+// constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; inline
+// constexpr char kTypeAttrName[] = "type"; inline constexpr char
+// kUnequalSemanticsAttrName[] = "unequal_semantics"; inline constexpr char
+// kValueAttrName[] = "value"; inline constexpr char kValuesAttrName[] =
+// "values"; inline constexpr char kCompositeSpecConstituentsName[] =
+// "constituents";
 } // namespace AttrNames
 
 template <typename Ty>

``````````

</details>


https://github.com/llvm/llvm-project/pull/81552


More information about the Mlir-commits mailing list