[Mlir-commits] [mlir] 245b299 - [mlir][OpenMP] Add custom parser and pretty printer for parallel construct

David Truby llvmlistbot at llvm.org
Tue Jun 16 05:35:53 PDT 2020


Author: David Truby
Date: 2020-06-16T13:35:42+01:00
New Revision: 245b299edc98c3e92c902205ffce1bf50ca95f9a

URL: https://github.com/llvm/llvm-project/commit/245b299edc98c3e92c902205ffce1bf50ca95f9a
DIFF: https://github.com/llvm/llvm-project/commit/245b299edc98c3e92c902205ffce1bf50ca95f9a.diff

LOG: [mlir][OpenMP] Add custom parser and pretty printer for parallel construct

Reviewers: jdoerfert

Subscribers: yaxunl, guansong, mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, sstefan1, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81264

Added: 
    mlir/test/Dialect/OpenMP/invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 78b56cac1353..3be6c97322b5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -91,6 +91,9 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
              OptionalAttr<ClauseProcBind>:$proc_bind_val);
 
   let regions = (region AnyRegion:$region);
+
+  let parser = [{ return parseParallelOp(parser, result); }];
+  let printer = [{ return printParallelOp(p, *this); }];
 }
 
 def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 99c592c25b83..4467d3361f51 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -12,9 +12,14 @@
 
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/OperationSupport.h"
 
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
+#include <cstddef>
 
 #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
 
@@ -29,6 +34,245 @@ OpenMPDialect::OpenMPDialect(MLIRContext *context)
       >();
 }
 
+//===----------------------------------------------------------------------===//
+// ParallelOp
+//===----------------------------------------------------------------------===//
+
+/// Parse a list of operands with types.
+///
+/// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
+/// ssa-id-and-type-list ::= ssa-id-and-type |
+///                          ssa-id-and-type ',' ssa-id-and-type-list
+/// ssa-id-and-type ::= ssa-id `:` type
+static ParseResult
+parseOperandAndTypeList(OpAsmParser &parser,
+                        SmallVectorImpl<OpAsmParser::OperandType> &operands,
+                        SmallVectorImpl<Type> &types) {
+  if (parser.parseLParen())
+    return failure();
+
+  do {
+    OpAsmParser::OperandType operand;
+    Type type;
+    if (parser.parseOperand(operand) || parser.parseColonType(type))
+      return failure();
+    operands.push_back(operand);
+    types.push_back(type);
+  } while (succeeded(parser.parseOptionalComma()));
+
+  if (parser.parseRParen())
+    return failure();
+
+  return success();
+}
+
+static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
+  p << "omp.parallel";
+
+  if (auto ifCond = op.if_expr_var())
+    p << " if(" << ifCond << ")";
+
+  if (auto threads = op.num_threads_var())
+    p << " num_threads(" << threads << " : " << threads.getType() << ")";
+
+  // Print private, firstprivate, shared and copyin parameters
+  auto printDataVars = [&p](StringRef name, OperandRange vars) {
+    if (vars.size()) {
+      p << " " << name << "(";
+      for (unsigned i = 0; i < vars.size(); ++i) {
+        std::string separator = i == vars.size() - 1 ? ")" : ", ";
+        p << vars[i] << " : " << vars[i].getType() << separator;
+      }
+    }
+  };
+  printDataVars("private", op.private_vars());
+  printDataVars("firstprivate", op.firstprivate_vars());
+  printDataVars("shared", op.shared_vars());
+  printDataVars("copyin", op.copyin_vars());
+
+  if (auto def = op.default_val())
+    p << " default(" << def->drop_front(3) << ")";
+
+  if (auto bind = op.proc_bind_val())
+    p << " proc_bind(" << bind << ")";
+
+  p.printRegion(op.getRegion());
+}
+
+/// Emit an error if the same clause is present more than once on an operation.
+static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
+                               llvm::StringRef operation) {
+  return parser.emitError(parser.getNameLoc())
+         << " at most one " << clause << " clause can appear on the "
+         << operation << " operation";
+}
+
+/// Parses a parallel operation.
+///
+/// operation ::= `omp.parallel` clause-list
+/// clause-list ::= clause | clause clause-list
+/// clause ::= if | numThreads | private | firstprivate | shared | copyin |
+///            default | procBind
+/// if ::= `if` `(` ssa-id `)`
+/// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
+/// private ::= `private` operand-and-type-list
+/// firstprivate ::= `firstprivate` operand-and-type-list
+/// shared ::= `shared` operand-and-type-list
+/// copyin ::= `copyin` operand-and-type-list
+/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
+/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
+///
+/// Note that each clause can only appear once in the clase-list.
+static ParseResult parseParallelOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  OpAsmParser::OperandType ifCond;
+  std::pair<OpAsmParser::OperandType, Type> numThreads;
+  llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
+  llvm::SmallVector<Type, 4> privateTypes;
+  llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates;
+  llvm::SmallVector<Type, 4> firstprivateTypes;
+  llvm::SmallVector<OpAsmParser::OperandType, 4> shareds;
+  llvm::SmallVector<Type, 4> sharedTypes;
+  llvm::SmallVector<OpAsmParser::OperandType, 4> copyins;
+  llvm::SmallVector<Type, 4> copyinTypes;
+  std::array<int, 6> segments{0, 0, 0, 0, 0, 0};
+  llvm::StringRef keyword;
+  bool defaultVal = false;
+  bool procBind = false;
+
+  const int ifClausePos = 0;
+  const int numThreadsClausePos = 1;
+  const int privateClausePos = 2;
+  const int firstprivateClausePos = 3;
+  const int sharedClausePos = 4;
+  const int copyinClausePos = 5;
+  const llvm::StringRef opName = result.name.getStringRef();
+
+  while (succeeded(parser.parseOptionalKeyword(&keyword))) {
+    if (keyword == "if") {
+      // Fail if there was already another if condition
+      if (segments[ifClausePos])
+        return allowedOnce(parser, "if", opName);
+      if (parser.parseLParen() || parser.parseOperand(ifCond) ||
+          parser.parseRParen())
+        return failure();
+      segments[ifClausePos] = 1;
+    } else if (keyword == "num_threads") {
+      // fail if there was already another num_threads clause
+      if (segments[numThreadsClausePos])
+        return allowedOnce(parser, "num_threads", opName);
+      if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
+          parser.parseColonType(numThreads.second) || parser.parseRParen())
+        return failure();
+      segments[numThreadsClausePos] = 1;
+    } else if (keyword == "private") {
+      // fail if there was already another private clause
+      if (segments[privateClausePos])
+        return allowedOnce(parser, "private", opName);
+      if (parseOperandAndTypeList(parser, privates, privateTypes))
+        return failure();
+      segments[privateClausePos] = privates.size();
+    } else if (keyword == "firstprivate") {
+      // fail if there was already another firstprivate clause
+      if (segments[firstprivateClausePos])
+        return allowedOnce(parser, "firstprivate", opName);
+      if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
+        return failure();
+      segments[firstprivateClausePos] = firstprivates.size();
+    } else if (keyword == "shared") {
+      // fail if there was already another shared clause
+      if (segments[sharedClausePos])
+        return allowedOnce(parser, "shared", opName);
+      if (parseOperandAndTypeList(parser, shareds, sharedTypes))
+        return failure();
+      segments[sharedClausePos] = shareds.size();
+    } else if (keyword == "copyin") {
+      // fail if there was already another copyin clause
+      if (segments[copyinClausePos])
+        return allowedOnce(parser, "copyin", opName);
+      if (parseOperandAndTypeList(parser, copyins, copyinTypes))
+        return failure();
+      segments[copyinClausePos] = copyins.size();
+    } else if (keyword == "default") {
+      // fail if there was already another default clause
+      if (defaultVal)
+        return allowedOnce(parser, "default", opName);
+      defaultVal = true;
+      llvm::StringRef defval;
+      if (parser.parseLParen() || parser.parseKeyword(&defval) ||
+          parser.parseRParen())
+        return failure();
+      llvm::SmallString<16> attrval;
+      // The def prefix is required for the attribute as "private" is a keyword
+      // in C++
+      attrval += "def";
+      attrval += defval;
+      auto attr = parser.getBuilder().getStringAttr(attrval);
+      result.addAttribute("default_val", attr);
+    } else if (keyword == "proc_bind") {
+      // fail if there was already another default clause
+      if (procBind)
+        return allowedOnce(parser, "proc_bind", opName);
+      procBind = true;
+      llvm::StringRef bind;
+      if (parser.parseLParen() || parser.parseKeyword(&bind) ||
+          parser.parseRParen())
+        return failure();
+      auto attr = parser.getBuilder().getStringAttr(bind);
+      result.addAttribute("proc_bind_val", attr);
+    } else {
+      return parser.emitError(parser.getNameLoc())
+             << keyword << " is not a valid clause for the " << opName
+             << " operation";
+    }
+  }
+
+  // Add if parameter
+  if (segments[ifClausePos]) {
+    parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(),
+                          result.operands);
+  }
+
+  // Add num_threads parameter
+  if (segments[numThreadsClausePos]) {
+    parser.resolveOperand(numThreads.first, numThreads.second, result.operands);
+  }
+
+  // Add private parameters
+  if (segments[privateClausePos]) {
+    parser.resolveOperands(privates, privateTypes, privates[0].location,
+                           result.operands);
+  }
+
+  // Add firstprivate parameters
+  if (segments[firstprivateClausePos]) {
+    parser.resolveOperands(firstprivates, firstprivateTypes,
+                           firstprivates[0].location, result.operands);
+  }
+
+  // Add shared parameters
+  if (segments[sharedClausePos]) {
+    parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
+                           result.operands);
+  }
+
+  // Add copyin parameters
+  if (segments[copyinClausePos]) {
+    parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
+                           result.operands);
+  }
+
+  result.addAttribute("operand_segment_sizes",
+                      parser.getBuilder().getI32VectorAttr(segments));
+
+  Region *body = result.addRegion();
+  llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs;
+  llvm::SmallVector<Type, 4> regionArgTypes;
+  if (parser.parseRegion(*body, regionArgs, regionArgTypes))
+    return failure();
+  return success();
+}
+
 namespace mlir {
 namespace omp {
 #define GET_OP_CLASSES

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
new file mode 100644
index 000000000000..00f9726b119d
--- /dev/null
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+
+func @unknown_clause() {
+  // expected-error at +1 {{invalid is not a valid clause for the omp.parallel operation}}
+  omp.parallel invalid {
+  }
+
+  return
+}
+
+// -----
+
+func @if_once(%n : i1) {
+  // expected-error at +1 {{at most one if clause can appear on the omp.parallel operation}}
+  omp.parallel if(%n) if(%n) {
+  }
+
+  return
+}
+
+// -----
+
+func @num_threads_once(%n : si32) {
+  // expected-error at +1 {{at most one num_threads clause can appear on the omp.parallel operation}}
+  omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
+  }
+
+  return
+}
+
+// -----
+
+func @private_once(%n : memref<i32>) {
+  // expected-error at +1 {{at most one private clause can appear on the omp.parallel operation}}
+  omp.parallel private(%n : memref<i32>) private(%n : memref<i32>) {
+  }
+
+  return
+}
+
+// -----
+
+func @firstprivate_once(%n : memref<i32>) {
+  // expected-error at +1 {{at most one firstprivate clause can appear on the omp.parallel operation}}
+  omp.parallel firstprivate(%n : memref<i32>) firstprivate(%n : memref<i32>) {
+  }
+
+  return
+}
+
+// -----
+
+func @shared_once(%n : memref<i32>) {
+  // expected-error at +1 {{at most one shared clause can appear on the omp.parallel operation}}
+  omp.parallel shared(%n : memref<i32>) shared(%n : memref<i32>) {
+  }
+
+  return
+}
+
+// -----
+
+func @copyin_once(%n : memref<i32>) {
+  // expected-error at +1 {{at most one copyin clause can appear on the omp.parallel operation}}
+  omp.parallel copyin(%n : memref<i32>) copyin(%n : memref<i32>) {
+  }
+
+  return
+}
+
+// -----
+ 
+func @default_once() {
+  // expected-error at +1 {{at most one default clause can appear on the omp.parallel operation}}
+  omp.parallel default(private) default(firstprivate) {
+  }
+
+  return
+}
+
+// -----
+
+func @proc_bind_once() {
+  // expected-error at +1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
+  omp.parallel proc_bind(close) proc_bind(spread) {
+  }
+
+  return
+}

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e780cebd93fa..85343f985501 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
 func @omp_barrier() -> () {
   // CHECK: omp.barrier
@@ -51,11 +51,11 @@ func @omp_terminator() -> () {
 }
 
 func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
-  // CHECK: omp.parallel
+  // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>) shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>)
   "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var, %data_var, %data_var) ({
 
   // test without if condition
-  // CHECK: omp.parallel
+  // CHECK: omp.parallel num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>) shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>)
     "omp.parallel"(%num_threads, %data_var, %data_var, %data_var, %data_var) ({
       omp.terminator
     }) {operand_segment_sizes = dense<[0,1,1,1,1,1]>: vector<6xi32>, default_val = "defshared"} : (si32, memref<i32>, memref<i32>, memref<i32>, memref<i32>) -> ()
@@ -64,7 +64,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
     omp.barrier
 
   // test without num_threads
-  // CHECK: omp.parallel
+  // CHECK: omp.parallel if(%{{.*}}) private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>) shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>)
     "omp.parallel"(%if_cond, %data_var, %data_var, %data_var, %data_var) ({
       omp.terminator
     }) {operand_segment_sizes = dense<[1,0,1,1,1,1]> : vector<6xi32>} : (i1, memref<i32>, memref<i32>, memref<i32>, memref<i32>) -> ()
@@ -73,10 +73,43 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
   }) {operand_segment_sizes = dense<[1,1,1,1,1,1]> : vector<6xi32>, proc_bind_val = "spread"} : (i1, si32, memref<i32>, memref<i32>, memref<i32>, memref<i32>) -> ()
 
   // test with multiple parameters for single variadic argument
-  // CHECK: omp.parallel
+  // CHECK: omp.parallel private(%{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>, %{{.*}} : memref<i32>) shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var, %data_var, %data_var, %data_var) ({
     omp.terminator
   }) {operand_segment_sizes = dense<[0,0,1,2,1,1]> : vector<6xi32>} : (memref<i32>, memref<i32>, memref<i32>, memref<i32>, memref<i32>) -> ()
 
   return
 }
+
+func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
+  // CHECK: omp.parallel
+  omp.parallel {
+    omp.terminator
+  }
+
+  // CHECK: omp.parallel num_threads(%{{.*}} : si32)
+  omp.parallel num_threads(%num_threads : si32) {
+    omp.terminator
+  }
+
+  // CHECK: omp.parallel private(%{{.*}} : memref<i32>, %{{.*}} : memref<i32>) firstprivate(%{{.*}} : memref<i32>)
+  omp.parallel private(%data_var : memref<i32>, %data_var : memref<i32>) firstprivate(%data_var : memref<i32>) {
+    omp.terminator
+  }
+
+  // CHECK omp.parallel shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>, %{{.*}} : memref<i32>)
+  omp.parallel shared(%data_var : memref<i32>) copyin(%data_var : memref<i32>, %data_var : memref<i32>) {
+    omp.parallel if(%if_cond) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+
+  // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
+  omp.parallel num_threads(%num_threads : si32) if(%if_cond) 
+               private(%data_var : memref<i32>) proc_bind(close) {
+    omp.terminator
+  }
+
+  return
+}


        


More information about the Mlir-commits mailing list