[Mlir-commits] [mlir] e9b4239 - [mlir][openmp] Custom syntax for `omp.target` operation
Kiran Chandramohan
llvmlistbot at llvm.org
Wed Jan 26 02:28:23 PST 2022
Author: Alexander Batashev
Date: 2022-01-26T10:26:19Z
New Revision: e9b4239fefa657362678c063c1ba81b0eed2cab3
URL: https://github.com/llvm/llvm-project/commit/e9b4239fefa657362678c063c1ba81b0eed2cab3
DIFF: https://github.com/llvm/llvm-project/commit/e9b4239fefa657362678c063c1ba81b0eed2cab3.diff
LOG: [mlir][openmp] Custom syntax for `omp.target` operation
Add a custom parser and printer for `omp.target` operation.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D117539
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 311513f1682a6..d628027698d33 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -418,6 +418,9 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
UnitAttr:$nowait);
let regions = (region AnyRegion:$region);
+
+ let parser = [{ return parseTargetOp(parser, result); }];
+ let printer = [{ return printTargetOp(p, *this); }];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1eeee3f65f3ba..089d534bcdae4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -198,6 +198,24 @@ static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
p.printRegion(op.getRegion());
}
+static void printTargetOp(OpAsmPrinter &p, TargetOp op) {
+ p << " ";
+ if (auto ifCond = op.if_expr())
+ p << "if(" << ifCond << " : " << ifCond.getType() << ") ";
+
+ if (auto device = op.device())
+ p << "device(" << device << " : " << device.getType() << ") ";
+
+ if (auto threads = op.thread_limit())
+ p << "thread_limit(" << threads << " : " << threads.getType() << ") ";
+
+ if (op.nowait()) {
+ p << "nowait ";
+ }
+
+ p.printRegion(op.getRegion());
+}
+
//===----------------------------------------------------------------------===//
// Parser and printer for Linear Clause
//===----------------------------------------------------------------------===//
@@ -523,6 +541,8 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
enum ClauseType {
ifClause,
numThreadsClause,
+ deviceClause,
+ threadLimitClause,
privateClause,
firstprivateClause,
lastprivateClause,
@@ -611,6 +631,8 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
// Containers for storing operands, types and attributes for various clauses
std::pair<OpAsmParser::OperandType, Type> ifCond;
std::pair<OpAsmParser::OperandType, Type> numThreads;
+ std::pair<OpAsmParser::OperandType, Type> device;
+ std::pair<OpAsmParser::OperandType, Type> threadLimit;
SmallVector<OpAsmParser::OperandType> privates, firstprivates, lastprivates,
shareds, copyins;
@@ -681,6 +703,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
parser.parseColonType(numThreads.second) || parser.parseRParen())
return failure();
clauseSegments[pos[numThreadsClause]] = 1;
+ } else if (clauseKeyword == "device") {
+ if (checkAllowed(deviceClause) || parser.parseLParen() ||
+ parser.parseOperand(device.first) ||
+ parser.parseColonType(device.second) || parser.parseRParen())
+ return failure();
+ clauseSegments[pos[deviceClause]] = 1;
+ } else if (clauseKeyword == "thread_limit") {
+ if (checkAllowed(threadLimitClause) || parser.parseLParen() ||
+ parser.parseOperand(threadLimit.first) ||
+ parser.parseColonType(threadLimit.second) || parser.parseRParen())
+ return failure();
+ clauseSegments[pos[threadLimitClause]] = 1;
} else if (clauseKeyword == "private") {
if (checkAllowed(privateClause) ||
parseOperandAndTypeList(parser, privates, privateTypes))
@@ -812,6 +846,18 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
result.operands)))
return failure();
+ // Add device parameter.
+ if (done[deviceClause] && clauseSegments[pos[deviceClause]] &&
+ failed(
+ parser.resolveOperand(device.first, device.second, result.operands)))
+ return failure();
+
+ // Add thread_limit parameter.
+ if (done[threadLimitClause] && clauseSegments[pos[threadLimitClause]] &&
+ failed(parser.resolveOperand(threadLimit.first, threadLimit.second,
+ result.operands)))
+ return failure();
+
// Add private parameters.
if (done[privateClause] && clauseSegments[pos[privateClause]] &&
failed(parser.resolveOperands(privates, privateTypes,
@@ -948,6 +994,33 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
return success();
}
+/// Parses a target operation.
+///
+/// operation ::= `omp.target` clause-list
+/// clause-list ::= clause | clause clause-list
+/// clause ::= if | device | thread_limit | nowait
+///
+static ParseResult parseTargetOp(OpAsmParser &parser, OperationState &result) {
+ SmallVector<ClauseType> clauses = {ifClause, deviceClause, threadLimitClause,
+ nowaitClause};
+
+ SmallVector<int> segments;
+
+ if (failed(parseClauses(parser, result, clauses, segments)))
+ return failure();
+
+ result.addAttribute(
+ TargetOp::AttrSizedOperandSegments::getOperandSegmentSizeAttr(),
+ parser.getBuilder().getI32VectorAttr(segments));
+
+ Region *body = result.addRegion();
+ SmallVector<OpAsmParser::OperandType> regionArgs;
+ SmallVector<Type> regionArgTypes;
+ if (parser.parseRegion(*body, regionArgs, regionArgTypes))
+ return failure();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Parser, printer and verifier for SectionsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 96a0b427123f3..3586acbd9dd2d 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -330,12 +330,12 @@ func @omp_wsloop_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, %lb2 : i3
// CHECK-LABEL: omp_target
func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
- // Test with optional operands; if_expr, device, thread_limit, and nowait.
- // CHECK: omp.target
+ // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
+ // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> ()
+ }) {operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>, nowait } : ( i1, si32, si32 ) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -343,6 +343,21 @@ func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
return
}
+// CHECK-LABEL: omp_target_pretty
+func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
+ // CHECK: omp.target if({{.*}}) device({{.*}})
+ omp.target if(%if_cond : i1) device(%device : si32) {
+ omp.terminator
+ }
+
+ // CHECK: omp.target if({{.*}}) device({{.*}}) nowait
+ omp.target if(%if_cond : i1) device(%device : si32) thread_limit(%num_threads : si32) nowait {
+ omp.terminator
+ }
+
+ return
+}
+
// CHECK: omp.reduction.declare
// CHECK-LABEL: @add_f32
// CHECK: : f32
More information about the Mlir-commits
mailing list