[Mlir-commits] [mlir] d2f0fe2 - [mlir][OpenMP] Added assemblyFormat for atomic and critical operations

Shraiysh Vaishay llvmlistbot at llvm.org
Tue Mar 1 21:52:25 PST 2022


Author: Shraiysh Vaishay
Date: 2022-03-02T11:22:09+05:30
New Revision: d2f0fe23d2375da1b8caf510cc3c481398694101

URL: https://github.com/llvm/llvm-project/commit/d2f0fe23d2375da1b8caf510cc3c481398694101
DIFF: https://github.com/llvm/llvm-project/commit/d2f0fe23d2375da1b8caf510cc3c481398694101.diff

LOG: [mlir][OpenMP] Added assemblyFormat for atomic and critical operations

This patch adds assemblyFormat for `omp.critical.declare`, `omp.atomic.read`,
`omp.atomic.write`, `omp.atomic.update` and `omp.atomic.capture`.

Also removing those clauses from `parseClauses` that aren't needed
anymore, thanks to the new assemblyFormats.

Reviewed By: NimishMishra, rriddle

Differential Revision: https://reviews.llvm.org/D120248

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index cd3ea1f470c43..471dcf24667ba 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -105,7 +105,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
                 $allocate_vars, type($allocate_vars),
                 $allocators_vars, type($allocators_vars)
               ) `)`
-          | `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
+          | `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
     ) $region attr-dict
   }];
   let hasVerifier = 1;
@@ -317,6 +317,10 @@ def YieldOp : OpenMP_Op<"yield",
 
   let arguments = (ins Variadic<AnyType>:$results);
 
+  let builders = [
+    OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
+  ];
+
   let assemblyFormat = [{ ( `(` $results^ `:` type($results) `)` )? attr-dict}];
 }
 
@@ -421,10 +425,11 @@ def CriticalDeclareOp : OpenMP_Op<"critical.declare", [Symbol]> {
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name,
-                       DefaultValuedAttr<I64Attr, "0">:$hint);
+                       DefaultValuedAttr<I64Attr, "0">:$hint_val);
 
   let assemblyFormat = [{
-    $sym_name custom<SynchronizationHint>($hint) attr-dict
+    $sym_name oilist(`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+    attr-dict
   }];
   let hasVerifier = 1;
 }
@@ -555,7 +560,7 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
 // value of the clause) here decomposes handling of this construct into a
 // two-step process.
 
-def AtomicReadOp : OpenMP_Op<"atomic.read"> {
+def AtomicReadOp : OpenMP_Op<"atomic.read", [AllTypesMatch<["x", "v"]>]> {
 
   let summary = "performs an atomic read";
 
@@ -570,14 +575,19 @@ def AtomicReadOp : OpenMP_Op<"atomic.read"> {
     optimization.
 
     `memory_order` indicates the memory ordering behavior of the construct. It
-    can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+    can be one of `seq_cst`, `acquire` or `relaxed`.
   }];
 
   let arguments = (ins OpenMP_PointerLikeType:$x,
                        OpenMP_PointerLikeType:$v,
-                       DefaultValuedAttr<I64Attr, "0">:$hint,
-                       OptionalAttr<MemoryOrderKindAttr>:$memory_order);
-  let hasCustomAssemblyFormat = 1;
+                       DefaultValuedAttr<I64Attr, "0">:$hint_val,
+                       OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
+  let assemblyFormat = [{
+    $v `=` $x 
+    oilist( `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+          | `hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+    `:` type($x) attr-dict
+  }];
   let hasVerifier = 1;
 }
 
@@ -598,14 +608,20 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write"> {
     optimization.
 
     `memory_order` indicates the memory ordering behavior of the construct. It
-    can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+    can be one of `seq_cst`, `release` or `relaxed`.
   }];
 
   let arguments = (ins OpenMP_PointerLikeType:$address,
                        AnyType:$value,
-                       DefaultValuedAttr<I64Attr, "0">:$hint,
-                       OptionalAttr<MemoryOrderKindAttr>:$memory_order);
-  let hasCustomAssemblyFormat = 1;
+                       DefaultValuedAttr<I64Attr, "0">:$hint_val,
+                       OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
+  let assemblyFormat = [{
+    $address `=` $value
+    oilist( `hint` `(` custom<SynchronizationHint>($hint_val) `)`
+          | `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`)
+    `:` type($address) `,` type($value)
+    attr-dict
+  }];
   let hasVerifier = 1;
 }
 
@@ -625,7 +641,7 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
     time constant. As the name suggests, this is just a hint for optimization.
 
     `memory_order` indicates the memory ordering behavior of the construct. It
-    can be one of `seq_cst`, `acq_rel`, `release`, `acquire` or `relaxed`.
+    can be one of `seq_cst`, `release` or `relaxed`.
 
     The region describes how to update the value of `x`. It takes the value at
     `x` as an input and must yield the updated value. Only the update to `x` is
@@ -635,10 +651,14 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update",
   }];
 
   let arguments = (ins OpenMP_PointerLikeType:$x,
-                       DefaultValuedAttr<I64Attr, "0">:$hint,
-                       OptionalAttr<MemoryOrderKindAttr>:$memory_order);
+                       DefaultValuedAttr<I64Attr, "0">:$hint_val,
+                       OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
   let regions = (region SizedRegion<1>:$region);
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    oilist( `memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+          | `hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+    $x `:` type($x) $region attr-dict
+  }];
   let hasVerifier = 1;
 }
 
@@ -678,10 +698,14 @@ def AtomicCaptureOp : OpenMP_Op<"atomic.capture",
 
   }];
 
-  let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint,
-                       OptionalAttr<MemoryOrderKind>:$memory_order);
+  let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint_val,
+                       OptionalAttr<MemoryOrderKindAttr>:$memory_order_val);
   let regions = (region SizedRegion<1>:$region);
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    oilist(`memory_order` `(` custom<ClauseAttr>($memory_order_val) `)`
+          |`hint` `(` custom<SynchronizationHint>($hint_val) `)`)
+    $region attr-dict
+  }];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index de526482ae5f0..b4e08ac7b9a1c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -119,6 +119,30 @@ static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
   }
 }
 
+/// Parse a clause attribute (StringEnumAttr)
+template <typename ClauseAttr>
+static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
+  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
+  StringRef enumStr;
+  SMLoc loc = parser.getCurrentLocation();
+  if (parser.parseKeyword(&enumStr))
+    return failure();
+  if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
+    attr = ClauseAttr::get(parser.getContext(), *enumValue);
+    return success();
+  }
+  return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
+}
+
+template <typename ClauseAttr>
+void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
+  p << stringifyEnum(attr.getValue());
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for Procbind Clause
+//===----------------------------------------------------------------------===//
+
 ParseResult parseProcBindKind(OpAsmParser &parser,
                               omp::ClauseProcBindKindAttr &procBindAttr) {
   StringRef procBindStr;
@@ -193,7 +217,7 @@ static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
 }
 
 //===----------------------------------------------------------------------===//
-// Parser and printer for Schedule Clause
+// Parser, printer and verifier for Schedule Clause
 //===----------------------------------------------------------------------===//
 
 static ParseResult
@@ -379,15 +403,7 @@ static LogicalResult verifyReductionVarList(Operation *op,
 ///
 /// hint-clause = `hint` `(` hint-value `)`
 static ParseResult parseSynchronizationHint(OpAsmParser &parser,
-                                            IntegerAttr &hintAttr,
-                                            bool parseKeyword = true) {
-  if (parseKeyword && failed(parser.parseOptionalKeyword("hint"))) {
-    hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
-    return success();
-  }
-
-  if (failed(parser.parseLParen()))
-    return failure();
+                                            IntegerAttr &hintAttr) {
   StringRef hintKeyword;
   int64_t hint = 0;
   do {
@@ -405,8 +421,6 @@ static ParseResult parseSynchronizationHint(OpAsmParser &parser,
       return parser.emitError(parser.getCurrentLocation())
              << hintKeyword << " is not a valid hint";
   } while (succeeded(parser.parseOptionalComma()));
-  if (failed(parser.parseRParen()))
-    return failure();
   hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
   return success();
 }
@@ -437,9 +451,7 @@ static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
   if (speculative)
     hints.push_back("speculative");
 
-  p << "hint(";
   llvm::interleaveComma(hints, p);
-  p << ") ";
 }
 
 /// Verifies a synchronization hint clause
@@ -463,12 +475,7 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
 }
 
 enum ClauseType {
-  ifClause,
-  numThreadsClause,
-  deviceClause,
-  threadLimitClause,
   allocateClause,
-  procBindClause,
   reductionClause,
   nowaitClause,
   linearClause,
@@ -476,8 +483,6 @@ enum ClauseType {
   collapseClause,
   orderClause,
   orderedClause,
-  memoryOrderClause,
-  hintClause,
   COUNT
 };
 
@@ -485,35 +490,14 @@ enum ClauseType {
 // Parser for Clause List
 //===----------------------------------------------------------------------===//
 
-/// Parse a clause attribute `(` $value `)`.
-template <typename ClauseAttr>
-static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
-                                   StringRef attrName, StringRef name) {
-  using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
-  StringRef enumStr;
-  SMLoc loc = parser.getCurrentLocation();
-  if (parser.parseLParen() || parser.parseKeyword(&enumStr) ||
-      parser.parseRParen())
-    return failure();
-  if (Optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
-    auto attr = ClauseAttr::get(parser.getContext(), *enumValue);
-    state.addAttribute(attrName, attr);
-    return success();
-  }
-  return parser.emitError(loc, "invalid ") << name << " kind";
-}
-
 /// Parse a list of clauses. The clauses can appear in any order, but their
 /// operand segment indices are in the same order that they are passed in the
 /// `clauses` list. The operand segments are added over the prevSegments
 
 /// clause-list ::= clause clause-list | empty
-/// clause ::= if | num-threads | allocate | proc-bind | reduction | nowait
-///          | linear | schedule | collapse | order | ordered | inclusive
-/// if ::= `if` `(` ssa-id-and-type `)`
-/// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
+/// clause ::= allocate | reduction | nowait | linear | schedule | collapse
+///          | order | ordered
 /// allocate ::= `allocate` `(` allocate-operand-list `)`
-/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
 /// reduction ::= `reduction` `(` reduction-entry-list `)`
 /// nowait ::= `nowait`
 /// linear ::= `linear` `(` linear-list `)`
@@ -521,7 +505,6 @@ static ParseResult parseClauseAttr(AsmParser &parser, OperationState &state,
 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
 /// order ::= `order` `(` `concurrent` `)`
 /// ordered ::= `ordered` `(` ssa-id-and-type `)`
-/// inclusive ::= `inclusive`
 ///
 /// Note that each clause can only appear once in the clase-list.
 static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
@@ -539,11 +522,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
   StringRef opName = result.name.getStringRef();
 
   // Containers for storing operands, types and attributes for various clauses
-  std::pair<OpAsmParser::OperandType, Type> ifCond;
-  std::pair<OpAsmParser::OperandType, Type> numThreads;
-  std::pair<OpAsmParser::OperandType, Type> device;
-  std::pair<OpAsmParser::OperandType, Type> threadLimit;
-
   SmallVector<OpAsmParser::OperandType> allocates, allocators;
   SmallVector<Type> allocateTypes, allocatorTypes;
 
@@ -566,9 +544,8 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
 
     // Skip the following clauses - they do not take any position in operand
     // segments
-    if (clause == procBindClause || clause == nowaitClause ||
-        clause == collapseClause || clause == orderClause ||
-        clause == orderedClause)
+    if (clause == nowaitClause || clause == collapseClause ||
+        clause == orderClause || clause == orderedClause)
       continue;
 
     pos[clause] = currPos++;
@@ -596,31 +573,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
   };
 
   while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
-    if (clauseKeyword == "if") {
-      if (checkAllowed(ifClause) || parser.parseLParen() ||
-          parser.parseOperand(ifCond.first) ||
-          parser.parseColonType(ifCond.second) || parser.parseRParen())
-        return failure();
-      clauseSegments[pos[ifClause]] = 1;
-    } else if (clauseKeyword == "num_threads") {
-      if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
-          parser.parseOperand(numThreads.first) ||
-          parser.parseColonType(numThreads.second) || parser.parseRParen())
-        return failure();
-      clauseSegments[pos[numThreadsClause]] = 1;
-    } else if (clauseKeyword == "device") {
-      if (checkAllowed(deviceClause) || parser.parseLParen() ||
-          parser.parseOperand(device.first) ||
-          parser.parseColonType(device.second) || parser.parseRParen())
-        return failure();
-      clauseSegments[pos[deviceClause]] = 1;
-    } else if (clauseKeyword == "thread_limit") {
-      if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
-          parser.parseOperand(threadLimit.first) ||
-          parser.parseColonType(threadLimit.second) || parser.parseRParen())
-        return failure();
-      clauseSegments[pos[threadLimitClause]] = 1;
-    } else if (clauseKeyword == "allocate") {
+    if (clauseKeyword == "allocate") {
       if (checkAllowed(allocateClause) || parser.parseLParen() ||
           parseAllocateAndAllocator(parser, allocates, allocateTypes,
                                     allocators, allocatorTypes) ||
@@ -628,11 +581,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
         return failure();
       clauseSegments[pos[allocateClause]] = allocates.size();
       clauseSegments[pos[allocateClause] + 1] = allocators.size();
-    } else if (clauseKeyword == "proc_bind") {
-      if (checkAllowed(procBindClause) ||
-          parseClauseAttr<ClauseProcBindKindAttr>(parser, result,
-                                                  "proc_bind_val", "proc bind"))
-        return failure();
     } else if (clauseKeyword == "reduction") {
       if (checkAllowed(reductionClause) || parser.parseLParen() ||
           parseReductionVarList(parser, reductionVars, reductionVarTypes,
@@ -680,51 +628,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
       }
       result.addAttribute("ordered_val", attr);
     } else if (clauseKeyword == "order") {
-      if (checkAllowed(orderClause) ||
-          parseClauseAttr<ClauseOrderKindAttr>(parser, result, "order_val",
-                                               "order"))
-        return failure();
-    } else if (clauseKeyword == "memory_order") {
-      if (checkAllowed(memoryOrderClause) ||
-          parseClauseAttr<ClauseMemoryOrderKindAttr>(
-              parser, result, "memory_order", "memory order"))
-        return failure();
-    } else if (clauseKeyword == "hint") {
-      IntegerAttr hint;
-      if (checkAllowed(hintClause) ||
-          parseSynchronizationHint(parser, hint, false))
+      ClauseOrderKindAttr order;
+      if (checkAllowed(orderClause) || parser.parseLParen() ||
+          parseClauseAttr<ClauseOrderKindAttr>(parser, order) ||
+          parser.parseRParen())
         return failure();
-      result.addAttribute("hint", hint);
+      result.addAttribute("order_val", order);
     } else {
       return parser.emitError(parser.getNameLoc())
              << clauseKeyword << " is not a valid clause";
     }
   }
 
-  // Add if parameter.
-  if (done[ifClause] && clauseSegments[pos[ifClause]] &&
-      failed(
-          parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
-    return failure();
-
-  // Add num_threads parameter.
-  if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
-      failed(parser.resolveOperand(numThreads.first, numThreads.second,
-                                   result.operands)))
-    return failure();
-
-  // Add device parameter.
-  if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
-      failed(
-          parser.resolveOperand(device.first, device.second, result.operands)))
-    return failure();
-
-  // Add thread_limit parameter.
-  if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
-      failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
-                                   result.operands)))
-    return failure();
-
   // Add allocate parameters.
   if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
       failed(parser.resolveOperands(allocates, allocateTypes,
@@ -1024,7 +939,7 @@ LogicalResult WsLoopOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult CriticalDeclareOp::verify() {
-  return verifySynchronizationHint(*this, hint());
+  return verifySynchronizationHint(*this, hint_val());
 }
 
 LogicalResult CriticalOp::verify() {
@@ -1079,41 +994,11 @@ LogicalResult OrderedRegionOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// AtomicReadOp
+// Verifier for AtomicReadOp
 //===----------------------------------------------------------------------===//
 
-/// Parser for AtomicReadOp
-///
-/// operation ::= `omp.atomic.read` atomic-clause-list address `->` result-type
-/// address ::= operand `:` type
-ParseResult AtomicReadOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType x, v;
-  Type addressType;
-  SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
-  SmallVector<int> segments;
-
-  if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
-      parseClauses(parser, result, clauses, segments) ||
-      parser.parseColonType(addressType) ||
-      parser.resolveOperand(x, addressType, result.operands) ||
-      parser.resolveOperand(v, addressType, result.operands))
-    return failure();
-
-  return success();
-}
-
-void AtomicReadOp::print(OpAsmPrinter &p) {
-  p << " " << v() << " = " << x() << " ";
-  if (auto mo = memory_order())
-    p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
-  if (hintAttr())
-    printSynchronizationHint(p << " ", *this, hintAttr());
-  p << ": " << x().getType();
-}
-
-/// Verifier for AtomicReadOp
 LogicalResult AtomicReadOp::verify() {
-  if (auto mo = memory_order()) {
+  if (auto mo = memory_order_val()) {
     if (*mo == ClauseMemoryOrderKind::acq_rel ||
         *mo == ClauseMemoryOrderKind::release) {
       return emitError(
@@ -1123,92 +1008,30 @@ LogicalResult AtomicReadOp::verify() {
   if (x() == v())
     return emitError(
         "read and write must not be to the same location for atomic reads");
-  return verifySynchronizationHint(*this, hint());
+  return verifySynchronizationHint(*this, hint_val());
 }
 
 //===----------------------------------------------------------------------===//
-// AtomicWriteOp
+// Verifier for AtomicWriteOp
 //===----------------------------------------------------------------------===//
 
-/// Parser for AtomicWriteOp
-///
-/// operation ::= `omp.atomic.write` atomic-clause-list operands
-/// operands ::= address `,` value
-/// address ::= operand `:` type
-/// value ::= operand `:` type
-ParseResult AtomicWriteOp::parse(OpAsmParser &parser, OperationState &result) {
-  OpAsmParser::OperandType address, value;
-  Type addrType, valueType;
-  SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
-  SmallVector<int> segments;
-
-  if (parser.parseOperand(address) || parser.parseEqual() ||
-      parser.parseOperand(value) ||
-      parseClauses(parser, result, clauses, segments) ||
-      parser.parseColonType(addrType) || parser.parseComma() ||
-      parser.parseType(valueType) ||
-      parser.resolveOperand(address, addrType, result.operands) ||
-      parser.resolveOperand(value, valueType, result.operands))
-    return failure();
-  return success();
-}
-
-void AtomicWriteOp::print(OpAsmPrinter &p) {
-  p << " " << address() << " = " << value() << " ";
-  if (auto mo = memory_order())
-    p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
-  if (hintAttr())
-    printSynchronizationHint(p, *this, hintAttr());
-  p << ": " << address().getType() << ", " << value().getType();
-}
-
-/// Verifier for AtomicWriteOp
 LogicalResult AtomicWriteOp::verify() {
-  if (auto mo = memory_order()) {
+  if (auto mo = memory_order_val()) {
     if (*mo == ClauseMemoryOrderKind::acq_rel ||
         *mo == ClauseMemoryOrderKind::acquire) {
       return emitError(
           "memory-order must not be acq_rel or acquire for atomic writes");
     }
   }
-  return verifySynchronizationHint(*this, hint());
+  return verifySynchronizationHint(*this, hint_val());
 }
 
 //===----------------------------------------------------------------------===//
-// AtomicUpdateOp
+// Verifier for AtomicUpdateOp
 //===----------------------------------------------------------------------===//
 
-/// Parser for AtomicUpdateOp
-///
-/// operation ::= `omp.atomic.update` atomic-clause-list ssa-id-and-type region
-ParseResult AtomicUpdateOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
-  SmallVector<int> segments;
-  OpAsmParser::OperandType x, expr;
-  Type xType;
-
-  if (parseClauses(parser, result, clauses, segments) ||
-      parser.parseOperand(x) || parser.parseColon() ||
-      parser.parseType(xType) ||
-      parser.resolveOperand(x, xType, result.operands) ||
-      parser.parseRegion(*result.addRegion()))
-    return failure();
-  return success();
-}
-
-void AtomicUpdateOp::print(OpAsmPrinter &p) {
-  p << " ";
-  if (auto mo = memory_order())
-    p << "memory_order(" << stringifyClauseMemoryOrderKind(*mo) << ") ";
-  if (hintAttr())
-    printSynchronizationHint(p, *this, hintAttr());
-  p << x() << " : " << x().getType();
-  p.printRegion(region());
-}
-
-/// Verifier for AtomicUpdateOp
 LogicalResult AtomicUpdateOp::verify() {
-  if (auto mo = memory_order()) {
+  if (auto mo = memory_order_val()) {
     if (*mo == ClauseMemoryOrderKind::acq_rel ||
         *mo == ClauseMemoryOrderKind::acquire) {
       return emitError(
@@ -1235,28 +1058,9 @@ LogicalResult AtomicUpdateOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// AtomicCaptureOp
+// Verifier for AtomicCaptureOp
 //===----------------------------------------------------------------------===//
 
-ParseResult AtomicCaptureOp::parse(OpAsmParser &parser,
-                                   OperationState &result) {
-  SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
-  SmallVector<int> segments;
-  if (parseClauses(parser, result, clauses, segments) ||
-      parser.parseRegion(*result.addRegion()))
-    return failure();
-  return success();
-}
-
-void AtomicCaptureOp::print(OpAsmPrinter &p) {
-  if (memory_order())
-    p << "memory_order(" << memory_order() << ") ";
-  if (hintAttr())
-    printSynchronizationHint(p, *this, hintAttr());
-  p.printRegion(region());
-}
-
-/// Verifier for AtomicCaptureOp
 LogicalResult AtomicCaptureOp::verify() {
   Block::OpListType &ops = region().front().getOperations();
   if (ops.size() != 3)

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d3b43b423c952..35fa843a553d5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -335,8 +335,9 @@ convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
     auto criticalDeclareOp =
         SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
                                                                      symbolRef);
-    hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
-                                  static_cast<int>(criticalDeclareOp.hint()));
+    hint =
+        llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
+                               static_cast<int>(criticalDeclareOp.hint_val()));
   }
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
       ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
@@ -910,7 +911,7 @@ convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
-  llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order());
+  llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order_val());
   llvm::Value *x = moduleTranslation.lookupValue(readOp.x());
   Type xTy = readOp.x().getType().cast<omp::PointerLikeType>().getElementType();
   llvm::Value *v = moduleTranslation.lookupValue(readOp.v());
@@ -931,7 +932,7 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-  llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order());
+  llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order_val());
   llvm::Value *expr = moduleTranslation.lookupValue(writeOp.value());
   llvm::Value *dest = moduleTranslation.lookupValue(writeOp.address());
   llvm::Type *ty = moduleTranslation.convertType(writeOp.value().getType());

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a991d5f20f6b7..2ac177f4da2cd 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -97,7 +97,7 @@ func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
 // -----
 
 func @order_value(%lb : index, %ub : index, %step : index) {
-  // expected-error @below {{invalid order kind}}
+  // expected-error @below {{invalid clause value: 'default'}}
   omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) {
     omp.yield
   }
@@ -106,7 +106,7 @@ func @order_value(%lb : index, %ub : index, %step : index) {
 // -----
 
 func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
-  // expected-error @below {{if is not a valid clause for the omp.wsloop operation}}
+  // expected-error @below {{if is not a valid clause}}
   omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) {
     omp.yield
   }
@@ -115,7 +115,7 @@ func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
 // -----
 
 func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) {
-  // expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}}
+  // expected-error @below {{num_threads is not a valid clause}}
   omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) {
     omp.yield
   }
@@ -124,7 +124,7 @@ func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var
 // -----
 
 func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) {
-  // expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}}
+  // expected-error @below {{proc_bind is not a valid clause}}
   omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) {
     omp.yield
   }
@@ -429,7 +429,7 @@ func @omp_atomic_read1(%x: memref<i32>, %v: memref<i32>) {
 // -----
 
 func @omp_atomic_read2(%x: memref<i32>, %v: memref<i32>) {
-  // expected-error @below {{invalid memory order kind}}
+  // expected-error @below {{invalid clause value: 'xyz'}}
   omp.atomic.read %v = %x memory_order(xyz) : memref<i32>
   return
 }
@@ -453,7 +453,7 @@ func @omp_atomic_read4(%x: memref<i32>, %v: memref<i32>) {
 // -----
 
 func @omp_atomic_read5(%x: memref<i32>, %v: memref<i32>) {
-  // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.read operation}}
+  // expected-error @below {{`memory_order` clause can appear at most once in the expansion of the oilist directive}}
   omp.atomic.read %v = %x memory_order(acquire) memory_order(relaxed) : memref<i32>
   return
 }
@@ -461,7 +461,7 @@ func @omp_atomic_read5(%x: memref<i32>, %v: memref<i32>) {
 // -----
 
 func @omp_atomic_read6(%x: memref<i32>, %v: memref<i32>) {
-  // expected-error @below {{at most one hint clause can appear on the omp.atomic.read operation}}
+  // expected-error @below {{`hint` clause can appear at most once in the expansion of the oilist directive}}
   omp.atomic.read %v =  %x hint(speculative) hint(contended) : memref<i32>
   return
 }
@@ -501,7 +501,7 @@ func @omp_atomic_write3(%addr : memref<i32>, %val : i32) {
 // -----
 
 func @omp_atomic_write4(%addr : memref<i32>, %val : i32) {
-  // expected-error @below {{at most one memory_order clause can appear on the omp.atomic.write operation}}
+  // expected-error @below {{`memory_order` clause can appear at most once in the expansion of the oilist directive}}
   omp.atomic.write  %addr = %val memory_order(release) memory_order(seq_cst) : memref<i32>, i32
   return
 }
@@ -509,7 +509,7 @@ func @omp_atomic_write4(%addr : memref<i32>, %val : i32) {
 // -----
 
 func @omp_atomic_write5(%addr : memref<i32>, %val : i32) {
-  // expected-error @below {{at most one hint clause can appear on the omp.atomic.write operation}}
+  // expected-error @below {{`hint` clause can appear at most once in the expansion of the oilist directive}}
   omp.atomic.write  %addr = %val hint(contended) hint(speculative) : memref<i32>, i32
   return
 }
@@ -517,7 +517,7 @@ func @omp_atomic_write5(%addr : memref<i32>, %val : i32) {
 // -----
 
 func @omp_atomic_write6(%addr : memref<i32>, %val : i32) {
-  // expected-error @below {{invalid memory order kind}}
+  // expected-error @below {{invalid clause value: 'xyz'}}
   omp.atomic.write  %addr = %val memory_order(xyz) : memref<i32>, i32
   return
 }
@@ -573,9 +573,8 @@ func @omp_atomic_update4(%x: memref<i32>, %expr: i32) {
 
 // -----
 
-// expected-note @below {{prior use here}}
 func @omp_atomic_update5(%x: memref<i32>, %expr: i32) {
-  // expected-error @below {{use of value '%x' expects 
diff erent type than prior uses: 'i32' vs 'memref<i32>'}}
+  // expected-error @below {{invalid kind of type specified}}
   omp.atomic.update %x : i32 {
   ^bb0(%xval: i32):
     %newval = llvm.add %xval, %expr : i32

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e2cc900bf3787..3d6834f0d9345 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -568,14 +568,13 @@ func @omp_atomic_update(%x : memref<i32>, %expr : i32, %xBool : memref<i1>, %exp
 // CHECK-LABEL: omp_atomic_capture
 // CHECK-SAME: (%[[v:.*]]: memref<i32>, %[[x:.*]]: memref<i32>, %[[expr:.*]]: i32)
 func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
-  // CHECK: omp.atomic.capture{
+  // CHECK: omp.atomic.capture {
   // CHECK-NEXT: omp.atomic.update %[[x]] : memref<i32>
   // CHECK-NEXT: (%[[xval:.*]]: i32):
   // CHECK-NEXT:   %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32
   // CHECK-NEXT:   omp.yield(%[[newval]] : i32)
   // CHECK-NEXT: }
   // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
-  // CHECK-NEXT: omp.terminator
   // CHECK-NEXT: }
   omp.atomic.capture{
     omp.atomic.update %x : memref<i32> {
@@ -584,16 +583,14 @@ func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
       omp.yield(%newval : i32)
     }
     omp.atomic.read %v = %x : memref<i32>
-    omp.terminator
   }
-  // CHECK: omp.atomic.capture{
+  // CHECK: omp.atomic.capture {
   // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
   // CHECK-NEXT: omp.atomic.update %[[x]] : memref<i32>
   // CHECK-NEXT: (%[[xval:.*]]: i32):
   // CHECK-NEXT:   %[[newval:.*]] = llvm.add %[[xval]], %[[expr]] : i32
   // CHECK-NEXT:   omp.yield(%[[newval]] : i32)
   // CHECK-NEXT: }
-  // CHECK-NEXT: omp.terminator
   // CHECK-NEXT: }
   omp.atomic.capture{
     omp.atomic.read %v = %x : memref<i32>
@@ -602,17 +599,14 @@ func @omp_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
       %newval = llvm.add %xval, %expr : i32
       omp.yield(%newval : i32)
     }
-    omp.terminator
   }
-  // CHECK: omp.atomic.capture{
+  // CHECK: omp.atomic.capture {
   // CHECK-NEXT: omp.atomic.read %[[v]] = %[[x]] : memref<i32>
   // CHECK-NEXT: omp.atomic.write %[[x]] = %[[expr]] : memref<i32>, i32
-  // CHECK-NEXT: omp.terminator
   // CHECK-NEXT: }
   omp.atomic.capture{
     omp.atomic.read %v = %x : memref<i32>
     omp.atomic.write %x = %expr : memref<i32>, i32
-    omp.terminator
   }
   return
 }


        


More information about the Mlir-commits mailing list