[Mlir-commits] [mlir] replaces hardcoded attribute names in SPIRV dialect parsing code (PR #81552)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 12 16:08:40 PST 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 07bf1ddb4eb0abfff20542fd4459bace1f72107f c794f24655d131fb698a7cf2310b1f50259b7d16 -- mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index d84133d593..2f2f510004 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,7 +42,8 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
- StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
+ StringRef semanticsAttrName =
+ spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
auto memorySemantics =
op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 650b5448f0..a8b61000b6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,8 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();
- result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
- builder.getArrayAttr({trueWeight, falseWeight}));
+ result.addAttribute(
+ BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
+ builder.getArrayAttr({trueWeight, falseWeight}));
}
// Parse the true branch.
@@ -199,11 +200,14 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
+ return (*this)->getAttrOfType<SymbolRefAttr>(
+ FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
+ (*this)->setAttr(
+ FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(),
+ callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ea82f8442d..643b220ba4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -25,10 +25,12 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
spirv::Scope executionScope;
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
- spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
- GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+ executionScope, parser, state,
+ ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
+ spirv::parseEnumStrAttr<GroupOperationAttr>(
+ groupOperation, parser, state,
+ GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
parser.parseOperand(valueInfo))
return failure();
@@ -58,16 +60,22 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
- printer
- << " \""
- << stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
- .getValue())
- << "\" \""
- << stringifyGroupOperation(
- groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
- .getValue())
- << "\" " << groupOp->getOperand(0);
+ printer << " \""
+ << stringifyScope(groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ ControlBarrierOp::getExecutionScopeAttrName(
+ groupOp->getName())
+ .strref())
+ .getValue())
+ << "\" \""
+ << stringifyGroupOperation(
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupFAddOp::getGroupOperationAttrName(
+ groupOp->getName())
+ .strref())
+ .getValue())
+ << "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -76,14 +84,20 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
+ groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName())
+ .strref())
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
- groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupFAddOp::getGroupOperationAttrName(groupOp->getName())
+ .strref())
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 3abc436132..17c63d86e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,7 +33,8 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
// ODS enforces that vector 1 and vector 2, and result and the accumulator
// have the same types.
Type factorTy = op->getOperand(0).getType();
- StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+ StringRef packedVectorFormatAttrName =
+ SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto packedVectorFormat =
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
@@ -100,7 +101,8 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
- StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+ StringRef packedVectorFormatAttrName =
+ SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
op->getAttr(packedVectorFormatAttrName));
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index eb7900fd6e..8dd28cfb0c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,7 +38,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
+ StringRef memoryAccessAttrName =
+ CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
@@ -47,7 +48,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
- StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
+ StringRef alignmentAttrName =
+ CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
@@ -182,7 +184,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+ auto memAccessAttr =
+ op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c27df9e30f..535d4e0489 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,7 +457,8 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
- StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+ StringRef indicesAttrName =
+ spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
Type compositeType;
SMLoc attrLocation;
@@ -514,7 +515,8 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
- StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+ StringRef indicesAttrName =
+ spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
return failure(
@@ -561,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
- StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+ StringRef valueAttrName =
+ spirv::ConstantOp::getValueAttrName(result.name).strref();
if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
@@ -825,8 +828,9 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}))
return failure();
}
- result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
- parser.getBuilder().getArrayAttr(interfaceVars));
+ result.addAttribute(
+ spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+ parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -878,7 +882,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
- StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+ StringRef valuesAttrName =
+ spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
@@ -1154,7 +1159,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
- StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+ StringRef initializerAttrName =
+ spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
@@ -1175,7 +1181,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
}
Type type;
- StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+ StringRef typeAttrName =
+ spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1227,8 +1234,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
<< stringifyStorageClass(storageClass) << "'";
}
- if (auto init =
- (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
+ if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+ this->getInitializerAttrName().strref())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1602,7 +1609,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
- StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+ StringRef defaultValueAttrName =
+ spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -1618,8 +1626,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
}
if (parser.parseEqual() ||
- parser.parseAttribute(valueAttr, defaultValueAttrName,
- result.attributes))
+ parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
return failure();
return success();
@@ -1792,7 +1799,9 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseRParen())
return failure();
- StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+ StringRef compositeSpecConstituentsName =
+ spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name)
+ .strref();
result.addAttribute(compositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
@@ -1800,7 +1809,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseColonType(type))
return failure();
- StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+ StringRef typeAttrName =
+ spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index f48b0b93ab..7ddb0370ca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -40,9 +40,10 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
Attribute alignmentAttr;
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type,
- CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
- state.attributes)) {
+ parser.parseAttribute(
+ alignmentAttr, i32Type,
+ CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
+ state.attributes)) {
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index d4c0deef6a..9349f63eb6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,11 +20,12 @@
namespace mlir::spirv {
namespace AttrNames {
-inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+inline constexpr char kMemoryAccessAttrName[] =
+ "memory_access"; // need hardcoded string below -- probably can change
inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
-inline constexpr char kControl[] = "control"; // no ODS generation
-inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
-inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
// TODO: generate these strings using ODS.
// inline constexpr char kAlignmentAttrName[] = "alignment";
@@ -37,18 +38,18 @@ inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
// inline constexpr char kIndicesAttrName[] = "indices";
// inline constexpr char kInitializerAttrName[] = "initializer";
// inline constexpr char kInterfaceAttrName[] = "interface";
-// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-// inline constexpr char kPackedVectorFormatAttrName[] = "format";
-// inline constexpr char kSemanticsAttrName[] = "semantics";
-// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-// inline constexpr char kTypeAttrName[] = "type";
-// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-// inline constexpr char kValueAttrName[] = "value";
-// inline constexpr char kValuesAttrName[] = "values";
-// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] =
+// "matrix_layout"; inline constexpr char kMemoryOperandAttrName[] =
+// "memory_operand"; inline constexpr char kMemoryScopeAttrName[] =
+// "memory_scope"; inline constexpr char kPackedVectorFormatAttrName[] =
+// "format"; inline constexpr char kSemanticsAttrName[] = "semantics"; inline
+// constexpr char kSourceAlignmentAttrName[] = "source_alignment"; inline
+// constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; inline
+// constexpr char kTypeAttrName[] = "type"; inline constexpr char
+// kUnequalSemanticsAttrName[] = "unequal_semantics"; inline constexpr char
+// kValueAttrName[] = "value"; inline constexpr char kValuesAttrName[] =
+// "values"; inline constexpr char kCompositeSpecConstituentsName[] =
+// "constituents";
} // namespace AttrNames
template <typename Ty>
``````````
</details>
https://github.com/llvm/llvm-project/pull/81552
More information about the Mlir-commits
mailing list