[Mlir-commits] [mlir] 1d5e3b2 - [mlir][spirv] Use ODS generated attribute names for op definitions (#81552)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 25 16:47:29 PST 2024
Author: tw-ilson
Date: 2024-02-25T16:47:25-08:00
New Revision: 1d5e3b2d6559a853c544099e4cf1d46f44f83368
URL: https://github.com/llvm/llvm-project/commit/1d5e3b2d6559a853c544099e4cf1d46f44f83368
DIFF: https://github.com/llvm/llvm-project/commit/1d5e3b2d6559a853c544099e4cf1d46f44f83368.diff
LOG: [mlir][spirv] Use ODS generated attribute names for op definitions (#81552)
Since ODS generates getters functions for SPIRV operations' attribute
names, we replace instances of these hardcoded strings in the SPIR-V
dialect's op parser/printer with function calls for consistency.
Fixes https://github.com/llvm/llvm-project/issues/77627
---------
Co-authored-by: Lei Zhang <antiagainst at gmail.com>
Added:
Modified:
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..948d48980f2e84 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
}
// Verifies an atomic update op.
-template <typename ExpectedElementType>
+template <typename AtomicOpTy, typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
@@ -42,8 +42,10 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
+ StringAttr semanticsAttrName =
+ AtomicOpTy::getSemanticsAttrName(op->getName());
auto memorySemantics =
- op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
@@ -56,7 +58,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
//===----------------------------------------------------------------------===//
LogicalResult AtomicAndOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -64,7 +66,7 @@ LogicalResult AtomicAndOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIAddOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -72,7 +74,7 @@ LogicalResult AtomicIAddOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult EXTAtomicFAddOp::verify() {
- return verifyAtomicUpdateOp<FloatType>(getOperation());
+ return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -80,7 +82,7 @@ LogicalResult EXTAtomicFAddOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIDecrementOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -88,7 +90,7 @@ LogicalResult AtomicIDecrementOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIIncrementOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -96,7 +98,7 @@ LogicalResult AtomicIIncrementOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicISubOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -104,7 +106,7 @@ LogicalResult AtomicISubOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicOrOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -112,7 +114,7 @@ LogicalResult AtomicOrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMaxOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -120,7 +122,7 @@ LogicalResult AtomicSMaxOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMinOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -128,7 +130,7 @@ LogicalResult AtomicSMinOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMaxOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -136,7 +138,7 @@ LogicalResult AtomicUMaxOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMinOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -144,7 +146,7 @@ LogicalResult AtomicUMinOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicXorOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
}
} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 7170a899069ee3..3e317319b68fc5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();
- result.addAttribute(kBranchWeightAttrName,
+ StringAttr branchWeightsAttrName =
+ BranchConditionalOp::getBranchWeightsAttrName(result.name);
+ result.addAttribute(branchWeightsAttrName,
builder.getArrayAttr({trueWeight, falseWeight}));
}
@@ -199,11 +201,11 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+ return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+ (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..fcbef2be75f9a0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -20,15 +20,18 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+template <typename OpTy>
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- kExecutionScopeAttrName) ||
- spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
- kGroupOperationAttrName) ||
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+ executionScope, parser, state,
+ OpTy::getExecutionScopeAttrName(state.name)) ||
+ spirv::parseEnumStrAttr<GroupOperationAttr>(
+ groupOperation, parser, state,
+ OpTy::getGroupOperationAttrName(state.name)) ||
parser.parseOperand(valueInfo))
return failure();
@@ -56,16 +59,23 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
return parser.addTypeToList(resultType, state.types);
}
+template <typename GroupNonUniformArithmeticOpTy>
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
printer
<< " \""
<< stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+ groupOp->getName()))
.getValue())
<< "\" \""
<< stringifyGroupOperation(
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+ groupOp->getName()))
.getValue())
<< "\" " << groupOp->getOperand(0);
@@ -74,16 +84,21 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
printer << " : " << groupOp->getResult(0).getType();
}
+template <typename OpTy>
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ OpTy::getExecutionScopeAttrName(groupOp->getName()))
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ OpTy::getGroupOperationAttrName(groupOp->getName()))
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
@@ -206,16 +221,17 @@ LogicalResult GroupNonUniformElectOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
}
ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
+ result);
}
void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -223,16 +239,17 @@ void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
}
ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
+ result);
}
void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -240,16 +257,17 @@ void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
}
ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
+ result);
}
void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -257,16 +275,17 @@ void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
}
ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
+ result);
}
void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -274,16 +293,17 @@ void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
}
ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
+ result);
}
void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -291,16 +311,17 @@ void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
}
ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
+ result);
}
void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -308,16 +329,17 @@ void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
}
ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
+ result);
}
void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -325,16 +347,17 @@ void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
}
ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
+ result);
}
void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -342,16 +365,17 @@ void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
}
ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
+ result);
}
void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -359,16 +383,17 @@ void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
}
ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
+ result);
}
void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -376,16 +401,17 @@ void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseAndOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
}
ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
+ result);
}
void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -393,16 +419,17 @@ void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseOrOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
}
ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
+ result);
}
void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -410,16 +437,17 @@ void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseXorOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
}
ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
+ result);
}
void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -427,16 +455,17 @@ void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalAndOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
}
ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
+ result);
}
void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -444,16 +473,17 @@ void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalOrOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
}
ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
+ result);
}
void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -461,16 +491,17 @@ void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalXorOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
}
ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
+ result);
}
void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this, p);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..f5676f36a0f5f5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -25,6 +25,7 @@ namespace mlir::spirv {
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
+template <typename IntegerDotProductOpTy>
static LogicalResult verifyIntegerDotProduct(Operation *op) {
assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
"Not an integer dot product op?");
@@ -33,10 +34,12 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
// ODS enforces that vector 1 and vector 2, and result and the accumulator
// have the same types.
Type factorTy = op->getOperand(0).getType();
+ StringAttr packedVectorFormatAttrName =
+ IntegerDotProductOpTy::getFormatAttrName(op->getName());
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto packedVectorFormat =
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (!packedVectorFormat)
return op->emitOpError("requires Packed Vector Format attribute for "
"integer vector operands");
@@ -50,7 +53,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
"integer vector operands to be 32-bits wide",
packedVectorFormat.getValue()));
} else {
- if (op->hasAttr(kPackedVectorFormatAttrName))
+ if (op->hasAttr(packedVectorFormatAttrName))
return op->emitOpError(llvm::formatv(
"with invalid format attribute for vector operands of type '{0}'",
factorTy));
@@ -84,6 +87,7 @@ getIntegerDotProductExtensions() {
return {extension};
}
+template <typename IntegerDotProductOpTy>
static SmallVector<ArrayRef<spirv::Capability>, 1>
getIntegerDotProductCapabilities(Operation *op) {
// Requires the the DotProduct capability and capabilities that depend on
@@ -99,9 +103,11 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
+ StringAttr packedVectorFormatAttrName =
+ IntegerDotProductOpTy::getFormatAttrName(op->getName());
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (formatAttr.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
capabilities.push_back(dotProductInput4x8BitPackedCap);
@@ -120,12 +126,14 @@ getIntegerDotProductCapabilities(Operation *op) {
}
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
- LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
+ LogicalResult OpName::verify() { \
+ return verifyIntegerDotProduct<OpName>(*this); \
+ } \
SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
return getIntegerDotProductExtensions(); \
} \
SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
- return getIntegerDotProductCapabilities(*this); \
+ return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
return getIntegerDotProductMinVersion(); \
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..c4c7ff722175dc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -25,10 +25,48 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+/// (`[` memory-access `]`)?
+/// where:
+/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+/// integer-literal | `"NonTemporal"`
+template <typename MemoryOpTy>
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse an optional list of attributes staring with '['
+ if (parser.parseOptionalLSquare()) {
+ // Nothing to do
+ return success();
+ }
+
+ spirv::MemoryAccess memoryAccessAttr;
+ StringAttr memoryAccessAttrName =
+ MemoryOpTy::getMemoryAccessAttrName(state.name);
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+ memoryAccessAttr, parser, state, memoryAccessAttrName))
+ return failure();
+
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
+ // Parse integer attribute for alignment.
+ Attribute alignmentAttr;
+ StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseComma() ||
+ parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
+ state.attributes)) {
+ return failure();
+ }
+ }
+ return parser.parseRSquare();
+}
+
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
+template <typename MemoryOpTy>
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
@@ -38,17 +76,21 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
+ StringRef memoryAccessAttrName =
+ MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
- memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+ memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
+ StringAttr alignmentAttrName =
+ MemoryOpTy::getSourceAlignmentAttrName(state.name);
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+ parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
@@ -72,7 +114,7 @@ static void printSourceMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -80,7 +122,7 @@ static void printSourceMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kSourceAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
printer << ", " << *alignment;
}
}
@@ -98,7 +140,7 @@ static void printMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -106,7 +148,7 @@ static void printMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
printer << ", " << *alignment;
}
}
@@ -136,11 +178,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -157,11 +199,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -180,11 +222,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -201,11 +243,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kSourceAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -376,7 +418,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand ptrInfo;
Type elementType;
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
- parseMemoryAccessAttributes(parser, result) ||
+ parseMemoryAccessAttributes<LoadOp>(parser, result) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
@@ -425,8 +467,8 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
- parser.parseType(elementType)) {
+ parseMemoryAccessAttributes<StoreOp>(parser, result) ||
+ parser.parseColon() || parser.parseType(elementType)) {
return failure();
}
@@ -499,13 +541,13 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
- parseMemoryAccessAttributes(parser, result)) {
+ parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
if (!parser.parseOptionalComma()) {
// Parse 2nd memory access attributes.
- if (parseSourceMemoryAccessAttributes(parser, result)) {
+ if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..38415620fd4f96 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,14 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
+ StringRef indicesAttrName =
+ spirv::CompositeExtractOp::getIndicesAttrName(result.name);
Type compositeType;
SMLoc attrLocation;
if (parser.parseOperand(compositeInfo) ||
parser.getCurrentLocation(&attrLocation) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(compositeType) ||
parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
return failure();
@@ -513,11 +515,13 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
+ StringRef indicesAttrName =
+ spirv::CompositeInsertOp::getIndicesAttrName(result.name);
auto loc = parser.getCurrentLocation();
return failure(
parser.parseOperandList(operands, 2) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(objectType) ||
parser.parseKeywordType("into", compositeType) ||
parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
- if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+ StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
+ if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
Type type = NoneType::get(parser.getContext());
@@ -822,7 +827,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}))
return failure();
}
- result.addAttribute(kInterfaceAttrName,
+ result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -875,7 +880,9 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
- result.addAttribute(kValuesAttrName,
+ StringRef valuesAttrName =
+ spirv::ExecutionModeOp::getValuesAttrName(result.name);
+ result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
}
@@ -1150,16 +1157,18 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
+ StringRef initializerAttrName =
+ spirv::GlobalVariableOp::getInitializerAttrName(result.name);
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
}
// Parse optional initializer
- if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+ if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
- parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+ parser.parseAttribute(initSymbol, Type(), initializerAttrName,
result.attributes) ||
parser.parseRParen())
return failure();
@@ -1170,6 +1179,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
}
Type type;
+ StringRef typeAttrName =
+ spirv::GlobalVariableOp::getTypeAttrName(result.name);
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1177,7 +1188,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
if (!llvm::isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
@@ -1191,15 +1202,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+ StringRef initializerAttrName = this->getInitializerAttrName();
// Print optional initializer
if (auto initializer = this->getInitializer()) {
- printer << " " << kInitializerAttrName << '(';
+ printer << " " << initializerAttrName << '(';
printer.printSymbolName(*initializer);
printer << ')';
- elidedAttrs.push_back(kInitializerAttrName);
+ elidedAttrs.push_back(initializerAttrName);
}
- elidedAttrs.push_back(kTypeAttrName);
+ StringRef typeAttrName = this->getTypeAttrName();
+ elidedAttrs.push_back(typeAttrName);
spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
@@ -1219,8 +1232,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
<< stringifyStorageClass(storageClass) << "'";
}
- if (auto init =
- (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+ if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+ this->getInitializerAttrName())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1594,6 +1607,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
+ StringRef defaultValueAttrName =
+ spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -1609,8 +1624,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
}
if (parser.parseEqual() ||
- parser.parseAttribute(valueAttr, kDefaultValueAttrName,
- result.attributes))
+ parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
return failure();
return success();
@@ -1783,14 +1797,18 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseRParen())
return failure();
- result.addAttribute(kCompositeSpecConstituentsName,
+ StringAttr compositeSpecConstituentsName =
+ spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
+ result.addAttribute(compositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
Type type;
if (parser.parseColonType(type))
return failure();
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ StringAttr typeAttrName =
+ spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 27c373300aee8e..726fb7ce9fdb74 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -18,35 +18,6 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
-ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
- OperationState &state,
- StringRef attrName) {
- // Parse an optional list of attributes staring with '['
- if (parser.parseOptionalLSquare()) {
- // Nothing to do
- return success();
- }
-
- spirv::MemoryAccess memoryAccessAttr;
- if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
- state, attrName))
- return failure();
-
- if (spirv::bitEnumContainsAll(memoryAccessAttr,
- spirv::MemoryAccess::Aligned)) {
- // Parse integer attribute for alignment.
- Attribute alignmentAttr;
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type,
- AttrNames::kAlignmentAttrName,
- state.attributes)) {
- return failure();
- }
- }
- return parser.parseRSquare();
-}
-
ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state) {
auto builtInName = llvm::convertToSnakeFromCamelCase(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 625c82f6e8e899..858b94f7be8b08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,34 +20,12 @@
namespace mlir::spirv {
namespace AttrNames {
-// TODO: generate these strings using ODS.
-inline constexpr char kAlignmentAttrName[] = "alignment";
-inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-inline constexpr char kCallee[] = "callee";
-inline constexpr char kClusterSize[] = "cluster_size";
-inline constexpr char kControl[] = "control";
-inline constexpr char kDefaultValueAttrName[] = "default_value";
-inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-inline constexpr char kFnNameAttrName[] = "fn";
-inline constexpr char kGroupOperationAttrName[] = "group_operation";
-inline constexpr char kIndicesAttrName[] = "indices";
-inline constexpr char kInitializerAttrName[] = "initializer";
-inline constexpr char kInterfaceAttrName[] = "interface";
-inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-inline constexpr char kMemoryAccessAttrName[] = "memory_access";
-inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-inline constexpr char kPackedVectorFormatAttrName[] = "format";
-inline constexpr char kSemanticsAttrName[] = "semantics";
-inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-inline constexpr char kSpecIdAttrName[] = "spec_id";
-inline constexpr char kTypeAttrName[] = "type";
-inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-inline constexpr char kValueAttrName[] = "value";
-inline constexpr char kValuesAttrName[] = "values";
-inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+
+inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+
} // namespace AttrNames
template <typename Ty>
@@ -143,16 +121,6 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
return success();
}
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-/// (`[` memory-access `]`)?
-/// where:
-/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-/// integer-literal | `"NonTemporal"`
-ParseResult parseMemoryAccessAttributes(
- OpAsmParser &parser, OperationState &state,
- StringRef attrName = AttrNames::kMemoryAccessAttrName);
-
ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state);
More information about the Mlir-commits
mailing list