[Mlir-commits] [mlir] e9b4239 - [mlir][openmp] Custom syntax for `omp.target` operation

Kiran Chandramohan llvmlistbot at llvm.org
Wed Jan 26 02:28:23 PST 2022


Author: Alexander Batashev
Date: 2022-01-26T10:26:19Z
New Revision: e9b4239fefa657362678c063c1ba81b0eed2cab3

URL: https://github.com/llvm/llvm-project/commit/e9b4239fefa657362678c063c1ba81b0eed2cab3
DIFF: https://github.com/llvm/llvm-project/commit/e9b4239fefa657362678c063c1ba81b0eed2cab3.diff

LOG: [mlir][openmp] Custom syntax for `omp.target` operation

Add a custom parser and printer for `omp.target` operation.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D117539

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 311513f1682a6..d628027698d33 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -418,6 +418,9 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
                        UnitAttr:$nowait);
 
   let regions = (region AnyRegion:$region);
+
+  let parser = [{ return parseTargetOp(parser, result); }];
+  let printer = [{ return printTargetOp(p, *this); }];
 }
 
 

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1eeee3f65f3ba..089d534bcdae4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -198,6 +198,24 @@ static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
   p.printRegion(op.getRegion());
 }
 
+static void printTargetOp(OpAsmPrinter &p, TargetOp op) {
+  p << " ";
+  if (auto ifCond = op.if_expr())
+    p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
+
+  if (auto device = op.device())
+    p << "device(" << device << " : " << device.getType() << ") ";
+
+  if (auto threads = op.thread_limit())
+    p << "thread_limit(" << threads << " : " << threads.getType() << ") ";
+
+  if (op.nowait()) {
+    p << "nowait ";
+  }
+
+  p.printRegion(op.getRegion());
+}
+
 //===----------------------------------------------------------------------===//
 // Parser and printer for Linear Clause
 //===----------------------------------------------------------------------===//
@@ -523,6 +541,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
 enum ClauseType {
   ifClause,
   numThreadsClause,
+  deviceClause,
+  threadLimitClause,
   privateClause,
   firstprivateClause,
   lastprivateClause,
@@ -611,6 +631,8 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
   // Containers for storing operands, types and attributes for various clauses
   std::pair<OpAsmParser::OperandType, Type> ifCond;
   std::pair<OpAsmParser::OperandType, Type> numThreads;
+  std::pair<OpAsmParser::OperandType, Type> device;
+  std::pair<OpAsmParser::OperandType, Type> threadLimit;
 
   SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
       shareds, copyins;
@@ -681,6 +703,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
           parser.parseColonType(numThreads.second) || parser.parseRParen())
         return failure();
       clauseSegments[pos[numThreadsClause]] = 1;
+    } else if (clauseKeyword == "device") {
+      if (checkAllowed(deviceClause) || parser.parseLParen() ||
+          parser.parseOperand(device.first) ||
+          parser.parseColonType(device.second) || parser.parseRParen())
+        return failure();
+      clauseSegments[pos[deviceClause]] = 1;
+    } else if (clauseKeyword == "thread_limit") {
+      if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
+          parser.parseOperand(threadLimit.first) ||
+          parser.parseColonType(threadLimit.second) || parser.parseRParen())
+        return failure();
+      clauseSegments[pos[threadLimitClause]] = 1;
     } else if (clauseKeyword == "private") {
       if (checkAllowed(privateClause) ||
           parseOperandAndTypeList(parser, privates, privateTypes))
@@ -812,6 +846,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
                                    result.operands)))
     return failure();
 
+  // Add device parameter.
+  if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
+      failed(
+          parser.resolveOperand(device.first, device.second, result.operands)))
+    return failure();
+
+  // Add thread_limit parameter.
+  if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
+      failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
+                                   result.operands)))
+    return failure();
+
   // Add private parameters.
   if (done[privateClause] && clauseSegments[pos[privateClause]] &&
       failed(parser.resolveOperands(privates, privateTypes,
@@ -948,6 +994,33 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
   return success();
 }
 
+/// Parses a target operation.
+///
+/// operation ::= `omp.target` clause-list
+/// clause-list ::= clause | clause clause-list
+/// clause ::= if | device | thread_limit | nowait
+///
+static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) {
+  SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause,
+                                     nowaitClause};
+
+  SmallVector<int> segments;
+
+  if (failed(parseClauses(parser, result, clauses, segments)))
+    return failure();
+
+  result.addAttribute(
+      TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(),
+      parser.getBuilder().getI32VectorAttr(segments));
+
+  Region *body = result.addRegion();
+  SmallVector<OpAsmParser::OperandType> regionArgs;
+  SmallVector<Type> regionArgTypes;
+  if (parser.parseRegion(*body, regionArgs, regionArgTypes))
+    return failure();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Parser, printer and verifier for SectionsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 96a0b427123f3..3586acbd9dd2d 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -330,12 +330,12 @@ func @omp_wsloop_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, %lb2 : i3
 // CHECK-LABEL: omp_target
 func @omp_target(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
 
-    // Test with optional operands; if_expr, device, thread_limit, and nowait.
-    // CHECK: omp.target
+    // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
+    // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32  ) -> ()
+    }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> ()
 
     // CHECK: omp.barrier
     omp.barrier
@@ -343,6 +343,21 @@ func @omp_target(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
     return
 }
 
+// CHECK-LABEL: omp_target_pretty
+func @omp_target_pretty(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
+    // CHECK: omp.target if({{.*}}) device({{.*}})
+    omp.target if(%if_cond : i1) device(%device : si32) {
+      omp.terminator
+    }
+
+    // CHECK: omp.target if({{.*}}) device({{.*}}) nowait
+    omp.target if(%if_cond : i1) device(%device : si32) thread_limit(%num_threads : si32) nowait {
+      omp.terminator
+    }
+
+    return
+}
+
 // CHECK: omp.reduction.declare
 // CHECK-LABEL: @add_f32
 // CHECK: : f32


        


More information about the Mlir-commits mailing list