[Mlir-commits] [mlir] c1e4e01 - [mlir][OpenMP] Added assemblyFormat for SectionsOp

Shraiysh Vaishay llvmlistbot at llvm.org
Sun Feb 20 23:32:01 PST 2022


Author: Shraiysh Vaishay
Date: 2022-02-21T13:01:49+05:30
New Revision: c1e4e019454b38e3890589be977a3c2c445fefd1

URL: https://github.com/llvm/llvm-project/commit/c1e4e019454b38e3890589be977a3c2c445fefd1
DIFF: https://github.com/llvm/llvm-project/commit/c1e4e019454b38e3890589be977a3c2c445fefd1.diff

LOG: [mlir][OpenMP] Added assemblyFormat for SectionsOp

This patch adds assemblyFormat for omp.sections operation.

Some existing functions have been altered to fit the custom directive
in assemblyFormat. This has led to their callsites to get modified too,
but those will be removed in later patches, when other operations get
their assemblyFormat. All operations were not changed in one patch for
ease of review.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D120176

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 6ed13e6d8ff2c..d316ca6314b53 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -188,7 +188,20 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
 
   let regions = (region SizedRegion<1>:$region);
 
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    oilist( `reduction` `(`
+              custom<ReductionVarList>(
+                $reduction_vars, type($reduction_vars), $reductions
+              ) `)`
+          | `allocate` `(`
+              custom<AllocateAndAllocator>(
+                $allocate_vars, type($allocate_vars),
+                $allocators_vars, type($allocators_vars)
+              ) `)`
+          | `nowait`
+    ) $region attr-dict
+  }];
+
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index babd71e85bd09..4fa4e5819b339 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -77,7 +77,6 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 
 /// Parse an allocate clause with allocators and a list of operands with types.
 ///
-/// allocate ::= `allocate` `(` allocate-operand-list `)`
 /// allocate-operand-list :: = allocate-operand |
 ///                            allocator-operand `,` allocate-operand-list
 /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
@@ -300,39 +299,35 @@ static void printScheduleClause(OpAsmPrinter &p, ClauseScheduleKind sched,
 // Parser, printer and verifier for ReductionVarList
 //===----------------------------------------------------------------------===//
 
-/// reduction ::= `reduction` `(` reduction-entry-list `)`
 /// reduction-entry-list ::= reduction-entry
 ///                        | reduction-entry-list `,` reduction-entry
 /// reduction-entry ::= symbol-ref `->` ssa-id `:` type
-static ParseResult
-parseReductionVarList(OpAsmParser &parser,
-                      SmallVectorImpl<SymbolRefAttr> &symbols,
-                      SmallVectorImpl<OpAsmParser::OperandType> &operands,
-                      SmallVectorImpl<Type> &types) {
-  if (failed(parser.parseLParen()))
-    return failure();
-
+static ParseResult parseReductionVarList(
+    OpAsmParser &parser, SmallVectorImpl<OpAsmParser::OperandType> &operands,
+    SmallVectorImpl<Type> &types, ArrayAttr &redcuctionSymbols) {
+  SmallVector<SymbolRefAttr> reductionVec;
   do {
-    if (parser.parseAttribute(symbols.emplace_back()) || parser.parseArrow() ||
-        parser.parseOperand(operands.emplace_back()) ||
+    if (parser.parseAttribute(reductionVec.emplace_back()) ||
+        parser.parseArrow() || parser.parseOperand(operands.emplace_back()) ||
         parser.parseColonType(types.emplace_back()))
       return failure();
   } while (succeeded(parser.parseOptionalComma()));
-  return parser.parseRParen();
+  SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
+  redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+  return success();
 }
 
 /// Print Reduction clause
-static void printReductionVarList(OpAsmPrinter &p,
-                                  Optional<ArrayAttr> reductions,
-                                  OperandRange reductionVars) {
-  p << "reduction(";
+static void printReductionVarList(OpAsmPrinter &p, Operation *op,
+                                  OperandRange reductionVars,
+                                  TypeRange reductionTypes,
+                                  Optional<ArrayAttr> reductions) {
   for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
     if (i != 0)
       p << ", ";
     p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
       << reductionVars[i].getType();
   }
-  p << ") ";
 }
 
 /// Verifies Reduction Clause
@@ -552,7 +547,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
   SmallVector<OpAsmParser::OperandType> allocates, allocators;
   SmallVector<Type> allocateTypes, allocatorTypes;
 
-  SmallVector<SymbolRefAttr> reductionSymbols;
+  ArrayAttr reductions;
   SmallVector<OpAsmParser::OperandType> reductionVars;
   SmallVector<Type> reductionVarTypes;
 
@@ -639,9 +634,10 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
                                                   "proc_bind_val", "proc bind"))
         return failure();
     } else if (clauseKeyword == "reduction") {
-      if (checkAllowed(reductionClause) ||
-          parseReductionVarList(parser, reductionSymbols, reductionVars,
-                                reductionVarTypes))
+      if (checkAllowed(reductionClause) || parser.parseLParen() ||
+          parseReductionVarList(parser, reductionVars, reductionVarTypes,
+                                reductions) ||
+          parser.parseRParen())
         return failure();
       clauseSegments[pos[reductionClause]] = reductionVars.size();
     } else if (clauseKeyword == "nowait") {
@@ -746,11 +742,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
     if (failed(parser.resolveOperands(reductionVars, reductionVarTypes,
                                       parser.getNameLoc(), result.operands)))
       return failure();
-
-    SmallVector<Attribute> reductions(reductionSymbols.begin(),
-                                      reductionSymbols.end());
-    result.addAttribute("reductions",
-                        parser.getBuilder().getArrayAttr(reductions));
+    result.addAttribute("reductions", reductions);
   }
 
   // Add linear parameters
@@ -805,53 +797,9 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
 }
 
 //===----------------------------------------------------------------------===//
-// Parser, printer and verifier for SectionsOp
+// Verifier for SectionsOp
 //===----------------------------------------------------------------------===//
 
-/// Parses an OpenMP Sections operation
-///
-/// sections ::= `omp.sections` clause-list
-/// clause-list ::= clause clause-list | empty
-/// clause ::= reduction | allocate | nowait
-ParseResult SectionsOp::parse(OpAsmParser &parser, OperationState &result) {
-  SmallVector<ClauseType> clauses = {reductionClause, allocateClause,
-                                     nowaitClause};
-
-  SmallVector<int> segments;
-
-  if (failed(parseClauses(parser, result, clauses, segments)))
-    return failure();
-
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(segments));
-
-  // Now parse the body.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body))
-    return failure();
-  return success();
-}
-
-void SectionsOp::print(OpAsmPrinter &p) {
-  p << " ";
-
-  if (!reduction_vars().empty())
-    printReductionVarList(p, reductions(), reduction_vars());
-
-  if (!allocate_vars().empty()) {
-    printAllocateAndAllocator(p << "allocate(", *this, allocate_vars(),
-                              allocate_vars().getTypes(), allocators_vars(),
-                              allocators_vars().getTypes());
-    p << ")";
-  }
-
-  if (nowait())
-    p << "nowait";
-
-  p << ' ';
-  p.printRegion(region());
-}
-
 LogicalResult SectionsOp::verify() {
   if (allocate_vars().size() != allocators_vars().size())
     return emitError(
@@ -960,8 +908,11 @@ void WsLoopOp::print(OpAsmPrinter &p) {
   if (auto order = order_val())
     p << "order(" << stringifyClauseOrderKind(*order) << ") ";
 
-  if (!reduction_vars().empty())
-    printReductionVarList(p, reductions(), reduction_vars());
+  if (!reduction_vars().empty()) {
+    printReductionVarList(p << "reduction(", *this, reduction_vars(),
+                          reduction_vars().getTypes(), reductions());
+    p << ")";
+  }
 
   p << ' ';
   p.printRegion(region(), /*printEntryBlockArgs=*/false);

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8a5d50dd0fb96..a991d5f20f6b7 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -793,7 +793,7 @@ func @omp_sections(%data_var : memref<i32>) -> () {
 // -----
 
 func @omp_sections(%cond : i1) {
-  // expected-error @below {{if is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections if(%cond) {
     omp.terminator
   }
@@ -803,7 +803,7 @@ func @omp_sections(%cond : i1) {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{num_threads is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections num_threads(10) {
     omp.terminator
   }
@@ -813,7 +813,7 @@ func @omp_sections() {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{proc_bind is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections proc_bind(close) {
     omp.terminator
   }
@@ -823,7 +823,7 @@ func @omp_sections() {
 // -----
 
 func @omp_sections(%data_var : memref<i32>, %linear_var : i32) {
-  // expected-error @below {{linear is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections linear(%data_var = %linear_var : memref<i32>) {
     omp.terminator
   }
@@ -833,7 +833,7 @@ func @omp_sections(%data_var : memref<i32>, %linear_var : i32) {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{schedule is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections schedule(static, none) {
     omp.terminator
   }
@@ -843,7 +843,7 @@ func @omp_sections() {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{collapse is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections collapse(3) {
     omp.terminator
   }
@@ -853,7 +853,7 @@ func @omp_sections() {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{ordered is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections ordered(2) {
     omp.terminator
   }
@@ -863,7 +863,7 @@ func @omp_sections() {
 // -----
 
 func @omp_sections() {
-  // expected-error @below {{order is not a valid clause for the omp.sections operation}}
+  // expected-error @below {{expected '{' to begin a region}}
   omp.sections order(concurrent) {
     omp.terminator
   }

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index cbb8b1f550da4..e2cc900bf3787 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -624,13 +624,13 @@ func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
   "omp.sections" (%data_var1, %data_var1) ({
     // CHECK: omp.terminator
     omp.terminator
-  }) {operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref<i32>, memref<i32>) -> ()
+  }) {allocate, operand_segment_sizes = dense<[0,1,1]> : vector<3xi32>} : (memref<i32>, memref<i32>) -> ()
 
     // CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr<f32>)
   "omp.sections" (%redn_var) ({
     // CHECK: omp.terminator
     omp.terminator
-  }) {operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr<f32>) -> ()
+  }) {reduction, operand_segment_sizes = dense<[1,0,0]> : vector<3xi32>, reductions=[@add_f32]} : (!llvm.ptr<f32>) -> ()
 
   // CHECK: omp.sections nowait {
   omp.sections nowait {


        


More information about the Mlir-commits mailing list