[Mlir-commits] [mlir] a60fda5 - [mlir][OpenMP] Restrict types for omp.parallel args
Shraiysh Vaishay
llvmlistbot at llvm.org
Mon May 2 01:47:46 PDT 2022
Author: Shraiysh Vaishay
Date: 2022-05-02T14:17:34+05:30
New Revision: a60fda59dc6b1dda25cad26214b02d1f630319e7
URL: https://github.com/llvm/llvm-project/commit/a60fda59dc6b1dda25cad26214b02d1f630319e7
DIFF: https://github.com/llvm/llvm-project/commit/a60fda59dc6b1dda25cad26214b02d1f630319e7.diff
LOG: [mlir][OpenMP] Restrict types for omp.parallel args
This patch restricts the value of `if` clause expression to an I1 value.
It also restricts the value of `num_threads` clause expression to an I32
value.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D124142
Added:
Modified:
flang/lib/Lower/OpenMP.cpp
flang/test/Lower/OpenMP/parallel.f90
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 93cb936e1e860..c3740123437f1 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -254,8 +254,10 @@ genOMP(Fortran::lower::AbstractConverter &converter,
if (const auto &ifClause =
std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
- ifClauseOperand = fir::getBase(
+ mlir::Value ifVal = fir::getBase(
converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+ ifClauseOperand = firOpBuilder.createConvert(
+ currentLocation, firOpBuilder.getI1Type(), ifVal);
} else if (const auto &numThreadsClause =
std::get_if<Fortran::parser::OmpClause::NumThreads>(
&clause.u)) {
diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90
index 849db5c3705f6..70dcee4ddd255 100644
--- a/flang/test/Lower/OpenMP/parallel.f90
+++ b/flang/test/Lower/OpenMP/parallel.f90
@@ -15,8 +15,13 @@ end subroutine parallel_simple
!===============================================================================
!FIRDialect-LABEL: func @_QPparallel_if
-subroutine parallel_if(alpha)
+subroutine parallel_if(alpha, beta, gamma)
integer, intent(in) :: alpha
+ logical, intent(in) :: beta
+ logical(1) :: logical1
+ logical(2) :: logical2
+ logical(4) :: logical4
+ logical(8) :: logical8
!OMPDialect: omp.parallel if(%{{.*}} : i1) {
!$omp parallel if(alpha .le. 0)
@@ -46,6 +51,41 @@ subroutine parallel_if(alpha)
!OMPDialect: omp.terminator
!$omp end parallel
+ !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+ !$omp parallel if(beta)
+ !FIRDialect: fir.call
+ call f1()
+ !OMPDialect: omp.terminator
+ !$omp end parallel
+
+ !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+ !$omp parallel if(logical1)
+ !FIRDialect: fir.call
+ call f1()
+ !OMPDialect: omp.terminator
+ !$omp end parallel
+
+ !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+ !$omp parallel if(logical2)
+ !FIRDialect: fir.call
+ call f1()
+ !OMPDialect: omp.terminator
+ !$omp end parallel
+
+ !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+ !$omp parallel if(logical4)
+ !FIRDialect: fir.call
+ call f1()
+ !OMPDialect: omp.terminator
+ !$omp end parallel
+
+ !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+ !$omp parallel if(logical8)
+ !FIRDialect: fir.call
+ call f1()
+ !OMPDialect: omp.terminator
+ !$omp end parallel
+
end subroutine parallel_if
!===============================================================================
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 39ccb70b41383..bc5a7ff89783f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -99,8 +99,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
of the parallel region.
}];
- let arguments = (ins Optional<AnyType>:$if_expr_var,
- Optional<AnyType>:$num_threads_var,
+ let arguments = (ins Optional<I1>:$if_expr_var,
+ Optional<IntLikeType>:$num_threads_var,
Variadic<AnyType>:$allocate_vars,
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 57387b35abd04..15ec6f796b880 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -51,15 +51,15 @@ func.func @omp_terminator() -> () {
omp.terminator
}
-func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
- // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
+func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
+ // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
// test without if condition
- // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
+ // CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
- }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (i32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -71,13 +71,13 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
}) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
- // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
+ // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
+ }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, i32) -> ()
omp.terminator
- }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
+ }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, i32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
@@ -88,14 +88,26 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
return
}
-func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () {
+func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %allocator : si32) -> () {
// CHECK: omp.parallel
omp.parallel {
omp.terminator
}
- // CHECK: omp.parallel num_threads(%{{.*}} : si32)
- omp.parallel num_threads(%num_threads : si32) {
+ // CHECK: omp.parallel num_threads(%{{.*}} : i32)
+ omp.parallel num_threads(%num_threads : i32) {
+ omp.terminator
+ }
+
+ %n_index = arith.constant 2 : index
+ // CHECK: omp.parallel num_threads(%{{.*}} : index)
+ omp.parallel num_threads(%n_index : index) {
+ omp.terminator
+ }
+
+ %n_i64 = arith.constant 4 : i64
+ // CHECK: omp.parallel num_threads(%{{.*}} : i64)
+ omp.parallel num_threads(%n_i64 : i64) {
omp.terminator
}
@@ -113,8 +125,8 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
- // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
- omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) proc_bind(close) {
+ // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) private(%{{.*}} : memref<i32>) proc_bind(close)
+ omp.parallel num_threads(%num_threads : i32) if(%if_cond: i1) proc_bind(close) {
omp.terminator
}
@@ -347,14 +359,14 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
}
// CHECK-LABEL: omp_target
-func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
+func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
// Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
// CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, si32 ) -> ()
+ }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, i32 ) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -363,14 +375,14 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : si32) -> ()
}
// CHECK-LABEL: omp_target_pretty
-func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : si32) -> () {
+func.func @omp_target_pretty(%if_cond : i1, %device : si32, %num_threads : i32) -> () {
// CHECK: omp.target if({{.*}}) device({{.*}})
omp.target if(%if_cond) device(%device : si32) {
omp.terminator
}
// CHECK: omp.target if({{.*}}) device({{.*}}) nowait
- omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : si32) nowait {
+ omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : i32) nowait {
omp.terminator
}
More information about the Mlir-commits
mailing list