[Mlir-commits] [mlir] replaces attribute names in SPIRV dialect parsing code (PR #81552)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 12 16:06:34 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: None (tw-ilson)
<details>
<summary>Changes</summary>
---
Patch is 25.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81552.diff
9 Files Affected:
- (modified) mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp (+2-1)
- (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+3-3)
- (modified) mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp (+6-6)
- (modified) mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp (+5-3)
- (modified) mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp (+16-14)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+26-15)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp (+3-1)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h (+28-27)
- (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+1-1)
``````````diff
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..d84133d5933415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
+ StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
auto memorySemantics =
- op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..650b5448f041b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,7 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();
- result.addAttribute(kBranchWeightAttrName,
+ result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
builder.getArrayAttr({trueWeight, falseWeight}));
}
@@ -199,11 +199,11 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+ return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+ (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..ea82f8442dcfe0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -26,9 +26,9 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- kExecutionScopeAttrName) ||
+ ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
- kGroupOperationAttrName) ||
+ GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
parser.parseOperand(valueInfo))
return failure();
@@ -61,11 +61,11 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
printer
<< " \""
<< stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
.getValue())
<< "\" \""
<< stringifyGroupOperation(
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
.getValue())
<< "\" " << groupOp->getOperand(0);
@@ -76,14 +76,14 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..3abc4361321d04 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,10 +33,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
// ODS enforces that vector 1 and vector 2, and result and the accumulator
// have the same types.
Type factorTy = op->getOperand(0).getType();
+ StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto packedVectorFormat =
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (!packedVectorFormat)
return op->emitOpError("requires Packed Vector Format attribute for "
"integer vector operands");
@@ -50,7 +51,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
"integer vector operands to be 32-bits wide",
packedVectorFormat.getValue()));
} else {
- if (op->hasAttr(kPackedVectorFormatAttrName))
+ if (op->hasAttr(packedVectorFormatAttrName))
return op->emitOpError(llvm::formatv(
"with invalid format attribute for vector operands of type '{0}'",
factorTy));
@@ -99,9 +100,10 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
+ StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (formatAttr.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
capabilities.push_back(dotProductInput4x8BitPackedCap);
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..eb7900fd6ef8fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,17 +38,19 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
+ StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
- memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+ memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
+ StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+ parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
@@ -72,7 +74,7 @@ static void printSourceMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -80,7 +82,7 @@ static void printSourceMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kSourceAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
printer << ", " << *alignment;
}
}
@@ -98,7 +100,7 @@ static void printMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -106,7 +108,7 @@ static void printMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
printer << ", " << *alignment;
}
}
@@ -136,11 +138,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -157,11 +159,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -180,11 +182,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -201,11 +203,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kSourceAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..c27df9e30f24e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,13 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
+ StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
Type compositeType;
SMLoc attrLocation;
if (parser.parseOperand(compositeInfo) ||
parser.getCurrentLocation(&attrLocation) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(compositeType) ||
parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
return failure();
@@ -513,11 +514,12 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
+ StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
return failure(
parser.parseOperandList(operands, 2) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(objectType) ||
parser.parseKeywordType("into", compositeType) ||
parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +561,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
- if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+ StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+ if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
Type type = NoneType::get(parser.getContext());
@@ -822,7 +825,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}))
return failure();
}
- result.addAttribute(kInterfaceAttrName,
+ result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -875,7 +878,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
- result.addAttribute(kValuesAttrName,
+ StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+ result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
}
@@ -1150,16 +1154,17 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
+ StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
}
// Parse optional initializer
- if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+ if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
- parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+ parser.parseAttribute(initSymbol, Type(), initializerAttrName,
result.attributes) ||
parser.parseRParen())
return failure();
@@ -1170,6 +1175,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
}
Type type;
+ StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1177,7 +1183,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
if (!llvm::isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
@@ -1191,15 +1197,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+ StringRef initializerAttrName = this->getInitializerAttrName().strref();
// Print optional initializer
if (auto initializer = this->getInitializer()) {
- printer << " " << kInitializerAttrName << '(';
+ printer << " " << initializerAttrName << '(';
printer.printSymbolName(*initializer);
printer << ')';
- elidedAttrs.push_back(kInitializerAttrName);
+ elidedAttrs.push_back(initializerAttrName);
}
- elidedAttrs.push_back(kTypeAttrName);
+ StringRef typeAttrName = this->getTypeAttrName().strref();
+ elidedAttrs.push_back(typeAttrName);
spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
@@ -1220,7 +1228,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
}
if (auto init =
- (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+ (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1594,6 +1602,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
+ StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -1609,7 +1618,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
}
if (parser.parseEqual() ||
- parser.parseAttribute(valueAttr, kDefaultValueAttrName,
+ parser.parseAttribute(valueAttr, defaultValueAttrName,
result.attributes))
return failure();
@@ -1783,14 +1792,16 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseRParen())
return failure();
- result.addAttribute(kCompositeSpecConstituentsName,
+ StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+ result.addAttribute(compositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
Type type;
if (parser.parseColonType(type))
return failure();
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
diff --git a/mlir/...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81552
More information about the Mlir-commits
mailing list