[Mlir-commits] [mlir] 3915171 - [mlir][OpenMP] Added assemblyFormat for ParallelOp

Shraiysh Vaishay llvmlistbot at llvm.org
Fri Feb 18 20:59:08 PST 2022


Author: Shraiysh Vaishay
Date: 2022-02-19T10:28:58+05:30
New Revision: 39151717dbb494463cda59fe5d776870816790ce

URL: https://github.com/llvm/llvm-project/commit/39151717dbb494463cda59fe5d776870816790ce
DIFF: https://github.com/llvm/llvm-project/commit/39151717dbb494463cda59fe5d776870816790ce.diff

LOG: [mlir][OpenMP] Added assemblyFormat for ParallelOp

This patch adds assemblyFormat for omp.parallel operation.

Some existing functions have been altered to fit the custom directive
in assemblyFormat. This has led to their callsites to get modified too,
but those will be removed in later patches, when other operations get
their assemblyFormat. All operations were not changed in one patch for
ease of review.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D120157

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ec535edf81d9f..6ed13e6d8ff2c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -97,7 +97,17 @@ def ParallelOp : OpenMP_Op<"parallel", [
   let builders = [
     OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
   ];
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)`
+          | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
+          | `allocate` `(`
+              custom<AllocateAndAllocator>(
+                $allocate_vars, type($allocate_vars),
+                $allocators_vars, type($allocators_vars)
+              ) `)`
+          | `proc_bind` `(` custom<ProcBindKind>($proc_bind_val) `)`
+    ) $region attr-dict
+  }];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bc3b595483d78..babd71e85bd09 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -89,35 +89,53 @@ static ParseResult parseAllocateAndAllocator(
     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
     SmallVectorImpl<Type> &typesAllocator) {
 
-  return parser.parseCommaSeparatedList(
-      OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
-        OpAsmParser::OperandType operand;
-        Type type;
-        if (parser.parseOperand(operand) || parser.parseColonType(type))
-          return failure();
-        operandsAllocator.push_back(operand);
-        typesAllocator.push_back(type);
-        if (parser.parseArrow())
-          return failure();
-        if (parser.parseOperand(operand) || parser.parseColonType(type))
-          return failure();
+  return parser.parseCommaSeparatedList([&]() -> ParseResult {
+    OpAsmParser::OperandType operand;
+    Type type;
+    if (parser.parseOperand(operand) || parser.parseColonType(type))
+      return failure();
+    operandsAllocator.push_back(operand);
+    typesAllocator.push_back(type);
+    if (parser.parseArrow())
+      return failure();
+    if (parser.parseOperand(operand) || parser.parseColonType(type))
+      return failure();
 
-        operandsAllocate.push_back(operand);
-        typesAllocate.push_back(type);
-        return success();
-      });
+    operandsAllocate.push_back(operand);
+    typesAllocate.push_back(type);
+    return success();
+  });
 }
 
 /// Print allocate clause
-static void printAllocateAndAllocator(OpAsmPrinter &p,
+static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
                                       OperandRange varsAllocate,
-                                      OperandRange varsAllocator) {
-  p << "allocate(";
+                                      TypeRange typesAllocate,
+                                      OperandRange varsAllocator,
+                                      TypeRange typesAllocator) {
   for (unsigned i = 0; i < varsAllocate.size(); ++i) {
-    std::string separator = i == varsAllocate.size() - 1 ? ") " : ", ";
-    p << varsAllocator[i] << " : " << varsAllocator[i].getType() << " -> ";
-    p << varsAllocate[i] << " : " << varsAllocate[i].getType() << separator;
+    std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
+    p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
+    p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
+  }
+}
+
+ParseResult parseProcBindKind(OpAsmParser &parser,
+                              omp::ClauseProcBindKindAttr &procBindAttr) {
+  StringRef procBindStr;
+  if (parser.parseKeyword(&procBindStr))
+    return failure();
+  if (auto procBindVal = symbolizeClauseProcBindKind(procBindStr)) {
+    procBindAttr =
+        ClauseProcBindKindAttr::get(parser.getContext(), *procBindVal);
+    return success();
   }
+  return failure();
+}
+
+void printProcBindKind(OpAsmPrinter &p, Operation *op,
+                       omp::ClauseProcBindKindAttr procBindAttr) {
+  p << stringifyClauseProcBindKind(procBindAttr.getValue());
 }
 
 LogicalResult ParallelOp::verify() {
@@ -127,24 +145,6 @@ LogicalResult ParallelOp::verify() {
   return success();
 }
 
-void ParallelOp::print(OpAsmPrinter &p) {
-  p << " ";
-  if (auto ifCond = if_expr_var())
-    p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
-
-  if (auto threads = num_threads_var())
-    p << "num_threads(" << threads << " : " << threads.getType() << ") ";
-
-  if (!allocate_vars().empty())
-    printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
-
-  if (auto bind = proc_bind_val())
-    p << "proc_bind(" << stringifyClauseProcBindKind(*bind) << ") ";
-
-  p << ' ';
-  p.printRegion(getRegion());
-}
-
 //===----------------------------------------------------------------------===//
 // Parser and printer for Linear Clause
 //===----------------------------------------------------------------------===//
@@ -626,9 +626,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
         return failure();
       clauseSegments[pos[threadLimitClause]] = 1;
     } else if (clauseKeyword == "allocate") {
-      if (checkAllowed(allocateClause) ||
+      if (checkAllowed(allocateClause) || parser.parseLParen() ||
           parseAllocateAndAllocator(parser, allocates, allocateTypes,
-                                    allocators, allocatorTypes))
+                                    allocators, allocatorTypes) ||
+          parser.parseRParen())
         return failure();
       clauseSegments[pos[allocateClause]] = allocates.size();
       clauseSegments[pos[allocateClause] + 1] = allocators.size();
@@ -803,32 +804,6 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
   return success();
 }
 
-/// Parses a parallel operation.
-///
-/// operation ::= `omp.parallel` clause-list
-/// clause-list ::= clause | clause clause-list
-/// clause ::= if | num-threads | allocate | proc-bind
-///
-ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<ClauseType> clauses = {ifClause, numThreadsClause, allocateClause,
-                                     procBindClause};
-
-  SmallVector<int> segments;
-
-  if (failed(parseClauses(parser, result, clauses, segments)))
-    return failure();
-
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(segments));
-
-  Region *body = result.addRegion();
-  SmallVector<OpAsmParser::OperandType> regionArgs;
-  SmallVector<Type> regionArgTypes;
-  if (parser.parseRegion(*body, regionArgs, regionArgTypes))
-    return failure();
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for SectionsOp
 //===----------------------------------------------------------------------===//
@@ -863,8 +838,12 @@ void SectionsOp::print(OpAsmPrinter &p) {
   if (!reduction_vars().empty())
     printReductionVarList(p, reductions(), reduction_vars());
 
-  if (!allocate_vars().empty())
-    printAllocateAndAllocator(p, allocate_vars(), allocators_vars());
+  if (!allocate_vars().empty()) {
+    printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
+                              allocate_vars().getTypes(), allocators_vars(),
+                              allocators_vars().getTypes());
+    p << ")";
+  }
 
   if (nowait())
     p << "nowait";

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 6646410183c74..8a5d50dd0fb96 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 func @unknown_clause() {
-  // expected-error at +1 {{invalid is not a valid clause}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel invalid {
   }
 
@@ -11,7 +11,7 @@ func @unknown_clause() {
 // -----
 
 func @if_once(%n : i1) {
-  // expected-error at +1 {{at most one if clause can appear on the omp.parallel operation}}
+  // expected-error at +1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
   omp.parallel if(%n : i1) if(%n : i1) {
   }
 
@@ -21,7 +21,7 @@ func @if_once(%n : i1) {
 // -----
 
 func @num_threads_once(%n : si32) {
-  // expected-error at +1 {{at most one num_threads clause can appear on the omp.parallel operation}}
+  // expected-error at +1 {{`num_threads` clause can appear at most once in the expansion of the oilist directive}}
   omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
   }
 
@@ -31,7 +31,7 @@ func @num_threads_once(%n : si32) {
 // -----
 
 func @nowait_not_allowed(%n : memref<i32>) {
-  // expected-error at +1 {{nowait is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
   return
 }
@@ -39,7 +39,7 @@ func @nowait_not_allowed(%n : memref<i32>) {
 // -----
 
 func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
-  // expected-error at +1 {{linear is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel linear(%data_var = %linear_var : memref<i32>)  {}
   return
 }
@@ -47,7 +47,7 @@ func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
 // -----
 
 func @schedule_not_allowed() {
-  // expected-error at +1 {{schedule is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel schedule(static) {}
   return
 }
@@ -55,7 +55,7 @@ func @schedule_not_allowed() {
 // -----
 
 func @collapse_not_allowed() {
-  // expected-error at +1 {{collapse is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel collapse(3) {}
   return
 }
@@ -63,7 +63,7 @@ func @collapse_not_allowed() {
 // -----
 
 func @order_not_allowed() {
-  // expected-error at +1 {{order is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel order(concurrent) {}
   return
 }
@@ -71,14 +71,14 @@ func @order_not_allowed() {
 // -----
 
 func @ordered_not_allowed() {
-  // expected-error at +1 {{ordered is not a valid clause for the omp.parallel operation}}
+  // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel ordered(2) {}
 }
 
 // -----
 
 func @proc_bind_once() {
-  // expected-error at +1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
+  // expected-error at +1 {{`proc_bind` clause can appear at most once in the expansion of the oilist directive}}
   omp.parallel proc_bind(close) proc_bind(spread) {
   }
 

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 573b036f5746a..cbb8b1f550da4 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -59,7 +59,7 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
   // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%num_threads, %data_var, %data_var) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
+    }) {num_threads, allocate, operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -68,22 +68,22 @@ func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32)
   // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%if_cond, %data_var, %data_var) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
+    }) {if, allocate, operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref<i32>, memref<i32>) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
+    }) {if, num_threads, operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> ()
 
     omp.terminator
-  }) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
+  }) {if, num_threads, allocate, operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
+  }) {allocate, operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref<i32>, memref<i32>) -> ()
 
   return
 }


        


More information about the Mlir-commits mailing list