[Mlir-commits] [mlir] d576f45 - [MLIR][OpenMP] Added parseClauses

Shraiysh Vaishay llvmlistbot at llvm.org
Tue Oct 19 05:01:47 PDT 2021


Author: Shraiysh Vaishay
Date: 2021-10-19T17:31:36+05:30
New Revision: d576f4501439860faa95e4f3b782cd6da5123ef1

URL: https://github.com/llvm/llvm-project/commit/d576f4501439860faa95e4f3b782cd6da5123ef1
DIFF: https://github.com/llvm/llvm-project/commit/d576f4501439860faa95e4f3b782cd6da5123ef1.diff

LOG: [MLIR][OpenMP] Added parseClauses

Code reorganized in OpenMPDialect.cpp to have all functions corresponding to an operation together.

Added parseClauses function to avoid code duplication while parsing clauses in OpenMP operations. Also added printers and verifiers for clauses, which are being used for multiple operations.

Reviewed By: kiranchandramohan, peixin

Differential Revision: https://reviews.llvm.org/D110903

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d093fa498b52e..37dcaf04c0089 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -211,8 +211,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
              Variadic<AnyType>:$linear_vars,
              Variadic<AnyType>:$linear_step_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,
-             OptionalAttr<TypedArrayAttrBase<SymbolRefAttr,
-                            "array of symbol references">>:$reductions,
+             OptionalAttr<SymbolRefArrayAttr>:$reductions,
              OptionalAttr<ScheduleKind>:$schedule_val,
              Optional<AnyType>:$schedule_chunk_var,
              Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$collapse_val,

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 32faac397eb16..ebd0e0ba8bf93 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/OperationSupport.h"
 
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
@@ -67,6 +68,10 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   state.addAttributes(attributes);
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for Operand and type list
+//===----------------------------------------------------------------------===//
+
 /// Parse a list of operands with types.
 ///
 /// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
@@ -89,9 +94,30 @@ parseOperandAndTypeList(OpAsmParser &parser,
       });
 }
 
+/// Print an operand and type list with parentheses
+static void printOperandAndTypeList(OpAsmPrinter &p, OperandRange operands) {
+  p << "(";
+  llvm::interleaveComma(
+      operands, p, [&](const Value &v) { p << v << " : " << v.getType(); });
+  p << ") ";
+}
+
+/// Print data variables corresponding to a data-sharing clause `name`
+static void printDataVars(OpAsmPrinter &p, OperandRange operands,
+                          StringRef name) {
+  if (operands.size()) {
+    p << name;
+    printOperandAndTypeList(p, operands);
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for Allocate Clause
+//===----------------------------------------------------------------------===//
+
 /// Parse an allocate clause with allocators and a list of operands with types.
 ///
-/// operand-and-type-list ::= `(` allocate-operand-list `)`
+/// allocate ::= `allocate` `(` allocate-operand-list `)`
 /// allocate-operand-list :: = allocate-operand |
 ///                            allocator-operand `,` allocate-operand-list
 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
@@ -122,6 +148,21 @@ static ParseResult parseAllocateAndAllocator(
       });
 }
 
+/// Print allocate clause
+static void printAllocateAndAllocator(OpAsmPrinter &p,
+                                      OperandRange varsAllocate,
+                                      OperandRange varsAllocator) {
+  if (varsAllocate.empty())
+    return;
+
+  p << "allocate(";
+  for (unsigned i = 0; i < varsAllocate.size(); ++i) {
+    std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
+    p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
+    p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
+  }
+}
+
 static LogicalResult verifyParallelOp(ParallelOp op) {
   if (op.allocate_vars().size() != op.allocators_vars().size())
     return op.emitError(
@@ -130,250 +171,31 @@ static LogicalResult verifyParallelOp(ParallelOp op) {
 }
 
 static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
+  p << " ";
   if (auto ifCond = op.if_expr_var())
-    p << " if(" << ifCond << " : " << ifCond.getType() << ")";
+    p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
 
   if (auto threads = op.num_threads_var())
-    p << " num_threads(" << threads << " : " << threads.getType() << ")";
-
-  // Print private, firstprivate, shared and copyin parameters
-  auto printDataVars = [&p](StringRef name, OperandRange vars) {
-    if (vars.size()) {
-      p << " " << name << "(";
-      for (unsigned i = 0; i < vars.size(); ++i) {
-        std::string separator = i == vars.size() - 1 ? ")" : ", ";
-        p << vars[i] << " : " << vars[i].getType() << separator;
-      }
-    }
-  };
+    p << "num_threads(" << threads << " : " << threads.getType() << ") ";
 
-  // Print allocator and allocate parameters
-  auto printAllocateAndAllocator = [&p](OperandRange varsAllocate,
-                                        OperandRange varsAllocator) {
-    if (varsAllocate.empty())
-      return;
-
-    p << " allocate(";
-    for (unsigned i = 0; i < varsAllocate.size(); ++i) {
-      std::string separator = i == varsAllocate.size() - 1 ? ")" : ", ";
-      p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
-      p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
-    }
-  };
-
-  printDataVars("private", op.private_vars());
-  printDataVars("firstprivate", op.firstprivate_vars());
-  printDataVars("shared", op.shared_vars());
-  printDataVars("copyin", op.copyin_vars());
-  printAllocateAndAllocator(op.allocate_vars(), op.allocators_vars());
+  printDataVars(p, op.private_vars(), "private");
+  printDataVars(p, op.firstprivate_vars(), "firstprivate");
+  printDataVars(p, op.shared_vars(), "shared");
+  printDataVars(p, op.copyin_vars(), "copyin");
+  printAllocateAndAllocator(p, op.allocate_vars(), op.allocators_vars());
 
   if (auto def = op.default_val())
-    p << " default(" << def->drop_front(3) << ")";
+    p << "default(" << def->drop_front(3) << ") ";
 
   if (auto bind = op.proc_bind_val())
-    p << " proc_bind(" << bind << ")";
+    p << "proc_bind(" << bind << ") ";
 
   p.printRegion(op.getRegion());
 }
 
-/// Emit an error if the same clause is present more than once on an operation.
-static ParseResult allowedOnce(OpAsmParser &parser, StringRef clause,
-                               StringRef operation) {
-  return parser.emitError(parser.getNameLoc())
-         << " at most one " << clause << " clause can appear on the "
-         << operation << " operation";
-}
-
-/// Parses a parallel operation.
-///
-/// operation ::= `omp.parallel` clause-list
-/// clause-list ::= clause | clause clause-list
-/// clause ::= if | numThreads | private | firstprivate | shared | copyin |
-///            default | procBind
-/// if ::= `if` `(` ssa-id `)`
-/// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
-/// private ::= `private` operand-and-type-list
-/// firstprivate ::= `firstprivate` operand-and-type-list
-/// shared ::= `shared` operand-and-type-list
-/// copyin ::= `copyin` operand-and-type-list
-/// allocate ::= `allocate` operand-and-type `->` operand-and-type-list
-/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
-/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
-///
-/// Note that each clause can only appear once in the clase-list.
-static ParseResult parseParallelOp(OpAsmParser &parser,
-                                   OperationState &result) {
-  std::pair<OpAsmParser::OperandType, Type> ifCond;
-  std::pair<OpAsmParser::OperandType, Type> numThreads;
-  SmallVector<OpAsmParser::OperandType, 4> privates;
-  SmallVector<Type, 4> privateTypes;
-  SmallVector<OpAsmParser::OperandType, 4> firstprivates;
-  SmallVector<Type, 4> firstprivateTypes;
-  SmallVector<OpAsmParser::OperandType, 4> shareds;
-  SmallVector<Type, 4> sharedTypes;
-  SmallVector<OpAsmParser::OperandType, 4> copyins;
-  SmallVector<Type, 4> copyinTypes;
-  SmallVector<OpAsmParser::OperandType, 4> allocates;
-  SmallVector<Type, 4> allocateTypes;
-  SmallVector<OpAsmParser::OperandType, 4> allocators;
-  SmallVector<Type, 4> allocatorTypes;
-  std::array<int, 8> segments{0, 0, 0, 0, 0, 0, 0, 0};
-  StringRef keyword;
-  bool defaultVal = false;
-  bool procBind = false;
-
-  const int ifClausePos = 0;
-  const int numThreadsClausePos = 1;
-  const int privateClausePos = 2;
-  const int firstprivateClausePos = 3;
-  const int sharedClausePos = 4;
-  const int copyinClausePos = 5;
-  const int allocateClausePos = 6;
-  const int allocatorPos = 7;
-  const StringRef opName = result.name.getStringRef();
-
-  while (succeeded(parser.parseOptionalKeyword(&keyword))) {
-    if (keyword == "if") {
-      // Fail if there was already another if condition.
-      if (segments[ifClausePos])
-        return allowedOnce(parser, "if", opName);
-      if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
-          parser.parseColonType(ifCond.second) || parser.parseRParen())
-        return failure();
-      segments[ifClausePos] = 1;
-    } else if (keyword == "num_threads") {
-      // Fail if there was already another num_threads clause.
-      if (segments[numThreadsClausePos])
-        return allowedOnce(parser, "num_threads", opName);
-      if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
-          parser.parseColonType(numThreads.second) || parser.parseRParen())
-        return failure();
-      segments[numThreadsClausePos] = 1;
-    } else if (keyword == "private") {
-      // Fail if there was already another private clause.
-      if (segments[privateClausePos])
-        return allowedOnce(parser, "private", opName);
-      if (parseOperandAndTypeList(parser, privates, privateTypes))
-        return failure();
-      segments[privateClausePos] = privates.size();
-    } else if (keyword == "firstprivate") {
-      // Fail if there was already another firstprivate clause.
-      if (segments[firstprivateClausePos])
-        return allowedOnce(parser, "firstprivate", opName);
-      if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
-        return failure();
-      segments[firstprivateClausePos] = firstprivates.size();
-    } else if (keyword == "shared") {
-      // Fail if there was already another shared clause.
-      if (segments[sharedClausePos])
-        return allowedOnce(parser, "shared", opName);
-      if (parseOperandAndTypeList(parser, shareds, sharedTypes))
-        return failure();
-      segments[sharedClausePos] = shareds.size();
-    } else if (keyword == "copyin") {
-      // Fail if there was already another copyin clause.
-      if (segments[copyinClausePos])
-        return allowedOnce(parser, "copyin", opName);
-      if (parseOperandAndTypeList(parser, copyins, copyinTypes))
-        return failure();
-      segments[copyinClausePos] = copyins.size();
-    } else if (keyword == "allocate") {
-      // Fail if there was already another allocate clause.
-      if (segments[allocateClausePos])
-        return allowedOnce(parser, "allocate", opName);
-      if (parseAllocateAndAllocator(parser, allocates, allocateTypes,
-                                    allocators, allocatorTypes))
-        return failure();
-      segments[allocateClausePos] = allocates.size();
-      segments[allocatorPos] = allocators.size();
-    } else if (keyword == "default") {
-      // Fail if there was already another default clause.
-      if (defaultVal)
-        return allowedOnce(parser, "default", opName);
-      defaultVal = true;
-      StringRef defval;
-      if (parser.parseLParen() || parser.parseKeyword(&defval) ||
-          parser.parseRParen())
-        return failure();
-      // The def prefix is required for the attribute as "private" is a keyword
-      // in C++.
-      auto attr = parser.getBuilder().getStringAttr("def" + defval);
-      result.addAttribute("default_val", attr);
-    } else if (keyword == "proc_bind") {
-      // Fail if there was already another proc_bind clause.
-      if (procBind)
-        return allowedOnce(parser, "proc_bind", opName);
-      procBind = true;
-      StringRef bind;
-      if (parser.parseLParen() || parser.parseKeyword(&bind) ||
-          parser.parseRParen())
-        return failure();
-      auto attr = parser.getBuilder().getStringAttr(bind);
-      result.addAttribute("proc_bind_val", attr);
-    } else {
-      return parser.emitError(parser.getNameLoc())
-             << keyword << " is not a valid clause for the " << opName
-             << " operation";
-    }
-  }
-
-  // Add if parameter.
-  if (segments[ifClausePos] &&
-      parser.resolveOperand(ifCond.first, ifCond.second, result.operands))
-    return failure();
-
-  // Add num_threads parameter.
-  if (segments[numThreadsClausePos] &&
-      parser.resolveOperand(numThreads.first, numThreads.second,
-                            result.operands))
-    return failure();
-
-  // Add private parameters.
-  if (segments[privateClausePos] &&
-      parser.resolveOperands(privates, privateTypes, privates[0].location,
-                             result.operands))
-    return failure();
-
-  // Add firstprivate parameters.
-  if (segments[firstprivateClausePos] &&
-      parser.resolveOperands(firstprivates, firstprivateTypes,
-                             firstprivates[0].location, result.operands))
-    return failure();
-
-  // Add shared parameters.
-  if (segments[sharedClausePos] &&
-      parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
-                             result.operands))
-    return failure();
-
-  // Add copyin parameters.
-  if (segments[copyinClausePos] &&
-      parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
-                             result.operands))
-    return failure();
-
-  // Add allocate parameters.
-  if (segments[allocateClausePos] &&
-      parser.resolveOperands(allocates, allocateTypes, allocates[0].location,
-                             result.operands))
-    return failure();
-
-  // Add allocator parameters.
-  if (segments[allocatorPos] &&
-      parser.resolveOperands(allocators, allocatorTypes, allocators[0].location,
-                             result.operands))
-    return failure();
-
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(segments));
-
-  Region *body = result.addRegion();
-  SmallVector<OpAsmParser::OperandType, 4> regionArgs;
-  SmallVector<Type, 4> regionArgTypes;
-  if (parser.parseRegion(*body, regionArgs, regionArgTypes))
-    return failure();
-  return success();
-}
+//===----------------------------------------------------------------------===//
+// Parser and printer for Linear Clause
+//===----------------------------------------------------------------------===//
 
 /// linear ::= `linear` `(` linear-list `)`
 /// linear-list := linear-val | linear-val linear-list
@@ -405,6 +227,24 @@ parseLinearClause(OpAsmParser &parser,
   return success();
 }
 
+/// Print Linear Clause
+static void printLinearClause(OpAsmPrinter &p, OperandRange linearVars,
+                              OperandRange linearStepVars) {
+  size_t linearVarsSize = linearVars.size();
+  p << "(";
+  for (unsigned i = 0; i < linearVarsSize; ++i) {
+    std::string separator = i == linearVarsSize - 1 ? ") " : ", ";
+    p << linearVars[i];
+    if (linearStepVars.size() > i)
+      p << " = " << linearStepVars[i];
+    p << " : " << linearVars[i].getType() << separator;
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for Schedule Clause
+//===----------------------------------------------------------------------===//
+
 /// schedule ::= `schedule` `(` sched-list `)`
 /// sched-list ::= sched-val | sched-val sched-list
 /// sched-val ::= sched-with-chunk | sched-wo-chunk
@@ -442,7 +282,21 @@ parseScheduleClause(OpAsmParser &parser, SmallString<8> &schedule,
   return success();
 }
 
-/// reduction-init ::= `reduction` `(` reduction-entry-list `)`
+/// Print schedule clause
+static void printScheduleClause(OpAsmPrinter &p, StringRef &sched,
+                                Value scheduleChunkVar) {
+  std::string schedLower = sched.lower();
+  p << "(" << schedLower;
+  if (scheduleChunkVar)
+    p << " = " << scheduleChunkVar;
+  p << ") ";
+}
+
+//===----------------------------------------------------------------------===//
+// Parser, printer and verifier for ReductionVarList
+//===----------------------------------------------------------------------===//
+
+/// reduction ::= `reduction` `(` reduction-entry-list `)`
 /// reduction-entry-list ::= reduction-entry
 ///                        | reduction-entry-list `,` reduction-entry
 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
@@ -463,209 +317,392 @@ parseReductionVarList(OpAsmParser &parser,
   return parser.parseRParen();
 }
 
-/// Parses an OpenMP Workshare Loop operation
-///
-/// operation ::= `omp.wsloop` loop-control clause-list
-/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
-/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
-/// steps := `step` `(`ssa-id-list`)`
-/// clause-list ::= clause | empty | clause-list
-/// clause ::= private | firstprivate | lastprivate | linear | schedule |
-//             collapse | nowait | ordered | order | inclusive
-/// private ::= `private` `(` ssa-id-and-type-list `)`
-/// firstprivate ::= `firstprivate` `(` ssa-id-and-type-list `)`
-/// lastprivate ::= `lastprivate` `(` ssa-id-and-type-list `)`
+/// Print Reduction clause
+static void printReductionVarList(OpAsmPrinter &p,
+                                  Optional<ArrayAttr> reductions,
+                                  OperandRange reduction_vars) {
+  for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << (*reductions)[i] << " -> " << reduction_vars[i] << " : "
+      << reduction_vars[i].getType();
+  }
+  p << ") ";
+}
+
+/// Verifies Reduction Clause
+static LogicalResult verifyReductionVarList(Operation *op,
+                                            Optional<ArrayAttr> reductions,
+                                            OperandRange reduction_vars) {
+  if (reduction_vars.size() != 0) {
+    if (!reductions || reductions->size() != reduction_vars.size())
+      return op->emitOpError()
+             << "expected as many reduction symbol references "
+                "as reduction variables";
+  } else {
+    if (reductions)
+      return op->emitOpError() << "unexpected reduction symbol references";
+    return success();
+  }
+
+  DenseSet<Value> accumulators;
+  for (auto args : llvm::zip(reduction_vars, *reductions)) {
+    Value accum = std::get<0>(args);
+
+    if (!accumulators.insert(accum).second)
+      return op->emitOpError() << "accumulator variable used more than once";
+
+    Type varType = accum.getType().cast<PointerLikeType>();
+    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+    auto decl =
+        SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
+    if (!decl)
+      return op->emitOpError() << "expected symbol reference " << symbolRef
+                               << " to point to a reduction declaration";
+
+    if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
+      return op->emitOpError()
+             << "expected accumulator (" << varType
+             << ") to be the same type as reduction declaration ("
+             << decl.getAccumulatorType() << ")";
+  }
+
+  return success();
+}
+
+enum ClauseType {
+  ifClause,
+  numThreadsClause,
+  privateClause,
+  firstprivateClause,
+  lastprivateClause,
+  sharedClause,
+  copyinClause,
+  allocateClause,
+  defaultClause,
+  procBindClause,
+  reductionClause,
+  nowaitClause,
+  linearClause,
+  scheduleClause,
+  collapseClause,
+  orderClause,
+  orderedClause,
+  inclusiveClause,
+  COUNT
+};
+
+//===----------------------------------------------------------------------===//
+// Parser for Clause List
+//===----------------------------------------------------------------------===//
+
+/// Parse a list of clauses. The clauses can appear in any order, but their
+/// operand segment indices are in the same order that they are passed in the
+/// `clauses` list. The operand segments are added over the prevSegments
+
+/// clause-list ::= clause clause-list | empty
+/// clause ::= if | num-threads | private | firstprivate | lastprivate |
+///            shared | copyin | allocate | default | proc-bind | reduction |
+///            nowait | linear | schedule | collapse | order | ordered |
+///            inclusive
+/// if ::= `if` `(` ssa-id-and-type `)`
+/// num-threads ::= `num_threads` `(` ssa-id-and-type `)`
+/// private ::= `private` operand-and-type-list
+/// firstprivate ::= `firstprivate` operand-and-type-list
+/// lastprivate ::= `lastprivate` operand-and-type-list
+/// shared ::= `shared` operand-and-type-list
+/// copyin ::= `copyin` operand-and-type-list
+/// allocate ::= `allocate` `(` allocate-operand-list `)`
+/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
+/// proc-bind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
+/// reduction ::= `reduction` `(` reduction-entry-list `)`
+/// nowait ::= `nowait`
 /// linear ::= `linear` `(` linear-list `)`
 /// schedule ::= `schedule` `(` sched-list `)`
 /// collapse ::= `collapse` `(` ssa-id-and-type `)`
-/// nowait ::= `nowait`
-/// ordered ::= `ordered` `(` ssa-id-and-type `)`
 /// order ::= `order` `(` `concurrent` `)`
+/// ordered ::= `ordered` `(` ssa-id-and-type `)`
 /// inclusive ::= `inclusive`
 ///
-static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
-  Type loopVarType;
-  int numIVs;
+/// Note that each clause can only appear once in the clase-list.
+static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
+                                SmallVectorImpl<ClauseType> &clauses,
+                                SmallVectorImpl<int> &segments) {
 
-  // Parse an opening `(` followed by induction variables followed by `)`
-  SmallVector<OpAsmParser::OperandType> ivs;
-  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
-                                     OpAsmParser::Delimiter::Paren))
-    return failure();
+  // Check done[clause] to see if it has been parsed already
+  llvm::BitVector done(ClauseType::COUNT, false);
 
-  numIVs = static_cast<int>(ivs.size());
+  // See pos[clause] to get position of clause in operand segments
+  SmallVector<int> pos(ClauseType::COUNT, -1);
 
-  if (parser.parseColonType(loopVarType))
-    return failure();
+  // Stores the last parsed clause keyword
+  StringRef clauseKeyword;
+  StringRef opName = result.name.getStringRef();
 
-  // Parse loop bounds.
-  SmallVector<OpAsmParser::OperandType> lower;
-  if (parser.parseEqual() ||
-      parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(lower, loopVarType, result.operands))
-    return failure();
+  // Containers for storing operands, types and attributes for various clauses
+  std::pair<OpAsmParser::OperandType, Type> ifCond;
+  std::pair<OpAsmParser::OperandType, Type> numThreads;
 
-  SmallVector<OpAsmParser::OperandType> upper;
-  if (parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(upper, loopVarType, result.operands))
-    return failure();
+  SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
+      shareds, copyins;
+  SmallVector<Type> privateTypes, firstprivateTypes, lastprivateTypes,
+      sharedTypes, copyinTypes;
 
-  // Parse step values.
-  SmallVector<OpAsmParser::OperandType> steps;
-  if (parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(steps, loopVarType, result.operands))
-    return failure();
+  SmallVector<OpAsmParser::OperandType> allocates, allocators;
+  SmallVector<Type> allocateTypes, allocatorTypes;
 
-  SmallVector<OpAsmParser::OperandType> privates;
-  SmallVector<Type> privateTypes;
-  SmallVector<OpAsmParser::OperandType> firstprivates;
-  SmallVector<Type> firstprivateTypes;
-  SmallVector<OpAsmParser::OperandType> lastprivates;
-  SmallVector<Type> lastprivateTypes;
-  SmallVector<OpAsmParser::OperandType> linears;
-  SmallVector<Type> linearTypes;
-  SmallVector<OpAsmParser::OperandType> linearSteps;
   SmallVector<SymbolRefAttr> reductionSymbols;
   SmallVector<OpAsmParser::OperandType> reductionVars;
   SmallVector<Type> reductionVarTypes;
+
+  SmallVector<OpAsmParser::OperandType> linears;
+  SmallVector<Type> linearTypes;
+  SmallVector<OpAsmParser::OperandType> linearSteps;
+
   SmallString<8> schedule;
   Optional<OpAsmParser::OperandType> scheduleChunkSize;
 
-  const StringRef opName = result.name.getStringRef();
-  StringRef keyword;
+  // Compute the position of clauses in operand segments
+  int currPos = 0;
+  for (ClauseType clause : clauses) {
 
-  enum SegmentPos {
-    lbPos = 0,
-    ubPos,
-    stepPos,
-    privateClausePos,
-    firstprivateClausePos,
-    lastprivateClausePos,
-    linearClausePos,
-    linearStepPos,
-    reductionVarPos,
-    scheduleClausePos,
+    // Skip the following clauses - they do not take any position in operand
+    // segments
+    if (clause == defaultClause || clause == procBindClause ||
+        clause == nowaitClause || clause == collapseClause ||
+        clause == orderClause || clause == orderedClause ||
+        clause == inclusiveClause)
+      continue;
+
+    pos[clause] = currPos++;
+
+    // For the following clauses, two positions are reserved in the operand
+    // segments
+    if (clause == allocateClause || clause == linearClause)
+      currPos++;
+  }
+
+  SmallVector<int> clauseSegments(currPos);
+
+  // Helper function to check if a clause is allowed/repeated or not
+  auto checkAllowed = [&](ClauseType clause,
+                          bool allowRepeat = false) -> ParseResult {
+    if (!llvm::is_contained(clauses, clause))
+      return parser.emitError(parser.getCurrentLocation())
+             << clauseKeyword << "is not a valid clause for the " << opName
+             << " operation";
+    if (done[clause] && !allowRepeat)
+      return parser.emitError(parser.getCurrentLocation())
+             << "at most one " << clauseKeyword << " clause can appear on the "
+             << opName << " operation";
+    done[clause] = true;
+    return success();
   };
-  std::array<int, 10> segments{numIVs, numIVs, numIVs, 0, 0, 0, 0, 0, 0, 0};
 
-  while (succeeded(parser.parseOptionalKeyword(&keyword))) {
-    if (keyword == "private") {
-      if (segments[privateClausePos])
-        return allowedOnce(parser, "private", opName);
-      if (parseOperandAndTypeList(parser, privates, privateTypes))
+  while (succeeded(parser.parseOptionalKeyword(&clauseKeyword))) {
+    if (clauseKeyword == "if") {
+      if (checkAllowed(ifClause) || parser.parseLParen() ||
+          parser.parseOperand(ifCond.first) ||
+          parser.parseColonType(ifCond.second) || parser.parseRParen())
+        return failure();
+      clauseSegments[pos[ifClause]] = 1;
+    } else if (clauseKeyword == "num_threads") {
+      if (checkAllowed(numThreadsClause) || parser.parseLParen() ||
+          parser.parseOperand(numThreads.first) ||
+          parser.parseColonType(numThreads.second) || parser.parseRParen())
+        return failure();
+      clauseSegments[pos[numThreadsClause]] = 1;
+    } else if (clauseKeyword == "private") {
+      if (checkAllowed(privateClause) ||
+          parseOperandAndTypeList(parser, privates, privateTypes))
+        return failure();
+      clauseSegments[pos[privateClause]] = privates.size();
+    } else if (clauseKeyword == "firstprivate") {
+      if (checkAllowed(firstprivateClause) ||
+          parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
+        return failure();
+      clauseSegments[pos[firstprivateClause]] = firstprivates.size();
+    } else if (clauseKeyword == "lastprivate") {
+      if (checkAllowed(lastprivateClause) ||
+          parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
+        return failure();
+      clauseSegments[pos[lastprivateClause]] = lastprivates.size();
+    } else if (clauseKeyword == "shared") {
+      if (checkAllowed(sharedClause) ||
+          parseOperandAndTypeList(parser, shareds, sharedTypes))
+        return failure();
+      clauseSegments[pos[sharedClause]] = shareds.size();
+    } else if (clauseKeyword == "copyin") {
+      if (checkAllowed(copyinClause) ||
+          parseOperandAndTypeList(parser, copyins, copyinTypes))
+        return failure();
+      clauseSegments[pos[copyinClause]] = copyins.size();
+    } else if (clauseKeyword == "allocate") {
+      if (checkAllowed(allocateClause) ||
+          parseAllocateAndAllocator(parser, allocates, allocateTypes,
+                                    allocators, allocatorTypes))
         return failure();
-      segments[privateClausePos] = privates.size();
-    } else if (keyword == "firstprivate") {
-      // fail if there was already another firstprivate clause
-      if (segments[firstprivateClausePos])
-        return allowedOnce(parser, "firstprivate", opName);
-      if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
+      clauseSegments[pos[allocateClause]] = allocates.size();
+      clauseSegments[pos[allocateClause] + 1] = allocators.size();
+    } else if (clauseKeyword == "default") {
+      StringRef defval;
+      if (checkAllowed(defaultClause) || parser.parseLParen() ||
+          parser.parseKeyword(&defval) || parser.parseRParen())
         return failure();
-      segments[firstprivateClausePos] = firstprivates.size();
-    } else if (keyword == "lastprivate") {
-      // fail if there was already another shared clause
-      if (segments[lastprivateClausePos])
-        return allowedOnce(parser, "lastprivate", opName);
-      if (parseOperandAndTypeList(parser, lastprivates, lastprivateTypes))
+      // The def prefix is required for the attribute as "private" is a keyword
+      // in C++.
+      auto attr = parser.getBuilder().getStringAttr("def" + defval);
+      result.addAttribute("default_val", attr);
+    } else if (clauseKeyword == "proc_bind") {
+      StringRef bind;
+      if (checkAllowed(procBindClause) || parser.parseLParen() ||
+          parser.parseKeyword(&bind) || parser.parseRParen())
         return failure();
-      segments[lastprivateClausePos] = lastprivates.size();
-    } else if (keyword == "linear") {
-      // fail if there was already another linear clause
-      if (segments[linearClausePos])
-        return allowedOnce(parser, "linear", opName);
-      if (parseLinearClause(parser, linears, linearTypes, linearSteps))
+      auto attr = parser.getBuilder().getStringAttr(bind);
+      result.addAttribute("proc_bind_val", attr);
+    } else if (clauseKeyword == "reduction") {
+      if (checkAllowed(reductionClause) ||
+          parseReductionVarList(parser, reductionSymbols, reductionVars,
+                                reductionVarTypes))
+        return failure();
+      clauseSegments[pos[reductionClause]] = reductionVars.size();
+    } else if (clauseKeyword == "nowait") {
+      if (checkAllowed(nowaitClause))
+        return failure();
+      auto attr = UnitAttr::get(parser.getBuilder().getContext());
+      result.addAttribute("nowait", attr);
+    } else if (clauseKeyword == "linear") {
+      if (checkAllowed(linearClause) ||
+          parseLinearClause(parser, linears, linearTypes, linearSteps))
         return failure();
-      segments[linearClausePos] = linears.size();
-      segments[linearStepPos] = linearSteps.size();
-    } else if (keyword == "schedule") {
-      if (!schedule.empty())
-        return allowedOnce(parser, "schedule", opName);
-      if (parseScheduleClause(parser, schedule, scheduleChunkSize))
+      clauseSegments[pos[linearClause]] = linears.size();
+      clauseSegments[pos[linearClause] + 1] = linearSteps.size();
+    } else if (clauseKeyword == "schedule") {
+      if (checkAllowed(scheduleClause) ||
+          parseScheduleClause(parser, schedule, scheduleChunkSize))
         return failure();
       if (scheduleChunkSize) {
-        segments[scheduleClausePos] = 1;
+        clauseSegments[pos[scheduleClause]] = 1;
       }
-    } else if (keyword == "collapse") {
+    } else if (clauseKeyword == "collapse") {
       auto type = parser.getBuilder().getI64Type();
       mlir::IntegerAttr attr;
-      if (parser.parseLParen() || parser.parseAttribute(attr, type) ||
-          parser.parseRParen())
+      if (checkAllowed(collapseClause) || parser.parseLParen() ||
+          parser.parseAttribute(attr, type) || parser.parseRParen())
         return failure();
       result.addAttribute("collapse_val", attr);
-    } else if (keyword == "nowait") {
-      auto attr = UnitAttr::get(parser.getContext());
-      result.addAttribute("nowait", attr);
-    } else if (keyword == "ordered") {
+    } else if (clauseKeyword == "ordered") {
       mlir::IntegerAttr attr;
+      if (checkAllowed(orderedClause))
+        return failure();
       if (succeeded(parser.parseOptionalLParen())) {
         auto type = parser.getBuilder().getI64Type();
-        if (parser.parseAttribute(attr, type))
-          return failure();
-        if (parser.parseRParen())
+        if (parser.parseAttribute(attr, type) || parser.parseRParen())
           return failure();
       } else {
         // Use 0 to represent no ordered parameter was specified
         attr = parser.getBuilder().getI64IntegerAttr(0);
       }
       result.addAttribute("ordered_val", attr);
-    } else if (keyword == "order") {
+    } else if (clauseKeyword == "order") {
       StringRef order;
-      if (parser.parseLParen() || parser.parseKeyword(&order) ||
-          parser.parseRParen())
+      if (checkAllowed(orderClause) || parser.parseLParen() ||
+          parser.parseKeyword(&order) || parser.parseRParen())
         return failure();
       auto attr = parser.getBuilder().getStringAttr(order);
       result.addAttribute("order", attr);
-    } else if (keyword == "inclusive") {
-      auto attr = UnitAttr::get(parser.getContext());
-      result.addAttribute("inclusive", attr);
-    } else if (keyword == "reduction") {
-      if (segments[reductionVarPos])
-        return allowedOnce(parser, "reduction", opName);
-      if (failed(parseReductionVarList(parser, reductionSymbols, reductionVars,
-                                       reductionVarTypes)))
+    } else if (clauseKeyword == "inclusive") {
+      if (checkAllowed(inclusiveClause))
         return failure();
-      segments[reductionVarPos] = reductionVars.size();
+      auto attr = UnitAttr::get(parser.getBuilder().getContext());
+      result.addAttribute("inclusive", attr);
+    } else {
+      return parser.emitError(parser.getNameLoc())
+             << clauseKeyword << " is not a valid clause";
     }
   }
 
-  if (segments[privateClausePos]) {
-    parser.resolveOperands(privates, privateTypes, privates[0].location,
-                           result.operands);
-  }
+  // Add if parameter.
+  if (done[ifClause] && clauseSegments[pos[ifClause]] &&
+      failed(
+          parser.resolveOperand(ifCond.first, ifCond.second, result.operands)))
+    return failure();
 
-  if (segments[firstprivateClausePos]) {
-    parser.resolveOperands(firstprivates, firstprivateTypes,
-                           firstprivates[0].location, result.operands);
-  }
+  // Add num_threads parameter.
+  if (done[numThreadsClause] && clauseSegments[pos[numThreadsClause]] &&
+      failed(parser.resolveOperand(numThreads.first, numThreads.second,
+                                   result.operands)))
+    return failure();
 
-  if (segments[lastprivateClausePos]) {
-    parser.resolveOperands(lastprivates, lastprivateTypes,
-                           lastprivates[0].location, result.operands);
-  }
+  // Add private parameters.
+  if (done[privateClause] && clauseSegments[pos[privateClause]] &&
+      failed(parser.resolveOperands(privates, privateTypes,
+                                    privates[0].location, result.operands)))
+    return failure();
 
-  if (segments[linearClausePos]) {
-    parser.resolveOperands(linears, linearTypes, linears[0].location,
-                           result.operands);
-    auto linearStepType = parser.getBuilder().getI32Type();
-    SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
-    parser.resolveOperands(linearSteps, linearStepTypes,
-                           linearSteps[0].location, result.operands);
-  }
+  // Add firstprivate parameters.
+  if (done[firstprivateClause] && clauseSegments[pos[firstprivateClause]] &&
+      failed(parser.resolveOperands(firstprivates, firstprivateTypes,
+                                    firstprivates[0].location,
+                                    result.operands)))
+    return failure();
+
+  // Add lastprivate parameters.
+  if (done[lastprivateClause] && clauseSegments[pos[lastprivateClause]] &&
+      failed(parser.resolveOperands(lastprivates, lastprivateTypes,
+                                    lastprivates[0].location, result.operands)))
+    return failure();
 
-  if (segments[reductionVarPos]) {
+  // Add shared parameters.
+  if (done[sharedClause] && clauseSegments[pos[sharedClause]] &&
+      failed(parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
+                                    result.operands)))
+    return failure();
+
+  // Add copyin parameters.
+  if (done[copyinClause] && clauseSegments[pos[copyinClause]] &&
+      failed(parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
+                                    result.operands)))
+    return failure();
+
+  // Add allocate parameters.
+  if (done[allocateClause] && clauseSegments[pos[allocateClause]] &&
+      failed(parser.resolveOperands(allocates, allocateTypes,
+                                    allocates[0].location, result.operands)))
+    return failure();
+
+  // Add allocator parameters.
+  if (done[allocateClause] && clauseSegments[pos[allocateClause] + 1] &&
+      failed(parser.resolveOperands(allocators, allocatorTypes,
+                                    allocators[0].location, result.operands)))
+    return failure();
+
+  // Add reduction parameters and symbols
+  if (done[reductionClause] && clauseSegments[pos[reductionClause]]) {
     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
-                                      parser.getNameLoc(), result.operands))) {
+                                      parser.getNameLoc(), result.operands)))
       return failure();
-    }
+
     SmallVector<Attribute> reductions(reductionSymbols.begin(),
                                       reductionSymbols.end());
     result.addAttribute("reductions",
                         parser.getBuilder().getArrayAttr(reductions));
   }
 
-  if (!schedule.empty()) {
+  // Add linear parameters
+  if (done[linearClause] && clauseSegments[pos[linearClause]]) {
+    auto linearStepType = parser.getBuilder().getI32Type();
+    SmallVector<Type> linearStepTypes(linearSteps.size(), linearStepType);
+    if (failed(parser.resolveOperands(linears, linearTypes, linears[0].location,
+                                      result.operands)) ||
+        failed(parser.resolveOperands(linearSteps, linearStepTypes,
+                                      linearSteps[0].location,
+                                      result.operands)))
+      return failure();
+  }
+
+  // Add schedule parameters
+  if (done[scheduleClause] && !schedule.empty()) {
     schedule[0] = llvm::toUpper(schedule[0]);
     auto attr = parser.getBuilder().getStringAttr(schedule);
     result.addAttribute("schedule_val", attr);
@@ -675,6 +712,91 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
     }
   }
 
+  segments.insert(segments.end(), clauseSegments.begin(), clauseSegments.end());
+
+  return success();
+}
+
+/// Parses a parallel operation.
+///
+/// operation ::= `omp.parallel` clause-list
+/// clause-list ::= clause | clause clause-list
+/// clause ::= if | num-threads | private | firstprivate | shared | copyin |
+///            allocate | default | proc-bind
+///
+static ParseResult parseParallelOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  SmallVector<ClauseType> clauses = {
+      ifClause,           numThreadsClause, privateClause,
+      firstprivateClause, sharedClause,     copyinClause,
+      allocateClause,     defaultClause,    procBindClause};
+
+  SmallVector<int> segments;
+
+  if (failed(parseClauses(parser, result, clauses, segments)))
+    return failure();
+
+  result.addAttribute("operand_segment_sizes",
+                      parser.getBuilder().getI32VectorAttr(segments));
+
+  Region *body = result.addRegion();
+  SmallVector<OpAsmParser::OperandType> regionArgs;
+  SmallVector<Type> regionArgTypes;
+  if (parser.parseRegion(*body, regionArgs, regionArgTypes))
+    return failure();
+  return success();
+}
+
+/// Parses an OpenMP Workshare Loop operation
+///
+/// wsloop ::= `omp.wsloop` loop-control clause-list
+/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
+/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
+/// steps := `step` `(`ssa-id-list`)`
+/// clause-list ::= clause clause-list | empty
+/// clause ::= private | firstprivate | lastprivate | linear | schedule |
+//             collapse | nowait | ordered | order | inclusive | reduction
+static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
+
+  // Parse an opening `(` followed by induction variables followed by `)`
+  SmallVector<OpAsmParser::OperandType> ivs;
+  if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
+                                     OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  int numIVs = static_cast<int>(ivs.size());
+  Type loopVarType;
+  if (parser.parseColonType(loopVarType))
+    return failure();
+
+  // Parse loop bounds.
+  SmallVector<OpAsmParser::OperandType> lower;
+  if (parser.parseEqual() ||
+      parser.parseOperandList(lower, numIVs, OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(lower, loopVarType, result.operands))
+    return failure();
+
+  SmallVector<OpAsmParser::OperandType> upper;
+  if (parser.parseKeyword("to") ||
+      parser.parseOperandList(upper, numIVs, OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(upper, loopVarType, result.operands))
+    return failure();
+
+  // Parse step values.
+  SmallVector<OpAsmParser::OperandType> steps;
+  if (parser.parseKeyword("step") ||
+      parser.parseOperandList(steps, numIVs, OpAsmParser::Delimiter::Paren) ||
+      parser.resolveOperands(steps, loopVarType, result.operands))
+    return failure();
+
+  SmallVector<ClauseType> clauses = {
+      privateClause,   firstprivateClause, lastprivateClause, linearClause,
+      reductionClause, collapseClause,     orderClause,       orderedClause,
+      nowaitClause,    scheduleClause};
+  SmallVector<int> segments{numIVs, numIVs, numIVs};
+  if (failed(parseClauses(parser, result, clauses, segments)))
+    return failure();
+
   result.addAttribute("operand_segment_sizes",
                       parser.getBuilder().getI32VectorAttr(segments));
 
@@ -690,69 +812,38 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
   auto args = op.getRegion().front().getArguments();
   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
-    << ") to (" << op.upperBound() << ") step (" << op.step() << ")";
+    << ") to (" << op.upperBound() << ") step (" << op.step() << ") ";
 
-  // Print private, firstprivate, shared and copyin parameters
-  auto printDataVars = [&p](StringRef name, OperandRange vars) {
-    if (vars.empty())
-      return;
+  printDataVars(p, op.private_vars(), "private");
+  printDataVars(p, op.firstprivate_vars(), "firstprivate");
+  printDataVars(p, op.lastprivate_vars(), "lastprivate");
 
-    p << " " << name << "(";
-    llvm::interleaveComma(
-        vars, p, [&](const Value &v) { p << v << " : " << v.getType(); });
-    p << ")";
-  };
-  printDataVars("private", op.private_vars());
-  printDataVars("firstprivate", op.firstprivate_vars());
-  printDataVars("lastprivate", op.lastprivate_vars());
-
-  auto linearVars = op.linear_vars();
-  auto linearVarsSize = linearVars.size();
-  if (linearVarsSize) {
-    p << " "
-      << "linear"
-      << "(";
-    for (unsigned i = 0; i < linearVarsSize; ++i) {
-      std::string separator = i == linearVarsSize - 1 ? ")" : ", ";
-      p << linearVars[i];
-      if (op.linear_step_vars().size() > i)
-        p << " = " << op.linear_step_vars()[i];
-      p << " : " << linearVars[i].getType() << separator;
-    }
+  if (op.linear_vars().size()) {
+    p << "linear";
+    printLinearClause(p, op.linear_vars(), op.linear_step_vars());
   }
 
   if (auto sched = op.schedule_val()) {
-    auto schedLower = sched->lower();
-    p << " schedule(" << schedLower;
-    if (auto chunk = op.schedule_chunk_var()) {
-      p << " = " << chunk;
-    }
-    p << ")";
+    p << "schedule";
+    printScheduleClause(p, sched.getValue(), op.schedule_chunk_var());
   }
 
   if (auto collapse = op.collapse_val())
-    p << " collapse(" << collapse << ")";
+    p << "collapse(" << collapse << ") ";
 
   if (op.nowait())
-    p << " nowait";
+    p << "nowait ";
 
-  if (auto ordered = op.ordered_val()) {
-    p << " ordered(" << ordered << ")";
-  }
+  if (auto ordered = op.ordered_val())
+    p << "ordered(" << ordered << ") ";
 
   if (!op.reduction_vars().empty()) {
-    p << " reduction(";
-    for (unsigned i = 0, e = op.getNumReductionVars(); i < e; ++i) {
-      if (i != 0)
-        p << ", ";
-      p << (*op.reductions())[i] << " -> " << op.reduction_vars()[i] << " : "
-        << op.reduction_vars()[i].getType();
-    }
-    p << ")";
+    p << "reduction(";
+    printReductionVarList(p, op.reductions(), op.reduction_vars());
   }
 
   if (op.inclusive()) {
-    p << " inclusive";
+    p << "inclusive ";
   }
 
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
@@ -921,42 +1012,7 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &result,
 }
 
 static LogicalResult verifyWsLoopOp(WsLoopOp op) {
-  if (op.getNumReductionVars() != 0) {
-    if (!op.reductions() ||
-        op.reductions()->size() != op.getNumReductionVars()) {
-      return op.emitOpError() << "expected as many reduction symbol references "
-                                 "as reduction variables";
-    }
-  } else {
-    if (op.reductions())
-      return op.emitOpError() << "unexpected reduction symbol references";
-    return success();
-  }
-
-  DenseSet<Value> accumulators;
-  for (auto args : llvm::zip(op.reduction_vars(), *op.reductions())) {
-    Value accum = std::get<0>(args);
-    if (!accumulators.insert(accum).second) {
-      return op.emitOpError() << "accumulator variable used more than once";
-    }
-    Type varType = accum.getType().cast<PointerLikeType>();
-    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
-    auto decl =
-        SymbolTable::lookupNearestSymbolFrom<ReductionDeclareOp>(op, symbolRef);
-    if (!decl) {
-      return op.emitOpError() << "expected symbol reference " << symbolRef
-                              << " to point to a reduction declaration";
-    }
-
-    if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) {
-      return op.emitOpError()
-             << "expected accumulator (" << varType
-             << ") to be the same type as reduction declaration ("
-             << decl.getAccumulatorType() << ")";
-    }
-  }
-
-  return success();
+  return verifyReductionVarList(op, op.reductions(), op.reduction_vars());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e07d98e6125e3..36d40ad455d06 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 func @unknown_clause() {
-  // expected-error at +1 {{invalid is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{invalid is not a valid clause}}
   omp.parallel invalid {
   }
 


        


More information about the Mlir-commits mailing list