[Mlir-commits] [mlir] [mlir][spirv] Use ODS generated attribute names for op definitions (PR #81552)
Lei Zhang
llvmlistbot at llvm.org
Sun Feb 25 16:06:43 PST 2024
https://github.com/antiagainst updated https://github.com/llvm/llvm-project/pull/81552
>From c794f24655d131fb698a7cf2310b1f50259b7d16 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Mon, 12 Feb 2024 00:31:59 -0500
Subject: [PATCH 1/5] replaces attribute names in SPIRV parsing code
---
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp | 3 +-
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 6 +-
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 12 ++--
.../Dialect/SPIRV/IR/IntegerDotProductOps.cpp | 8 ++-
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp | 30 +++++-----
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 41 +++++++++-----
.../Dialect/SPIRV/IR/SPIRVParsingUtils.cpp | 4 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h | 55 ++++++++++---------
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 +-
9 files changed, 90 insertions(+), 71 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index 7e33e91414e0bf..d84133d5933415 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -42,8 +42,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
+ StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
auto memorySemantics =
- op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 580782043c81b4..650b5448f041b5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,7 +87,7 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();
- result.addAttribute(kBranchWeightAttrName,
+ result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
builder.getArrayAttr({trueWeight, falseWeight}));
}
@@ -199,11 +199,11 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
+ return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+ (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ac29ab0db586cf..ea82f8442dcfe0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -26,9 +26,9 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- kExecutionScopeAttrName) ||
+ ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
- kGroupOperationAttrName) ||
+ GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
parser.parseOperand(valueInfo))
return failure();
@@ -61,11 +61,11 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
printer
<< " \""
<< stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
.getValue())
<< "\" \""
<< stringifyGroupOperation(
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
.getValue())
<< "\" " << groupOp->getOperand(0);
@@ -76,14 +76,14 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
- groupOp->getAttrOfType<GroupOperationAttr>(kGroupOperationAttrName)
+ groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 00fc2acf7f07d0..3abc4361321d04 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -33,10 +33,11 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
// ODS enforces that vector 1 and vector 2, and result and the accumulator
// have the same types.
Type factorTy = op->getOperand(0).getType();
+ StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto packedVectorFormat =
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (!packedVectorFormat)
return op->emitOpError("requires Packed Vector Format attribute for "
"integer vector operands");
@@ -50,7 +51,7 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
"integer vector operands to be 32-bits wide",
packedVectorFormat.getValue()));
} else {
- if (op->hasAttr(kPackedVectorFormatAttrName))
+ if (op->hasAttr(packedVectorFormatAttrName))
return op->emitOpError(llvm::formatv(
"with invalid format attribute for vector operands of type '{0}'",
factorTy));
@@ -99,9 +100,10 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
+ StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
- op->getAttr(kPackedVectorFormatAttrName));
+ op->getAttr(packedVectorFormatAttrName));
if (formatAttr.getValue() ==
spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
capabilities.push_back(dotProductInput4x8BitPackedCap);
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 0df59e42218b55..eb7900fd6ef8fb 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -38,17 +38,19 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
+ StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
- memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName))
+ memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
+ StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
+ parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
state.attributes)) {
return failure();
}
@@ -72,7 +74,7 @@ static void printSourceMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kSourceMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -80,7 +82,7 @@ static void printSourceMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kSourceAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
printer << ", " << *alignment;
}
}
@@ -98,7 +100,7 @@ static void printMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(kMemoryAccessAttrName);
+ elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -106,7 +108,7 @@ static void printMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(kAlignmentAttrName);
+ elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
printer << ", " << *alignment;
}
}
@@ -136,11 +138,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -157,11 +159,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -180,11 +182,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
+ auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -201,11 +203,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(kSourceAlignmentAttrName)) {
+ if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(kSourceAlignmentAttrName)) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 50035c917137e3..c27df9e30f24e3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,12 +457,13 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
+ StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
Type compositeType;
SMLoc attrLocation;
if (parser.parseOperand(compositeInfo) ||
parser.getCurrentLocation(&attrLocation) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(compositeType) ||
parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
return failure();
@@ -513,11 +514,12 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
+ StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
return failure(
parser.parseOperandList(operands, 2) ||
- parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
+ parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
parser.parseColonType(objectType) ||
parser.parseKeywordType("into", compositeType) ||
parser.resolveOperands(operands, {objectType, compositeType}, loc,
@@ -559,7 +561,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
- if (parser.parseAttribute(value, kValueAttrName, result.attributes))
+ StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+ if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
Type type = NoneType::get(parser.getContext());
@@ -822,7 +825,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}))
return failure();
}
- result.addAttribute(kInterfaceAttrName,
+ result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -875,7 +878,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
- result.addAttribute(kValuesAttrName,
+ StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+ result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
}
@@ -1150,16 +1154,17 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
+ StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
}
// Parse optional initializer
- if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
+ if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
- parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
+ parser.parseAttribute(initSymbol, Type(), initializerAttrName,
result.attributes) ||
parser.parseRParen())
return failure();
@@ -1170,6 +1175,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
}
Type type;
+ StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1177,7 +1183,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
if (!llvm::isa<spirv::PointerType>(type)) {
return parser.emitError(loc, "expected spirv.ptr type");
}
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
@@ -1191,15 +1197,17 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
+ StringRef initializerAttrName = this->getInitializerAttrName().strref();
// Print optional initializer
if (auto initializer = this->getInitializer()) {
- printer << " " << kInitializerAttrName << '(';
+ printer << " " << initializerAttrName << '(';
printer.printSymbolName(*initializer);
printer << ')';
- elidedAttrs.push_back(kInitializerAttrName);
+ elidedAttrs.push_back(initializerAttrName);
}
- elidedAttrs.push_back(kTypeAttrName);
+ StringRef typeAttrName = this->getTypeAttrName().strref();
+ elidedAttrs.push_back(typeAttrName);
spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
}
@@ -1220,7 +1228,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
}
if (auto init =
- (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
+ (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1594,6 +1602,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
+ StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -1609,7 +1618,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
}
if (parser.parseEqual() ||
- parser.parseAttribute(valueAttr, kDefaultValueAttrName,
+ parser.parseAttribute(valueAttr, defaultValueAttrName,
result.attributes))
return failure();
@@ -1783,14 +1792,16 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseRParen())
return failure();
- result.addAttribute(kCompositeSpecConstituentsName,
+ StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+ result.addAttribute(compositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
Type type;
if (parser.parseColonType(type))
return failure();
- result.addAttribute(kTypeAttrName, TypeAttr::get(type));
+ StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+ result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 27c373300aee8e..f48b0b93ab5fe8 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,6 +12,8 @@
#include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
#include "llvm/ADT/StringExtras.h"
using namespace mlir::spirv::AttrNames;
@@ -39,7 +41,7 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type,
- AttrNames::kAlignmentAttrName,
+ CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
state.attributes)) {
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 625c82f6e8e899..d4c0deef6aae3d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,34 +20,35 @@
namespace mlir::spirv {
namespace AttrNames {
+inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+
// TODO: generate these strings using ODS.
-inline constexpr char kAlignmentAttrName[] = "alignment";
-inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-inline constexpr char kCallee[] = "callee";
-inline constexpr char kClusterSize[] = "cluster_size";
-inline constexpr char kControl[] = "control";
-inline constexpr char kDefaultValueAttrName[] = "default_value";
-inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-inline constexpr char kFnNameAttrName[] = "fn";
-inline constexpr char kGroupOperationAttrName[] = "group_operation";
-inline constexpr char kIndicesAttrName[] = "indices";
-inline constexpr char kInitializerAttrName[] = "initializer";
-inline constexpr char kInterfaceAttrName[] = "interface";
-inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-inline constexpr char kMemoryAccessAttrName[] = "memory_access";
-inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-inline constexpr char kPackedVectorFormatAttrName[] = "format";
-inline constexpr char kSemanticsAttrName[] = "semantics";
-inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-inline constexpr char kSpecIdAttrName[] = "spec_id";
-inline constexpr char kTypeAttrName[] = "type";
-inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-inline constexpr char kValueAttrName[] = "value";
-inline constexpr char kValuesAttrName[] = "values";
-inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
+// inline constexpr char kAlignmentAttrName[] = "alignment";
+// inline constexpr char kBranchWeightAttrName[] = "branch_weights";
+// inline constexpr char kCallee[] = "callee";
+// inline constexpr char kDefaultValueAttrName[] = "default_value";
+// inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
+// inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
+// inline constexpr char kGroupOperationAttrName[] = "group_operation";
+// inline constexpr char kIndicesAttrName[] = "indices";
+// inline constexpr char kInitializerAttrName[] = "initializer";
+// inline constexpr char kInterfaceAttrName[] = "interface";
+// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
+// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
+// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
+// inline constexpr char kPackedVectorFormatAttrName[] = "format";
+// inline constexpr char kSemanticsAttrName[] = "semantics";
+// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
+// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
+// inline constexpr char kTypeAttrName[] = "type";
+// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
+// inline constexpr char kValueAttrName[] = "value";
+// inline constexpr char kValuesAttrName[] = "values";
+// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
} // namespace AttrNames
template <typename Ty>
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 71326049af0579..d34b367e98e3d2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1142,7 +1142,7 @@ void OpEmitter::genAttrNameGetters() {
} else {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
- assert(name.getStringRef() == getOperationName() && "invalid operation name");
+ // assert(name.getStringRef() == getOperationName() && "invalid operation name");
assert(name.isRegistered() && "Operation isn't registered, missing a "
"dependent dialect loading?");
return name.getAttributeNames()[index];
>From 1cd7b881fe305cee24c0a0145494da4a6315815e Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Wed, 14 Feb 2024 14:25:56 -0500
Subject: [PATCH 2/5] Change SPIRV parser to use ODS generated Attr names
---
mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp | 29 ++--
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 14 +-
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 141 ++++++++++--------
.../Dialect/SPIRV/IR/IntegerDotProductOps.cpp | 9 +-
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp | 56 ++++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 39 +++--
.../Dialect/SPIRV/IR/SPIRVParsingUtils.cpp | 29 ----
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h | 35 +----
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 2 +-
9 files changed, 183 insertions(+), 171 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
index d84133d5933415..948d48980f2e84 100644
--- a/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp
@@ -33,7 +33,7 @@ StringRef stringifyTypeName<FloatType>() {
}
// Verifies an atomic update op.
-template <typename ExpectedElementType>
+template <typename AtomicOpTy, typename ExpectedElementType>
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
auto elementType = ptrType.getPointeeType();
@@ -42,7 +42,8 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
<< stringifyTypeName<ExpectedElementType>()
<< " value, found " << elementType;
- StringRef semanticsAttrName = spirv::AtomicAndOp::getSemanticsAttrName(op->getName()).strref();
+ StringAttr semanticsAttrName =
+ AtomicOpTy::getSemanticsAttrName(op->getName());
auto memorySemantics =
op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
.getValue();
@@ -57,7 +58,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
//===----------------------------------------------------------------------===//
LogicalResult AtomicAndOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -65,7 +66,7 @@ LogicalResult AtomicAndOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIAddOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -73,7 +74,7 @@ LogicalResult AtomicIAddOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult EXTAtomicFAddOp::verify() {
- return verifyAtomicUpdateOp<FloatType>(getOperation());
+ return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -81,7 +82,7 @@ LogicalResult EXTAtomicFAddOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIDecrementOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -89,7 +90,7 @@ LogicalResult AtomicIDecrementOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicIIncrementOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -97,7 +98,7 @@ LogicalResult AtomicIIncrementOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicISubOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -105,7 +106,7 @@ LogicalResult AtomicISubOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicOrOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -113,7 +114,7 @@ LogicalResult AtomicOrOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMaxOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -121,7 +122,7 @@ LogicalResult AtomicSMaxOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicSMinOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -129,7 +130,7 @@ LogicalResult AtomicSMinOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMaxOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -137,7 +138,7 @@ LogicalResult AtomicUMaxOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicUMinOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
}
//===----------------------------------------------------------------------===//
@@ -145,7 +146,7 @@ LogicalResult AtomicUMinOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult AtomicXorOp::verify() {
- return verifyAtomicUpdateOp<IntegerType>(getOperation());
+ return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
}
} // namespace mlir::spirv
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 650b5448f041b5..c441cd4e637b29 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -87,8 +87,11 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
parser.parseRSquare())
return failure();
- result.addAttribute(BranchConditionalOp::getBranchWeightsAttrName(result.name).strref(),
- builder.getArrayAttr({trueWeight, falseWeight}));
+ StringAttr branchWeightsAttrName =
+ BranchConditionalOp::getBranchWeightsAttrName(result.name);
+ result.addAttribute(
+ branchWeightsAttrName,
+ builder.getArrayAttr({trueWeight, falseWeight}));
}
// Parse the true branch.
@@ -199,11 +202,14 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref());
+ return (*this)->getAttrOfType<SymbolRefAttr>(
+ FunctionCallOp::getCalleeAttrName((*this)->getName()));
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()).strref(), callee.get<SymbolRefAttr>());
+ (*this)->setAttr(
+ FunctionCallOp::getCalleeAttrName((*this)->getName()),
+ callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index ea82f8442dcfe0..3da6870e4ae883 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -20,15 +20,18 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+template <typename GroupNonUniformArithmenticOpTy>
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
- if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
- ControlBarrierOp::getExecutionScopeAttrName(state.name).strref()) ||
- spirv::parseEnumStrAttr<GroupOperationAttr>(groupOperation, parser, state,
- GroupFAddOp::getGroupOperationAttrName(state.name).strref()) ||
+ if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
+ executionScope, parser, state,
+ GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(state.name)) ||
+ spirv::parseEnumStrAttr<GroupOperationAttr>(
+ groupOperation, parser, state,
+ GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(state.name)) ||
parser.parseOperand(valueInfo))
return failure();
@@ -56,34 +59,44 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
return parser.addTypeToList(resultType, state.types);
}
+template <typename GroupNonUniformArithmeticOpTy>
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
- printer
- << " \""
- << stringifyScope(
- groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
- .getValue())
- << "\" \""
- << stringifyGroupOperation(
- groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
- .getValue())
- << "\" " << groupOp->getOperand(0);
+ printer << " \""
+ << stringifyScope(groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+ groupOp->getName()))
+ .getValue())
+ << "\" \""
+ << stringifyGroupOperation(
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+ groupOp->getName()))
+ .getValue())
+ << "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
printer << " : " << groupOp->getResult(0).getType();
}
+template <typename GroupNonUniformArithmenticOpTy>
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
- groupOp->getAttrOfType<spirv::ScopeAttr>(ControlBarrierOp::getExecutionScopeAttrName(groupOp->getName()).strref())
+ groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(groupOp->getName()))
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
GroupOperation operation =
- groupOp->getAttrOfType<GroupOperationAttr>(GroupFAddOp::getGroupOperationAttrName(groupOp->getName()).strref())
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(groupOp->getName()))
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
@@ -206,16 +219,16 @@ LogicalResult GroupNonUniformElectOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this);
}
ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser, result);
}
void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -223,16 +236,16 @@ void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this);
}
ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser, result);
}
void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -240,16 +253,16 @@ void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this);
}
ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser, result);
}
void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -257,16 +270,16 @@ void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformFMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this);
}
ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser, result);
}
void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -274,16 +287,16 @@ void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIAddOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this);
}
ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser, result);
}
void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -291,16 +304,16 @@ void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformIMulOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this);
}
ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser, result);
}
void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -308,16 +321,16 @@ void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this);
}
ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser, result);
}
void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -325,16 +338,16 @@ void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformSMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this);
}
ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser, result);
}
void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -342,16 +355,16 @@ void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMaxOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this);
}
ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser, result);
}
void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -359,16 +372,16 @@ void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformUMinOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this);
}
ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser, result);
}
void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -376,16 +389,16 @@ void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseAndOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this);
}
ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser, result);
}
void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -393,16 +406,16 @@ void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseOrOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this);
}
ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser, result);
}
void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -410,16 +423,16 @@ void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformBitwiseXorOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this);
}
ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser, result);
}
void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -427,16 +440,16 @@ void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalAndOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this);
}
ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser, result);
}
void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -444,16 +457,16 @@ void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalOrOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this);
}
ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser, result);
}
void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this, p);
}
//===----------------------------------------------------------------------===//
@@ -461,16 +474,16 @@ void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//
LogicalResult GroupNonUniformLogicalXorOp::verify() {
- return verifyGroupNonUniformArithmeticOp(*this);
+ return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this);
}
ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser, result);
}
void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
- printGroupNonUniformArithmeticOp(*this, p);
+ printGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this, p);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 3abc4361321d04..0ef1e7bc0ca8be 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -25,6 +25,7 @@ namespace mlir::spirv {
// Integer Dot Product ops
//===----------------------------------------------------------------------===//
+template <typename IntegerDotProductOpTy>
static LogicalResult verifyIntegerDotProduct(Operation *op) {
assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
"Not an integer dot product op?");
@@ -33,7 +34,8 @@ static LogicalResult verifyIntegerDotProduct(Operation *op) {
// ODS enforces that vector 1 and vector 2, and result and the accumulator
// have the same types.
Type factorTy = op->getOperand(0).getType();
- StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+ StringAttr packedVectorFormatAttrName =
+ IntegerDotProductOpTy::getFormatAttrName(op->getName());
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto packedVectorFormat =
llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
@@ -100,7 +102,8 @@ getIntegerDotProductCapabilities(Operation *op) {
SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
Type factorTy = op->getOperand(0).getType();
- StringRef packedVectorFormatAttrName = SDotAccSatOp::getFormatAttrName(op->getName()).strref();
+ StringAttr packedVectorFormatAttrName =
+ SDotAccSatOp::getFormatAttrName(op->getName());
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
op->getAttr(packedVectorFormatAttrName));
@@ -122,7 +125,7 @@ getIntegerDotProductCapabilities(Operation *op) {
}
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
- LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
+ LogicalResult OpName::verify() { return verifyIntegerDotProduct<OpName>(*this); } \
SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
return getIntegerDotProductExtensions(); \
} \
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index eb7900fd6ef8fb..25afebf9f92efc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -25,10 +25,49 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
+/// Parses optional memory access (a.k.a. memory operand) attributes attached to
+/// a memory access operand/pointer. Specifically, parses the following syntax:
+/// (`[` memory-access `]`)?
+/// where:
+/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
+/// integer-literal | `"NonTemporal"`
+template<typename MemoryOpTy>
+ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
+ OperationState &state) {
+ // Parse an optional list of attributes staring with '['
+ if (parser.parseOptionalLSquare()) {
+ // Nothing to do
+ return success();
+ }
+
+ spirv::MemoryAccess memoryAccessAttr;
+ StringAttr memoryAccessAttrName = MemoryOpTy::getMemoryAccessAttrName(state.name);
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
+ state, memoryAccessAttrName))
+ return failure();
+
+ if (spirv::bitEnumContainsAll(memoryAccessAttr,
+ spirv::MemoryAccess::Aligned)) {
+ // Parse integer attribute for alignment.
+ Attribute alignmentAttr;
+ StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
+ Type i32Type = parser.getBuilder().getIntegerType(32);
+ if (parser.parseComma() ||
+ parser.parseAttribute(
+ alignmentAttr, i32Type,
+ alignmentAttrName,
+ state.attributes)) {
+ return failure();
+ }
+ }
+ return parser.parseRSquare();
+}
+
// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
+template<typename MemoryOpTy>
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
@@ -38,7 +77,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- StringRef memoryAccessAttrName = CopyMemoryOp::getMemoryAccessAttrName(state.name).strref();
+ StringRef memoryAccessAttrName =
+ MemoryOpTy::getMemoryAccessAttrName(state.name).strref();
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
@@ -47,7 +87,8 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
Attribute alignmentAttr;
- StringRef alignmentAttrName = CopyMemoryOp::getSourceAlignmentAttrName(state.name).strref();
+ StringAttr alignmentAttrName =
+ MemoryOpTy::getSourceAlignmentAttrName(state.name);
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
@@ -182,7 +223,8 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+ auto memAccessAttr =
+ op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
@@ -378,7 +420,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand ptrInfo;
Type elementType;
if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
- parseMemoryAccessAttributes(parser, result) ||
+ parseMemoryAccessAttributes<LoadOp>(parser, result) ||
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
@@ -427,7 +469,7 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) ||
- parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
+ parseMemoryAccessAttributes<StoreOp>(parser, result) || parser.parseColon() ||
parser.parseType(elementType)) {
return failure();
}
@@ -501,13 +543,13 @@ ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
parseEnumStrAttr(sourceStorageClass, parser) ||
parser.parseOperand(sourcePtrInfo) ||
- parseMemoryAccessAttributes(parser, result)) {
+ parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
if (!parser.parseOptionalComma()) {
// Parse 2nd memory access attributes.
- if (parseSourceMemoryAccessAttributes(parser, result)) {
+ if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
return failure();
}
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index c27df9e30f24e3..5d6ddaa876575c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -457,7 +457,8 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
- StringRef indicesAttrName = spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+ StringRef indicesAttrName =
+ spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
Type compositeType;
SMLoc attrLocation;
@@ -514,7 +515,8 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
Type objectType, compositeType;
Attribute indicesAttr;
- StringRef indicesAttrName = spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+ StringRef indicesAttrName =
+ spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
return failure(
@@ -561,7 +563,8 @@ void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
- StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name).strref();
+ StringRef valueAttrName =
+ spirv::ConstantOp::getValueAttrName(result.name).strref();
if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
@@ -825,8 +828,9 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
}))
return failure();
}
- result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
- parser.getBuilder().getArrayAttr(interfaceVars));
+ result.addAttribute(
+ spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+ parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -878,7 +882,8 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
}
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
- StringRef valuesAttrName = spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+ StringRef valuesAttrName =
+ spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
@@ -1154,7 +1159,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
OperationState &result) {
// Parse variable name.
StringAttr nameAttr;
- StringRef initializerAttrName = spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+ StringRef initializerAttrName =
+ spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
@@ -1175,7 +1181,8 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
}
Type type;
- StringRef typeAttrName = spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+ StringRef typeAttrName =
+ spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1227,8 +1234,8 @@ LogicalResult spirv::GlobalVariableOp::verify() {
<< stringifyStorageClass(storageClass) << "'";
}
- if (auto init =
- (*this)->getAttrOfType<FlatSymbolRefAttr>(this->getInitializerAttrName().strref())) {
+ if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
+ this->getInitializerAttrName().strref())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1602,7 +1609,8 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
StringAttr nameAttr;
Attribute valueAttr;
- StringRef defaultValueAttrName = spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+ StringRef defaultValueAttrName =
+ spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -1618,8 +1626,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
}
if (parser.parseEqual() ||
- parser.parseAttribute(valueAttr, defaultValueAttrName,
- result.attributes))
+ parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
return failure();
return success();
@@ -1792,7 +1799,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseRParen())
return failure();
- StringRef compositeSpecConstituentsName = spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name).strref();
+ StringAttr compositeSpecConstituentsName =
+ spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
result.addAttribute(compositeSpecConstituentsName,
parser.getBuilder().getArrayAttr(constituents));
@@ -1800,7 +1808,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
if (parser.parseColonType(type))
return failure();
- StringRef typeAttrName = spirv::SpecConstantCompositeOp::getTypeAttrName(result.name).strref();
+ StringAttr typeAttrName =
+ spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
result.addAttribute(typeAttrName, TypeAttr::get(type));
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index f48b0b93ab5fe8..1fd25f1ac92e06 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -20,35 +20,6 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
-ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
- OperationState &state,
- StringRef attrName) {
- // Parse an optional list of attributes staring with '['
- if (parser.parseOptionalLSquare()) {
- // Nothing to do
- return success();
- }
-
- spirv::MemoryAccess memoryAccessAttr;
- if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
- state, attrName))
- return failure();
-
- if (spirv::bitEnumContainsAll(memoryAccessAttr,
- spirv::MemoryAccess::Aligned)) {
- // Parse integer attribute for alignment.
- Attribute alignmentAttr;
- Type i32Type = parser.getBuilder().getIntegerType(32);
- if (parser.parseComma() ||
- parser.parseAttribute(alignmentAttr, i32Type,
- CopyMemoryOp::getAlignmentAttrName(state.name).strref(),
- state.attributes)) {
- return failure();
- }
- }
- return parser.parseRSquare();
-}
-
ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state) {
auto builtInName = llvm::convertToSnakeFromCamelCase(
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index d4c0deef6aae3d..18c9e25227845c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -20,35 +20,12 @@
namespace mlir::spirv {
namespace AttrNames {
-inline constexpr char kMemoryAccessAttrName[] = "memory_access"; // need hardcoded string below -- probably can change
+
inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
inline constexpr char kControl[] = "control"; // no ODS generation
inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
-// TODO: generate these strings using ODS.
-// inline constexpr char kAlignmentAttrName[] = "alignment";
-// inline constexpr char kBranchWeightAttrName[] = "branch_weights";
-// inline constexpr char kCallee[] = "callee";
-// inline constexpr char kDefaultValueAttrName[] = "default_value";
-// inline constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
-// inline constexpr char kExecutionScopeAttrName[] = "execution_scope";
-// inline constexpr char kGroupOperationAttrName[] = "group_operation";
-// inline constexpr char kIndicesAttrName[] = "indices";
-// inline constexpr char kInitializerAttrName[] = "initializer";
-// inline constexpr char kInterfaceAttrName[] = "interface";
-// inline constexpr char kKhrCooperativeMatrixLayoutAttrName[] = "matrix_layout";
-// inline constexpr char kMemoryOperandAttrName[] = "memory_operand";
-// inline constexpr char kMemoryScopeAttrName[] = "memory_scope";
-// inline constexpr char kPackedVectorFormatAttrName[] = "format";
-// inline constexpr char kSemanticsAttrName[] = "semantics";
-// inline constexpr char kSourceAlignmentAttrName[] = "source_alignment";
-// inline constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
-// inline constexpr char kTypeAttrName[] = "type";
-// inline constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
-// inline constexpr char kValueAttrName[] = "value";
-// inline constexpr char kValuesAttrName[] = "values";
-// inline constexpr char kCompositeSpecConstituentsName[] = "constituents";
} // namespace AttrNames
template <typename Ty>
@@ -144,16 +121,6 @@ parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
return success();
}
-/// Parses optional memory access (a.k.a. memory operand) attributes attached to
-/// a memory access operand/pointer. Specifically, parses the following syntax:
-/// (`[` memory-access `]`)?
-/// where:
-/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
-/// integer-literal | `"NonTemporal"`
-ParseResult parseMemoryAccessAttributes(
- OpAsmParser &parser, OperationState &state,
- StringRef attrName = AttrNames::kMemoryAccessAttrName);
-
ParseResult parseVariableDecorations(OpAsmParser &parser,
OperationState &state);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d34b367e98e3d2..71326049af0579 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1142,7 +1142,7 @@ void OpEmitter::genAttrNameGetters() {
} else {
const char *const getAttrName = R"(
assert(index < {0} && "invalid attribute index");
- // assert(name.getStringRef() == getOperationName() && "invalid operation name");
+ assert(name.getStringRef() == getOperationName() && "invalid operation name");
assert(name.isRegistered() && "Operation isn't registered, missing a "
"dependent dialect loading?");
return name.getAttributeNames()[index];
>From 3bd6a1b7c5a5d667df9b3c09aa6fe30198026eb6 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Wed, 14 Feb 2024 21:10:02 -0500
Subject: [PATCH 3/5] fix bugs in parser
---
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 10 +--
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 90 ++++++++++++-------
.../Dialect/SPIRV/IR/IntegerDotProductOps.cpp | 9 +-
mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp | 47 +++++-----
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h | 6 +-
5 files changed, 92 insertions(+), 70 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index c441cd4e637b29..60dec9ce503061 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -89,9 +89,8 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser,
StringAttr branchWeightsAttrName =
BranchConditionalOp::getBranchWeightsAttrName(result.name);
- result.addAttribute(
- branchWeightsAttrName,
- builder.getArrayAttr({trueWeight, falseWeight}));
+ result.addAttribute(branchWeightsAttrName,
+ builder.getArrayAttr({trueWeight, falseWeight}));
}
// Parse the true branch.
@@ -207,9 +206,8 @@ CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(
- FunctionCallOp::getCalleeAttrName((*this)->getName()),
- callee.get<SymbolRefAttr>());
+ (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()),
+ callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 3da6870e4ae883..7c1d3c30f7f0d9 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -28,10 +28,12 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand valueInfo;
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
executionScope, parser, state,
- GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(state.name)) ||
+ GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
+ state.name)) ||
spirv::parseEnumStrAttr<GroupOperationAttr>(
groupOperation, parser, state,
- GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(state.name)) ||
+ GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
+ state.name)) ||
parser.parseOperand(valueInfo))
return failure();
@@ -62,20 +64,22 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
template <typename GroupNonUniformArithmeticOpTy>
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
- printer << " \""
- << stringifyScope(groupOp
- ->getAttrOfType<spirv::ScopeAttr>(
- GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
- groupOp->getName()))
- .getValue())
- << "\" \""
- << stringifyGroupOperation(
- groupOp
- ->getAttrOfType<GroupOperationAttr>(
- GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
- groupOp->getName()))
- .getValue())
- << "\" " << groupOp->getOperand(0);
+ printer
+ << " \""
+ << stringifyScope(
+ groupOp
+ ->getAttrOfType<spirv::ScopeAttr>(
+ GroupNonUniformArithmeticOpTy::getExecutionScopeAttrName(
+ groupOp->getName()))
+ .getValue())
+ << "\" \""
+ << stringifyGroupOperation(
+ groupOp
+ ->getAttrOfType<GroupOperationAttr>(
+ GroupNonUniformArithmeticOpTy::getGroupOperationAttrName(
+ groupOp->getName()))
+ .getValue())
+ << "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -87,7 +91,8 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
groupOp
->getAttrOfType<spirv::ScopeAttr>(
- GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(groupOp->getName()))
+ GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
+ groupOp->getName()))
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
@@ -96,7 +101,8 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
GroupOperation operation =
groupOp
->getAttrOfType<GroupOperationAttr>(
- GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(groupOp->getName()))
+ GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
+ groupOp->getName()))
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
@@ -224,7 +230,8 @@ LogicalResult GroupNonUniformFAddOp::verify() {
ParseResult GroupNonUniformFAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(parser,
+ result);
}
void GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
@@ -241,7 +248,8 @@ LogicalResult GroupNonUniformFMaxOp::verify() {
ParseResult GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(parser,
+ result);
}
void GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
@@ -258,7 +266,8 @@ LogicalResult GroupNonUniformFMinOp::verify() {
ParseResult GroupNonUniformFMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(parser,
+ result);
}
void GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
@@ -275,7 +284,8 @@ LogicalResult GroupNonUniformFMulOp::verify() {
ParseResult GroupNonUniformFMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(parser,
+ result);
}
void GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
@@ -292,7 +302,8 @@ LogicalResult GroupNonUniformIAddOp::verify() {
ParseResult GroupNonUniformIAddOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(parser,
+ result);
}
void GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
@@ -309,7 +320,8 @@ LogicalResult GroupNonUniformIMulOp::verify() {
ParseResult GroupNonUniformIMulOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(parser,
+ result);
}
void GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
@@ -326,7 +338,8 @@ LogicalResult GroupNonUniformSMaxOp::verify() {
ParseResult GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(parser,
+ result);
}
void GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
@@ -343,7 +356,8 @@ LogicalResult GroupNonUniformSMinOp::verify() {
ParseResult GroupNonUniformSMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(parser,
+ result);
}
void GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
@@ -360,7 +374,8 @@ LogicalResult GroupNonUniformUMaxOp::verify() {
ParseResult GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(parser,
+ result);
}
void GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
@@ -377,7 +392,8 @@ LogicalResult GroupNonUniformUMinOp::verify() {
ParseResult GroupNonUniformUMinOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(parser,
+ result);
}
void GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
@@ -394,7 +410,8 @@ LogicalResult GroupNonUniformBitwiseAndOp::verify() {
ParseResult GroupNonUniformBitwiseAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(parser,
+ result);
}
void GroupNonUniformBitwiseAndOp::print(OpAsmPrinter &p) {
@@ -411,7 +428,8 @@ LogicalResult GroupNonUniformBitwiseOrOp::verify() {
ParseResult GroupNonUniformBitwiseOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(parser,
+ result);
}
void GroupNonUniformBitwiseOrOp::print(OpAsmPrinter &p) {
@@ -428,7 +446,8 @@ LogicalResult GroupNonUniformBitwiseXorOp::verify() {
ParseResult GroupNonUniformBitwiseXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(parser,
+ result);
}
void GroupNonUniformBitwiseXorOp::print(OpAsmPrinter &p) {
@@ -445,7 +464,8 @@ LogicalResult GroupNonUniformLogicalAndOp::verify() {
ParseResult GroupNonUniformLogicalAndOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(parser,
+ result);
}
void GroupNonUniformLogicalAndOp::print(OpAsmPrinter &p) {
@@ -462,7 +482,8 @@ LogicalResult GroupNonUniformLogicalOrOp::verify() {
ParseResult GroupNonUniformLogicalOrOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(parser,
+ result);
}
void GroupNonUniformLogicalOrOp::print(OpAsmPrinter &p) {
@@ -479,7 +500,8 @@ LogicalResult GroupNonUniformLogicalXorOp::verify() {
ParseResult GroupNonUniformLogicalXorOp::parse(OpAsmParser &parser,
OperationState &result) {
- return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser, result);
+ return parseGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(parser,
+ result);
}
void GroupNonUniformLogicalXorOp::print(OpAsmPrinter &p) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
index 0ef1e7bc0ca8be..fd7cfa9b0d4bbf 100644
--- a/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/IntegerDotProductOps.cpp
@@ -87,6 +87,7 @@ getIntegerDotProductExtensions() {
return {extension};
}
+template <typename IntegerDotProductOpTy>
static SmallVector<ArrayRef<spirv::Capability>, 1>
getIntegerDotProductCapabilities(Operation *op) {
// Requires the the DotProduct capability and capabilities that depend on
@@ -103,7 +104,7 @@ getIntegerDotProductCapabilities(Operation *op) {
Type factorTy = op->getOperand(0).getType();
StringAttr packedVectorFormatAttrName =
- SDotAccSatOp::getFormatAttrName(op->getName());
+ IntegerDotProductOpTy::getFormatAttrName(op->getName());
if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
op->getAttr(packedVectorFormatAttrName));
@@ -125,12 +126,14 @@ getIntegerDotProductCapabilities(Operation *op) {
}
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
- LogicalResult OpName::verify() { return verifyIntegerDotProduct<OpName>(*this); } \
+ LogicalResult OpName::verify() { \
+ return verifyIntegerDotProduct<OpName>(*this); \
+ } \
SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
return getIntegerDotProductExtensions(); \
} \
SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
- return getIntegerDotProductCapabilities(*this); \
+ return getIntegerDotProductCapabilities<OpName>(*this); \
} \
std::optional<spirv::Version> OpName::getMinVersion() { \
return getIntegerDotProductMinVersion(); \
diff --git a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
index 25afebf9f92efc..a594ae9704f3c5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp
@@ -31,7 +31,7 @@ namespace mlir::spirv {
/// where:
/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
/// integer-literal | `"NonTemporal"`
-template<typename MemoryOpTy>
+template <typename MemoryOpTy>
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
@@ -41,9 +41,10 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- StringAttr memoryAccessAttrName = MemoryOpTy::getMemoryAccessAttrName(state.name);
- if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser,
- state, memoryAccessAttrName))
+ StringAttr memoryAccessAttrName =
+ MemoryOpTy::getMemoryAccessAttrName(state.name);
+ if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
+ memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
if (spirv::bitEnumContainsAll(memoryAccessAttr,
@@ -53,10 +54,8 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
Type i32Type = parser.getBuilder().getIntegerType(32);
if (parser.parseComma() ||
- parser.parseAttribute(
- alignmentAttr, i32Type,
- alignmentAttrName,
- state.attributes)) {
+ parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
+ state.attributes)) {
return failure();
}
}
@@ -67,7 +66,7 @@ ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
-template<typename MemoryOpTy>
+template <typename MemoryOpTy>
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
OperationState &state) {
// Parse an optional list of attributes staring with '['
@@ -78,7 +77,7 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
spirv::MemoryAccess memoryAccessAttr;
StringRef memoryAccessAttrName =
- MemoryOpTy::getMemoryAccessAttrName(state.name).strref();
+ MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
memoryAccessAttr, parser, state, memoryAccessAttrName))
return failure();
@@ -115,7 +114,7 @@ static void printSourceMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName().strref());
+ elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -123,7 +122,7 @@ static void printSourceMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName().strref());
+ elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
printer << ", " << *alignment;
}
}
@@ -141,7 +140,7 @@ static void printMemoryAccessAttribute(
// Print optional memory access attribute.
if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
: memoryOp.getMemoryAccess())) {
- elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName().strref());
+ elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
@@ -149,7 +148,7 @@ static void printMemoryAccessAttribute(
// Print integer alignment attribute.
if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
: memoryOp.getAlignment())) {
- elidedAttrs.push_back(memoryOp.getAlignmentAttrName().strref());
+ elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
printer << ", " << *alignment;
}
}
@@ -179,11 +178,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
// memory-access attribute is Aligned, then the alignment attribute must be
// present.
auto *op = memoryOp.getOperation();
- auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName().strref());
+ auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -200,11 +199,11 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+ if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(memoryOp.getAlignmentAttrName().strref())) {
+ if (op->getAttr(memoryOp.getAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -224,11 +223,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
// present.
auto *op = memoryOp.getOperation();
auto memAccessAttr =
- op->getAttr(memoryOp.getSourceMemoryAccessAttrName().strref());
+ op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
if (!memAccessAttr) {
// Alignment attribute shouldn't be present if memory access attribute is
// not present.
- if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification without aligned memory access "
"specification");
@@ -245,11 +244,11 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
if (spirv::bitEnumContainsAll(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
- if (!op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+ if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError("missing alignment value");
}
} else {
- if (op->getAttr(memoryOp.getSourceAlignmentAttrName().strref())) {
+ if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
return memoryOp.emitOpError(
"invalid alignment specification with non-aligned memory access "
"specification");
@@ -469,8 +468,8 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
Type elementType;
if (parseEnumStrAttr(storageClass, parser) ||
parser.parseOperandList(operandInfo, 2) ||
- parseMemoryAccessAttributes<StoreOp>(parser, result) || parser.parseColon() ||
- parser.parseType(elementType)) {
+ parseMemoryAccessAttributes<StoreOp>(parser, result) ||
+ parser.parseColon() || parser.parseType(elementType)) {
return failure();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
index 18c9e25227845c..858b94f7be8b08 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.h
@@ -22,9 +22,9 @@ namespace mlir::spirv {
namespace AttrNames {
inline constexpr char kClusterSize[] = "cluster_size"; // no ODS generation
-inline constexpr char kControl[] = "control"; // no ODS generation
-inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
-inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
+inline constexpr char kControl[] = "control"; // no ODS generation
+inline constexpr char kFnNameAttrName[] = "fn"; // no ODS generation
+inline constexpr char kSpecIdAttrName[] = "spec_id"; // no ODS generation
} // namespace AttrNames
>From 0c41cd1d73e38e438eda273246becba29b884d75 Mon Sep 17 00:00:00 2001
From: Tom Wilson <wilson.th at northeastern.edu>
Date: Thu, 15 Feb 2024 13:27:22 -0500
Subject: [PATCH 4/5] strref
---
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 22 +++++++++++-----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 5d6ddaa876575c..7cdd3fa7a19fdd 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -458,7 +458,7 @@ ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand compositeInfo;
Attribute indicesAttr;
StringRef indicesAttrName =
- spirv::CompositeExtractOp::getIndicesAttrName(result.name).strref();
+ spirv::CompositeExtractOp::getIndicesAttrName(result.name);
Type compositeType;
SMLoc attrLocation;
@@ -516,7 +516,7 @@ ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
Type objectType, compositeType;
Attribute indicesAttr;
StringRef indicesAttrName =
- spirv::CompositeInsertOp::getIndicesAttrName(result.name).strref();
+ spirv::CompositeInsertOp::getIndicesAttrName(result.name);
auto loc = parser.getCurrentLocation();
return failure(
@@ -564,7 +564,7 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
OperationState &result) {
Attribute value;
StringRef valueAttrName =
- spirv::ConstantOp::getValueAttrName(result.name).strref();
+ spirv::ConstantOp::getValueAttrName(result.name);
if (parser.parseAttribute(value, valueAttrName, result.attributes))
return failure();
@@ -829,7 +829,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
return failure();
}
result.addAttribute(
- spirv::EntryPointOp::getInterfaceAttrName(result.name).strref(),
+ spirv::EntryPointOp::getInterfaceAttrName(result.name),
parser.getBuilder().getArrayAttr(interfaceVars));
return success();
}
@@ -883,7 +883,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
values.push_back(llvm::cast<IntegerAttr>(value).getInt());
}
StringRef valuesAttrName =
- spirv::ExecutionModeOp::getValuesAttrName(result.name).strref();
+ spirv::ExecutionModeOp::getValuesAttrName(result.name);
result.addAttribute(valuesAttrName,
parser.getBuilder().getI32ArrayAttr(values));
return success();
@@ -1160,7 +1160,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
// Parse variable name.
StringAttr nameAttr;
StringRef initializerAttrName =
- spirv::GlobalVariableOp::getInitializerAttrName(result.name).strref();
+ spirv::GlobalVariableOp::getInitializerAttrName(result.name);
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes)) {
return failure();
@@ -1182,7 +1182,7 @@ ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
Type type;
StringRef typeAttrName =
- spirv::GlobalVariableOp::getTypeAttrName(result.name).strref();
+ spirv::GlobalVariableOp::getTypeAttrName(result.name);
auto loc = parser.getCurrentLocation();
if (parser.parseColonType(type)) {
return failure();
@@ -1204,7 +1204,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
printer.printSymbolName(getSymName());
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
- StringRef initializerAttrName = this->getInitializerAttrName().strref();
+ StringRef initializerAttrName = this->getInitializerAttrName();
// Print optional initializer
if (auto initializer = this->getInitializer()) {
printer << " " << initializerAttrName << '(';
@@ -1213,7 +1213,7 @@ void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
elidedAttrs.push_back(initializerAttrName);
}
- StringRef typeAttrName = this->getTypeAttrName().strref();
+ StringRef typeAttrName = this->getTypeAttrName();
elidedAttrs.push_back(typeAttrName);
spirv::printVariableDecorations(*this, printer, elidedAttrs);
printer << " : " << getType();
@@ -1235,7 +1235,7 @@ LogicalResult spirv::GlobalVariableOp::verify() {
}
if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
- this->getInitializerAttrName().strref())) {
+ this->getInitializerAttrName())) {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
@@ -1610,7 +1610,7 @@ ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
StringAttr nameAttr;
Attribute valueAttr;
StringRef defaultValueAttrName =
- spirv::SpecConstantOp::getDefaultValueAttrName(result.name).strref();
+ spirv::SpecConstantOp::getDefaultValueAttrName(result.name);
if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
result.attributes))
>From cb2215737478fa7030f106c8e8a07f1b1d804557 Mon Sep 17 00:00:00 2001
From: Lei Zhang <antiagainst at gmail.com>
Date: Sun, 25 Feb 2024 16:02:36 -0800
Subject: [PATCH 5/5] Improve a few callsites
---
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 6 ++----
mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp | 16 ++++++----------
mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp | 2 --
3 files changed, 8 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index c05e966268c6dd..3e317319b68fc5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -201,13 +201,11 @@ LogicalResult FunctionCallOp::verify() {
}
CallInterfaceCallable FunctionCallOp::getCallableForCallee() {
- return (*this)->getAttrOfType<SymbolRefAttr>(
- FunctionCallOp::getCalleeAttrName((*this)->getName()));
+ return (*this)->getAttrOfType<SymbolRefAttr>(getCalleeAttrName());
}
void FunctionCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
- (*this)->setAttr(FunctionCallOp::getCalleeAttrName((*this)->getName()),
- callee.get<SymbolRefAttr>());
+ (*this)->setAttr(getCalleeAttrName(), callee.get<SymbolRefAttr>());
}
Operation::operand_range FunctionCallOp::getArgOperands() {
diff --git a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
index 7c1d3c30f7f0d9..fcbef2be75f9a0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/GroupOps.cpp
@@ -20,7 +20,7 @@ using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
-template <typename GroupNonUniformArithmenticOpTy>
+template <typename OpTy>
static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OperationState &state) {
spirv::Scope executionScope;
@@ -28,12 +28,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand valueInfo;
if (spirv::parseEnumStrAttr<spirv::ScopeAttr>(
executionScope, parser, state,
- GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
- state.name)) ||
+ OpTy::getExecutionScopeAttrName(state.name)) ||
spirv::parseEnumStrAttr<GroupOperationAttr>(
groupOperation, parser, state,
- GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
- state.name)) ||
+ OpTy::getGroupOperationAttrName(state.name)) ||
parser.parseOperand(valueInfo))
return failure();
@@ -86,13 +84,12 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
printer << " : " << groupOp->getResult(0).getType();
}
-template <typename GroupNonUniformArithmenticOpTy>
+template <typename OpTy>
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope =
groupOp
->getAttrOfType<spirv::ScopeAttr>(
- GroupNonUniformArithmenticOpTy::getExecutionScopeAttrName(
- groupOp->getName()))
+ OpTy::getExecutionScopeAttrName(groupOp->getName()))
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
@@ -101,8 +98,7 @@ static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
GroupOperation operation =
groupOp
->getAttrOfType<GroupOperationAttr>(
- GroupNonUniformArithmenticOpTy::getGroupOperationAttrName(
- groupOp->getName()))
+ OpTy::getGroupOperationAttrName(groupOp->getName()))
.getValue();
if (operation == GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
index 1fd25f1ac92e06..726fb7ce9fdb74 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVParsingUtils.cpp
@@ -12,8 +12,6 @@
#include "SPIRVParsingUtils.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
-
#include "llvm/ADT/StringExtras.h"
using namespace mlir::spirv::AttrNames;
More information about the Mlir-commits
mailing list