[Mlir-commits] [mlir] 660832c - [OpenMP, MLIR] Translation of parallel operation: num_threads, if clauses 3/n
Kiran Chandramohan
llvmlistbot at llvm.org
Fri Aug 7 13:55:44 PDT 2020
Author: Kiran Chandramohan
Date: 2020-08-07T20:54:24Z
New Revision: 660832c4e744108ecb45b697e51be72482cacd42
URL: https://github.com/llvm/llvm-project/commit/660832c4e744108ecb45b697e51be72482cacd42
DIFF: https://github.com/llvm/llvm-project/commit/660832c4e744108ecb45b697e51be72482cacd42.diff
LOG: [OpenMP,MLIR] Translation of parallel operation: num_threads, if clauses 3/n
This simple patch translates the num_threads and if clauses of the parallel
operation. Also includes test cases.
A minor change was made to parsing of the if clause to parse AnyType and
return the parsed type. Updates to test cases also.
Reviewed by: SouraVX
Differential Revision: https://reviews.llvm.org/D84798
Added:
Modified:
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/openmp-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9159e87509c6..217588289e85 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -69,7 +69,7 @@ static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
p << "omp.parallel";
if (auto ifCond = op.if_expr_var())
- p << " if(" << ifCond << ")";
+ p << " if(" << ifCond << " : " << ifCond.getType() << ")";
if (auto threads = op.num_threads_var())
p << " num_threads(" << threads << " : " << threads.getType() << ")";
@@ -124,7 +124,7 @@ static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
/// Note that each clause can only appear once in the clase-list.
static ParseResult parseParallelOp(OpAsmParser &parser,
OperationState &result) {
- OpAsmParser::OperandType ifCond;
+ std::pair<OpAsmParser::OperandType, Type> ifCond;
std::pair<OpAsmParser::OperandType, Type> numThreads;
llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
llvm::SmallVector<Type, 4> privateTypes;
@@ -152,8 +152,8 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
// Fail if there was already another if condition
if (segments[ifClausePos])
return allowedOnce(parser, "if", opName);
- if (parser.parseLParen() || parser.parseOperand(ifCond) ||
- parser.parseRParen())
+ if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
+ parser.parseColonType(ifCond.second) || parser.parseRParen())
return failure();
segments[ifClausePos] = 1;
} else if (keyword == "num_threads") {
@@ -209,7 +209,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
auto attr = parser.getBuilder().getStringAttr(attrval);
result.addAttribute("default_val", attr);
} else if (keyword == "proc_bind") {
- // fail if there was already another default clause
+ // fail if there was already another proc_bind clause
if (procBind)
return allowedOnce(parser, "proc_bind", opName);
procBind = true;
@@ -228,8 +228,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
// Add if parameter
if (segments[ifClausePos]) {
- parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(),
- result.operands);
+ parser.resolveOperand(ifCond.first, ifCond.second, result.operands);
}
// Add num_threads parameter
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f58d02845ccd..b3b7e4c7afa5 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -454,7 +454,12 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
// TODO: The various operands of parallel operation are not handled.
// Parallel operation is created with some default options for now.
llvm::Value *ifCond = nullptr;
+ if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
+ ifCond = valueMapping.lookup(ifExprVar);
llvm::Value *numThreads = nullptr;
+ if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
+ numThreads = valueMapping.lookup(numThreadsVar);
+ // TODO: Is the Parallel construct cancellable?
bool isCancellable = false;
// TODO: Determine the actual alloca insertion point, e.g., the function
// entry or the alloca insertion point as provided by the body callback
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 00f9726b119d..88f61e7f7916 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -12,7 +12,7 @@ func @unknown_clause() {
func @if_once(%n : i1) {
// expected-error at +1 {{at most one if clause can appear on the omp.parallel operation}}
- omp.parallel if(%n) if(%n) {
+ omp.parallel if(%n : i1) if(%n : i1) {
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 85343f985501..e3e7afaff541 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -99,14 +99,14 @@ func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads :
// CHECK omp.parallel shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>, %{{.*}} : memref<i32>)
omp.parallel shared(%data_var : memref<i32>) copyin(%data_var : memref<i32>, %data_var : memref<i32>) {
- omp.parallel if(%if_cond) {
+ omp.parallel if(%if_cond: i1) {
omp.terminator
}
omp.terminator
}
// CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
- omp.parallel num_threads(%num_threads : si32) if(%if_cond)
+ omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1)
private(%data_var : memref<i32>) proc_bind(close) {
omp.terminator
}
diff --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir
index c8acd8022b2b..60462fee3b97 100644
--- a/mlir/test/Target/openmp-llvm.mlir
+++ b/mlir/test/Target/openmp-llvm.mlir
@@ -78,3 +78,100 @@ llvm.func @test_omp_parallel_2() -> () {
// CHECK-LABEL: omp.par.region2:
// CHECK: call void @body(i64 43)
// CHECK: br label %omp.par.pre_finalize
+
+// CHECK: define void @test_omp_parallel_num_threads_1(i32 %[[NUM_THREADS_VAR_1:.*]])
+llvm.func @test_omp_parallel_num_threads_1(%arg0: !llvm.i32) -> () {
+ // CHECK: %[[GTN_NUM_THREADS_VAR_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_1:.*]])
+ // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_1]], i32 %[[GTN_NUM_THREADS_VAR_1]], i32 %[[NUM_THREADS_VAR_1]])
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_1:.*]] to {{.*}}
+ omp.parallel num_threads(%arg0: !llvm.i32) {
+ omp.barrier
+ omp.terminator
+ }
+
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_1]]
+ // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_num_threads_2()
+llvm.func @test_omp_parallel_num_threads_2() -> () {
+ %0 = llvm.mlir.constant(4 : index) : !llvm.i32
+ // CHECK: %[[GTN_NUM_THREADS_VAR_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_2:.*]])
+ // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_2]], i32 %[[GTN_NUM_THREADS_VAR_2]], i32 4)
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_2:.*]] to {{.*}}
+ omp.parallel num_threads(%0: !llvm.i32) {
+ omp.barrier
+ omp.terminator
+ }
+
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_2]]
+ // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_num_threads_3()
+llvm.func @test_omp_parallel_num_threads_3() -> () {
+ %0 = llvm.mlir.constant(4 : index) : !llvm.i32
+ // CHECK: %[[GTN_NUM_THREADS_VAR_3_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_1:.*]])
+ // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_1]], i32 %[[GTN_NUM_THREADS_VAR_3_1]], i32 4)
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_1:.*]] to {{.*}}
+ omp.parallel num_threads(%0: !llvm.i32) {
+ omp.barrier
+ omp.terminator
+ }
+ %1 = llvm.mlir.constant(8 : index) : !llvm.i32
+ // CHECK: %[[GTN_NUM_THREADS_VAR_3_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_2:.*]])
+ // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_2]], i32 %[[GTN_NUM_THREADS_VAR_3_2]], i32 8)
+ // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_2:.*]] to {{.*}}
+ omp.parallel num_threads(%1: !llvm.i32) {
+ omp.barrier
+ omp.terminator
+ }
+
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_2]]
+ // CHECK: call void @__kmpc_barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_1]]
+ // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
+llvm.func @test_omp_parallel_if_1(%arg0: !llvm.i32) -> () {
+
+// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
+ %0 = llvm.mlir.constant(0 : index) : !llvm.i32
+ %1 = llvm.icmp "slt" %arg0, %0 : !llvm.i32
+
+// CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[SI_VAR_IF_1:.*]])
+// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]]
+// CHECK: [[IF_COND_TRUE_BLOCK_1]]:
+// CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]]
+// CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]:
+// CHECK: call void {{.*}} @__kmpc_fork_call(%struct.ident_t* @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]] to void
+// CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]]
+// CHECK: [[OUTLINED_EXIT_IF_1]]:
+// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]]
+// CHECK: [[OUTLINED_EXIT_IF_2]]:
+// CHECK: br label %[[RETURN_BLOCK_IF_1:.*]]
+// CHECK: [[IF_COND_FALSE_BLOCK_1]]:
+// CHECK: call void @__kmpc_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
+// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]]
+// CHECK: call void @__kmpc_end_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
+// CHECK: br label %[[RETURN_BLOCK_IF_1]]
+ omp.parallel if(%1 : !llvm.i1) {
+ omp.barrier
+ omp.terminator
+ }
+
+// CHECK: [[RETURN_BLOCK_IF_1]]:
+// CHECK: ret void
+ llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
+ // CHECK: call void @__kmpc_barrier
More information about the Mlir-commits
mailing list