[Mlir-commits] [mlir] d2f0fe2 - [mlir][OpenMP] Added assemblyFormat for atomic and critical operations
Shraiysh Vaishay
llvmlistbot at llvm.org
Tue Mar 1 21:52:25 PST 2022
Author: Shraiysh Vaishay
Date: 2022-03-02T11:22:09+05:30
New Revision: d2f0fe23d2375da1b8caf510cc3c481398694101
URL: https://github.com/llvm/llvm-project/commit/d2f0fe23d2375da1b8caf510cc3c481398694101
DIFF: https://github.com/llvm/llvm-project/commit/d2f0fe23d2375da1b8caf510cc3c481398694101.diff
LOG: [mlir][OpenMP] Added assemblyFormat for atomic and critical operations
This patch adds assemblyFormat for `omp.critical.declare`, `omp.atomic.read`,
`omp.atomic.write`, `omp.atomic.update` and `omp.atomic.capture`.
Also removing those clauses from `parseClauses` that aren't needed
anymore, thanks to the new assemblyFormats.
Reviewed By: NimishMishra, rriddle
Differential Revision: https://reviews.llvm.org/D120248
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index cd3ea1f470c43..471dcf24667ba 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -105,7 +105,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocate_vars, type($allocate_vars),
$allocators_vars, type($allocators_vars)
) `)`
- | `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
+ | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) $region attr-dict
}];
let hasVerifier = 1;
@@ -317,6 +317,10 @@ def YieldOp : OpenMP_Op<"yield",
let arguments = (ins Variadic<AnyType>:$results);
+ let builders = [
+ OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
+ ];
+
let assemblyFormat = [{ ( `(` $results^ `:` type($results) `)` )? attr-dict}];
}
@@ -421,10 +425,11 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
}];
let arguments = (ins SymbolNameAttr:$sym_name,
- DefaultValuedAttr<I64Attr, "0">:$hint);
+ DefaultValuedAttr<I64Attr, "0">:$hint_val);
let assemblyFormat = [{
- $sym_name custom<SynchronizationHint>($hint) attr-dict
+ $sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+ attr-dict
}];
let hasVerifier = 1;
}
@@ -555,7 +560,7 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
// value of the clause) here decomposes handling of this construct into a
// two-step process.
-def AtomicReadOp : OpenMP_Op<"atomic.read"> {
+def AtomicReadOp : OpenMP_Op<"atomic.read", [AllTypesMatch<["x", "v"]>]> {
let summary = "performs an atomic read";
@@ -570,14 +575,19 @@ def AtomicReadOp : OpenMP_Op<"atomic.read"> {
optimization.
`memory_order` indicates the memory ordering behavior of the construct. It
- can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+ can be one of `seq_cst`, `acquire` or `relaxed`.
}];
let arguments = (ins OpenMP_PointerLikeType:$x,
OpenMP_PointerLikeType:$v,
- DefaultValuedAttr<I64Attr, "0">:$hint,
- OptionalAttr<MemoryOrderKindAttr>:$memory_order);
- let hasCustomAssemblyFormat = 1;
+ DefaultValuedAttr<I64Attr, "0">:$hint_val,
+ OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
+ let assemblyFormat = [{
+ $v `=` $x
+ oilist( `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+ | `hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+ `:` type($x) attr-dict
+ }];
let hasVerifier = 1;
}
@@ -598,14 +608,20 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> {
optimization.
`memory_order` indicates the memory ordering behavior of the construct. It
- can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+ can be one of `seq_cst`, `release` or `relaxed`.
}];
let arguments = (ins OpenMP_PointerLikeType:$address,
AnyType:$value,
- DefaultValuedAttr<I64Attr, "0">:$hint,
- OptionalAttr<MemoryOrderKindAttr>:$memory_order);
- let hasCustomAssemblyFormat = 1;
+ DefaultValuedAttr<I64Attr, "0">:$hint_val,
+ OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
+ let assemblyFormat = [{
+ $address `=` $value
+ oilist( `hint` `(` custom<SynchronizationHint>($hint_val) `)`
+ | `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`)
+ `:` type($address) `,` type($value)
+ attr-dict
+ }];
let hasVerifier = 1;
}
@@ -625,7 +641,7 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
time constant. As the name suggests, this is just a hint for optimization.
`memory_order` indicates the memory ordering behavior of the construct. It
- can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+ can be one of `seq_cst`, `release` or `relaxed`.
The region describes how to update the value of `x`. It takes the value at
`x` as an input and must yield the updated value. Only the update to `x` is
@@ -635,10 +651,14 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
}];
let arguments = (ins OpenMP_PointerLikeType:$x,
- DefaultValuedAttr<I64Attr, "0">:$hint,
- OptionalAttr<MemoryOrderKindAttr>:$memory_order);
+ DefaultValuedAttr<I64Attr, "0">:$hint_val,
+ OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
let regions = (region SizedRegion<1>:$region);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ oilist( `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+ | `hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+ $x `:` type($x) $region attr-dict
+ }];
let hasVerifier = 1;
}
@@ -678,10 +698,14 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture",
}];
- let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint,
- OptionalAttr<MemoryOrderKind>:$memory_order);
+ let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint_val,
+ OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
let regions = (region SizedRegion<1>:$region);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ oilist(`memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+ |`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+ $region attr-dict
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index de526482ae5f0..b4e08ac7b9a1c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -119,6 +119,30 @@ static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
}
}
+/// Parse a clause attribute (StringEnumAttr)
+template <typename ClauseAttr>
+static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
+ using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
+ StringRef enumStr;
+ SMLoc loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&enumStr))
+ return failure();
+ if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
+ attr = ClauseAttr::get(parser.getContext(), *enumValue);
+ return success();
+ }
+ return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
+}
+
+template <typename ClauseAttr>
+void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
+ p << stringifyEnum(attr.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for Procbind Clause
+//===----------------------------------------------------------------------===//
+
ParseResult parseProcBindKind(OpAsmParser &parser,
omp::ClauseProcBindKindAttr &procBindAttr) {
StringRef procBindStr;
@@ -193,7 +217,7 @@ static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
}
//===----------------------------------------------------------------------===//
-// Parser and printer for Schedule Clause
+// Parser, printer and verifier for Schedule Clause
//===----------------------------------------------------------------------===//
static ParseResult
@@ -379,15 +403,7 @@ static LogicalResult verifyReductionVarList(Operation *op,
///
/// hint-clause = `hint` `(` hint-value `)`
static ParseResult parseSynchronizationHint(OpAsmParser &parser,
- IntegerAttr &hintAttr,
- bool parseKeyword = true) {
- if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
- hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
- return success();
- }
-
- if (failed(parser.parseLParen()))
- return failure();
+ IntegerAttr &hintAttr) {
StringRef hintKeyword;
int64_t hint = 0;
do {
@@ -405,8 +421,6 @@ static ParseResult parseSynchronizationHint(OpAsmParser &parser,
return parser.emitError(parser.getCurrentLocation())
<< hintKeyword << " is not a valid hint";
} while (succeeded(parser.parseOptionalComma()));
- if (failed(parser.parseRParen()))
- return failure();
hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
return success();
}
@@ -437,9 +451,7 @@ static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
if (speculative)
hints.push_back("speculative");
- p << "hint(";
llvm::interleaveComma(hints, p);
- p << ") ";
}
/// Verifies a synchronization hint clause
@@ -463,12 +475,7 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
}
enum ClauseType {
- ifClause,
- numThreadsClause,
- deviceClause,
- threadLimitClause,
allocateClause,
- procBindClause,
reductionClause,
nowaitClause,
linearClause,
@@ -476,8 +483,6 @@ enum ClauseType {
collapseClause,
orderClause,
orderedClause,
- memoryOrderClause,
- hintClause,
COUNT
};
@@ -485,35 +490,14 @@ enum ClauseType {
// Parser for Clause List
//===----------------------------------------------------------------------===//
-/// Parse a clause attribute `(` $value `)`.
-template <typename ClauseAttr>
-static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
- StringRef attrName, StringRef name) {
- using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
- StringRef enumStr;
- SMLoc loc = parser.getCurrentLocation();
- if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
- parser.parseRParen())
- return failure();
- if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
- auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
- state.addAttribute(attrName, attr);
- return success();
- }
- return parser.emitError(loc, "invalid ") << name << " kind";
-}
-
/// Parse a list of clauses. The clauses can appear in any order, but their
/// operand segment indices are in the same order that they are passed in the
/// `clauses` list. The operand segments are added over the prevSegments
/// clause-list ::= clause clause-list | empty
-/// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait
-/// | linear | schedule | collapse | order | ordered | inclusive
-/// if ::= `if` `(` ssa-id-and-type `)`
-/// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
+/// clause ::= allocate | reduction | nowait | linear | schedule | collapse
+/// | order | ordered
/// allocate ::= `allocate` `(` allocate-operand-list `)`
-/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
/// reduction ::= `reduction` `(` reduction-entry-list `)`
/// nowait ::= `nowait`
/// linear ::= `linear` `(` linear-list `)`
@@ -521,7 +505,6 @@ static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
/// collapse ::= `collapse` `(` ssa-id-and-type `)`
/// order ::= `order` `(` `concurrent` `)`
/// ordered ::= `ordered` `(` ssa-id-and-type `)`
-/// inclusive ::= `inclusive`
///
/// Note that each clause can only appear once in the clase-list.
static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
@@ -539,11 +522,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
StringRef opName = result.name.getStringRef();
// Containers for storing operands, types and attributes for various clauses
- std::pair<OpAsmParser::OperandType, Type> ifCond;
- std::pair<OpAsmParser::OperandType, Type> numThreads;
- std::pair<OpAsmParser::OperandType, Type> device;
- std::pair<OpAsmParser::OperandType, Type> threadLimit;
-
SmallVector<OpAsmParser::OperandType> allocates, allocators;
SmallVector<Type> allocateTypes, allocatorTypes;
@@ -566,9 +544,8 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
// Skip the following clauses - they do not take any position in operand
// segments
- if (clause == procBindClause || clause == nowaitClause ||
- clause == collapseClause || clause == orderClause ||
- clause == orderedClause)
+ if (clause == nowaitClause || clause == collapseClause ||
+ clause == orderClause || clause == orderedClause)
continue;
pos[clause] = currPos++;
@@ -596,31 +573,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
};
while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
- if (clauseKeyword == "if") {
- if (checkAllowed(ifClause) || parser.parseLParen() ||
- parser.parseOperand(ifCond.first) ||
- parser.parseColonType(ifCond.second) || parser.parseRParen())
- return failure();
- clauseSegments[pos[ifClause]] = 1;
- } else if (clauseKeyword == "num_threads") {
- if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
- parser.parseOperand(numThreads.first) ||
- parser.parseColonType(numThreads.second) || parser.parseRParen())
- return failure();
- clauseSegments[pos[numThreadsClause]] = 1;
- } else if (clauseKeyword == "device") {
- if (checkAllowed(deviceClause) || parser.parseLParen() ||
- parser.parseOperand(device.first) ||
- parser.parseColonType(device.second) || parser.parseRParen())
- return failure();
- clauseSegments[pos[deviceClause]] = 1;
- } else if (clauseKeyword == "thread_limit") {
- if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
- parser.parseOperand(threadLimit.first) ||
- parser.parseColonType(threadLimit.second) || parser.parseRParen())
- return failure();
- clauseSegments[pos[threadLimitClause]] = 1;
- } else if (clauseKeyword == "allocate") {
+ if (clauseKeyword == "allocate") {
if (checkAllowed(allocateClause) || parser.parseLParen() ||
parseAllocateAndAllocator(parser, allocates, allocateTypes,
allocators, allocatorTypes) ||
@@ -628,11 +581,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
return failure();
clauseSegments[pos[allocateClause]] = allocates.size();
clauseSegments[pos[allocateClause] + 1] = allocators.size();
- } else if (clauseKeyword == "proc_bind") {
- if (checkAllowed(procBindClause) ||
- parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
- "proc_bind_val", "proc bind"))
- return failure();
} else if (clauseKeyword == "reduction") {
if (checkAllowed(reductionClause) || parser.parseLParen() ||
parseReductionVarList(parser, reductionVars, reductionVarTypes,
@@ -680,51 +628,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
}
result.addAttribute("ordered_val", attr);
} else if (clauseKeyword == "order") {
- if (checkAllowed(orderClause) ||
- parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
- "order"))
- return failure();
- } else if (clauseKeyword == "memory_order") {
- if (checkAllowed(memoryOrderClause) ||
- parseClauseAttr<ClauseMemoryOrderKindAttr>(
- parser, result, "memory_order", "memory order"))
- return failure();
- } else if (clauseKeyword == "hint") {
- IntegerAttr hint;
- if (checkAllowed(hintClause) ||
- parseSynchronizationHint(parser, hint, false))
+ ClauseOrderKindAttr order;
+ if (checkAllowed(orderClause) || parser.parseLParen() ||
+ parseClauseAttr<ClauseOrderKindAttr>(parser, order) ||
+ parser.parseRParen())
return failure();
- result.addAttribute("hint", hint);
+ result.addAttribute("order_val", order);
} else {
return parser.emitError(parser.getNameLoc())
<< clauseKeyword << " is not a valid clause";
}
}
- // Add if parameter.
- if (done[ifClause] && clauseSegments[pos[ifClause]] &&
- failed(
- parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
- return failure();
-
- // Add num_threads parameter.
- if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
- failed(parser.resolveOperand(numThreads.first, numThreads.second,
- result.operands)))
- return failure();
-
- // Add device parameter.
- if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
- failed(
- parser.resolveOperand(device.first, device.second, result.operands)))
- return failure();
-
- // Add thread_limit parameter.
- if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
- failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
- result.operands)))
- return failure();
-
// Add allocate parameters.
if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
failed(parser.resolveOperands(allocates, allocateTypes,
@@ -1024,7 +939,7 @@ LogicalResult WsLoopOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult CriticalDeclareOp::verify() {
- return verifySynchronizationHint(*this, hint());
+ return verifySynchronizationHint(*this, hint_val());
}
LogicalResult CriticalOp::verify() {
@@ -1079,41 +994,11 @@ LogicalResult OrderedRegionOp::verify() {
}
//===----------------------------------------------------------------------===//
-// AtomicReadOp
+// Verifier for AtomicReadOp
//===----------------------------------------------------------------------===//
-/// Parser for AtomicReadOp
-///
-/// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
-/// address ::= operand `:` type
-ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType x, v;
- Type addressType;
- SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
- SmallVector<int> segments;
-
- if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
- parseClauses(parser, result, clauses, segments) ||
- parser.parseColonType(addressType) ||
- parser.resolveOperand(x, addressType, result.operands) ||
- parser.resolveOperand(v, addressType, result.operands))
- return failure();
-
- return success();
-}
-
-void AtomicReadOp::print(OpAsmPrinter &p) {
- p << " " << v() << " = " << x() << " ";
- if (auto mo = memory_order())
- p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
- if (hintAttr())
- printSynchronizationHint(p << " ", *this, hintAttr());
- p << ": " << x().getType();
-}
-
-/// Verifier for AtomicReadOp
LogicalResult AtomicReadOp::verify() {
- if (auto mo = memory_order()) {
+ if (auto mo = memory_order_val()) {
if (*mo == ClauseMemoryOrderKind::acq_rel ||
*mo == ClauseMemoryOrderKind::release) {
return emitError(
@@ -1123,92 +1008,30 @@ LogicalResult AtomicReadOp::verify() {
if (x() == v())
return emitError(
"read and write must not be to the same location for atomic reads");
- return verifySynchronizationHint(*this, hint());
+ return verifySynchronizationHint(*this, hint_val());
}
//===----------------------------------------------------------------------===//
-// AtomicWriteOp
+// Verifier for AtomicWriteOp
//===----------------------------------------------------------------------===//
-/// Parser for AtomicWriteOp
-///
-/// operation ::= `omp.atomic.write` atomic-clause-list operands
-/// operands ::= address `,` value
-/// address ::= operand `:` type
-/// value ::= operand `:` type
-ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) {
- OpAsmParser::OperandType address, value;
- Type addrType, valueType;
- SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
- SmallVector<int> segments;
-
- if (parser.parseOperand(address) || parser.parseEqual() ||
- parser.parseOperand(value) ||
- parseClauses(parser, result, clauses, segments) ||
- parser.parseColonType(addrType) || parser.parseComma() ||
- parser.parseType(valueType) ||
- parser.resolveOperand(address, addrType, result.operands) ||
- parser.resolveOperand(value, valueType, result.operands))
- return failure();
- return success();
-}
-
-void AtomicWriteOp::print(OpAsmPrinter &p) {
- p << " " << address() << " = " << value() << " ";
- if (auto mo = memory_order())
- p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
- if (hintAttr())
- printSynchronizationHint(p, *this, hintAttr());
- p << ": " << address().getType() << ", " << value().getType();
-}
-
-/// Verifier for AtomicWriteOp
LogicalResult AtomicWriteOp::verify() {
- if (auto mo = memory_order()) {
+ if (auto mo = memory_order_val()) {
if (*mo == ClauseMemoryOrderKind::acq_rel ||
*mo == ClauseMemoryOrderKind::acquire) {
return emitError(
"memory-order must not be acq_rel or acquire for atomic writes");
}
}
- return verifySynchronizationHint(*this, hint());
+ return verifySynchronizationHint(*this, hint_val());
}
//===----------------------------------------------------------------------===//
-// AtomicUpdateOp
+// Verifier for AtomicUpdateOp
//===----------------------------------------------------------------------===//
-/// Parser for AtomicUpdateOp
-///
-/// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region
-ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) {
- SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
- SmallVector<int> segments;
- OpAsmParser::OperandType x, expr;
- Type xType;
-
- if (parseClauses(parser, result, clauses, segments) ||
- parser.parseOperand(x) || parser.parseColon() ||
- parser.parseType(xType) ||
- parser.resolveOperand(x, xType, result.operands) ||
- parser.parseRegion(*result.addRegion()))
- return failure();
- return success();
-}
-
-void AtomicUpdateOp::print(OpAsmPrinter &p) {
- p << " ";
- if (auto mo = memory_order())
- p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
- if (hintAttr())
- printSynchronizationHint(p, *this, hintAttr());
- p << x() << " : " << x().getType();
- p.printRegion(region());
-}
-
-/// Verifier for AtomicUpdateOp
LogicalResult AtomicUpdateOp::verify() {
- if (auto mo = memory_order()) {
+ if (auto mo = memory_order_val()) {
if (*mo == ClauseMemoryOrderKind::acq_rel ||
*mo == ClauseMemoryOrderKind::acquire) {
return emitError(
@@ -1235,28 +1058,9 @@ LogicalResult AtomicUpdateOp::verify() {
}
//===----------------------------------------------------------------------===//
-// AtomicCaptureOp
+// Verifier for AtomicCaptureOp
//===----------------------------------------------------------------------===//
-ParseResult AtomicCaptureOp::parse(OpAsmParser &parser,
- OperationState &result) {
- SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
- SmallVector<int> segments;
- if (parseClauses(parser, result, clauses, segments) ||
- parser.parseRegion(*result.addRegion()))
- return failure();
- return success();
-}
-
-void AtomicCaptureOp::print(OpAsmPrinter &p) {
- if (memory_order())
- p << "memory_order(" << memory_order() << ") ";
- if (hintAttr())
- printSynchronizationHint(p, *this, hintAttr());
- p.printRegion(region());
-}
-
-/// Verifier for AtomicCaptureOp
LogicalResult AtomicCaptureOp::verify() {
Block::OpListType &ops = region().front().getOperations();
if (ops.size() != 3)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d3b43b423c952..35fa843a553d5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -335,8 +335,9 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
auto criticalDeclareOp =
SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
symbolRef);
- hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
- static_cast<int>(criticalDeclareOp.hint()));
+ hint =
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
+ static_cast<int>(criticalDeclareOp.hint_val()));
}
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
@@ -910,7 +911,7 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order());
+ llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order_val());
llvm::Value *x = moduleTranslation.lookupValue(readOp.x());
Type xTy = readOp.x().getType().cast<omp::PointerLikeType>().getElementType();
llvm::Value *v = moduleTranslation.lookupValue(readOp.v());
@@ -931,7 +932,7 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order());
+ llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order_val());
llvm::Value *expr = moduleTranslation.lookupValue(writeOp.value());
llvm::Value *dest = moduleTranslation.lookupValue(writeOp.address());
llvm::Type *ty = moduleTranslation.convertType(writeOp.value().getType());
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a991d5f20f6b7..2ac177f4da2cd 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -97,7 +97,7 @@ func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
// -----
func @order_value(%lb : index, %ub : index, %step : index) {
- // expected-error @below {{invalid order kind}}
+ // expected-error @below {{invalid clause value: 'default'}}
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) {
omp.yield
}
@@ -106,7 +106,7 @@ func @order_value(%lb : index, %ub : index, %step : index) {
// -----
func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
- // expected-error @below {{if is not a valid clause for the omp.wsloop operation}}
+ // expected-error @below {{if is not a valid clause}}
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) {
omp.yield
}
@@ -115,7 +115,7 @@ func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
// -----
func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) {
- // expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}}
+ // expected-error @below {{num_threads is not a valid clause}}
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) {
omp.yield
}
@@ -124,7 +124,7 @@ func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var
// -----
func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) {
- // expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}}
+ // expected-error @below {{proc_bind is not a valid clause}}
omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) {
omp.yield
}
@@ -429,7 +429,7 @@ func @omp_atomic_read1(%x: memref<i32>, %v: memref<i32>) {
// -----
func @omp_atomic_read2(%x: memref<i32>, %v: memref<i32>) {
- // expected-error @below {{invalid memory order kind}}
+ // expected-error @below {{invalid clause value: 'xyz'}}
omp.atomic.read %v = %x memory_order(xyz) : memref<i32>
return
}
@@ -453,7 +453,7 @@ func @omp_atomic_read4(%x: memref<i32>, %v: memref<i32>) {
// -----
func @omp_atomic_read5(%x: memref<i32>, %v: memref<i32>) {
- // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.read operation}}
+ // expected-error @below {{`memory_order` clause can appear at most once in the expansion of the oilist directive}}
omp.atomic.read %v = %x memory_order(acquire) memory_order(relaxed) : memref<i32>
return
}
@@ -461,7 +461,7 @@ func @omp_atomic_read5(%x: memref<i32>, %v: memref<i32>) {
// -----
func @omp_atomic_read6(%x: memref<i32>, %v: memref<i32>) {
- // expected-error @below {{at most one hint clause can appear on the omp.atomic.read operation}}
+ // expected-error @below {{`hint` clause can appear at most once in the expansion of the oilist directive}}
omp.atomic.read %v = %x hint(speculative) hint(contended) : memref<i32>
return
}
@@ -501,7 +501,7 @@ func @omp_atomic_write3(%addr : memref<i32>, %val : i32) {
// -----
func @omp_atomic_write4(%addr : memref<i32>, %val : i32) {
- // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.write operation}}
+ // expected-error @below {{`memory_order` clause can appear at most once in the expansion of the oilist directive}}
omp.atomic.write %addr = %val memory_order(release) memory_order(seq_cst) : memref<i32>, i32
return
}
@@ -509,7 +509,7 @@ func @omp_atomic_write4(%addr : memref<i32>, %val : i32) {
// -----
func @omp_atomic_write5(%addr : memref<i32>, %val : i32) {
- // expected-error @below {{at most one hint clause can appear on the omp.atomic.write operation}}
+ // expected-error @below {{`hint` clause can appear at most once in the expansion of the oilist directive}}
omp.atomic.write %addr = %val hint(contended) hint(speculative) : memref<i32>, i32
return
}
@@ -517,7 +517,7 @@ func @omp_atomic_write5(%addr : memref<i32>, %val : i32) {
// -----
func @omp_atomic_write6(%addr : memref<i32>, %val : i32) {
- // expected-error @below {{invalid memory order kind}}
+ // expected-error @below {{invalid clause value: 'xyz'}}
omp.atomic.write %addr = %val memory_order(xyz) : memref<i32>, i32
return
}
@@ -573,9 +573,8 @@ func @omp_atomic_update4(%x: memref<i32>, %expr: i32) {
// -----
-// expected-note @below {{prior use here}}
func @omp_atomic_update5(%x: memref<i32>, %expr: i32) {
- // expected-error @below {{use of value '%x' expects
diff erent type than prior uses: 'i32' vs 'memref<i32>'}}
+ // expected-error @below {{invalid kind of type specified}}
omp.atomic.update %x : i32 {
^bb0(%xval: i32):
%newval = llvm.add %xval, %expr : i32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e2cc900bf3787..3d6834f0d9345 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -568,14 +568,13 @@ func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, %exp
// CHECK-LABEL: omp_atomic_capture
// CHECK-SAME: (%[[v:.*]]: memref<i32>, %[[x:.*]]: memref<i32>, %[[expr:.*]]: i32)
func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
- // CHECK: omp.atomic.capture{
+ // CHECK: omp.atomic.capture {
// CHECK-NEXT: omp.atomic.update %[[x]] : memref<i32>
// CHECK-NEXT: (%[[xval:.*]]: i32):
// CHECK-NEXT: %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32
// CHECK-NEXT: omp.yield(%[[newval]] : i32)
// CHECK-NEXT: }
// CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
- // CHECK-NEXT: omp.terminator
// CHECK-NEXT: }
omp.atomic.capture{
omp.atomic.update %x : memref<i32> {
@@ -584,16 +583,14 @@ func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
omp.yield(%newval : i32)
}
omp.atomic.read %v = %x : memref<i32>
- omp.terminator
}
- // CHECK: omp.atomic.capture{
+ // CHECK: omp.atomic.capture {
// CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
// CHECK-NEXT: omp.atomic.update %[[x]] : memref<i32>
// CHECK-NEXT: (%[[xval:.*]]: i32):
// CHECK-NEXT: %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32
// CHECK-NEXT: omp.yield(%[[newval]] : i32)
// CHECK-NEXT: }
- // CHECK-NEXT: omp.terminator
// CHECK-NEXT: }
omp.atomic.capture{
omp.atomic.read %v = %x : memref<i32>
@@ -602,17 +599,14 @@ func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
%newval = llvm.add %xval, %expr : i32
omp.yield(%newval : i32)
}
- omp.terminator
}
- // CHECK: omp.atomic.capture{
+ // CHECK: omp.atomic.capture {
// CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
// CHECK-NEXT: omp.atomic.write %[[x]] = %[[expr]] : memref<i32>, i32
- // CHECK-NEXT: omp.terminator
// CHECK-NEXT: }
omp.atomic.capture{
omp.atomic.read %v = %x : memref<i32>
omp.atomic.write %x = %expr : memref<i32>, i32
- omp.terminator
}
return
}
More information about the Mlir-commits
mailing list