[Mlir-commits] [mlir] a60fda5 - [mlir][OpenMP] Restrict types for omp.parallel args

Shraiysh Vaishay llvmlistbot at llvm.org
Mon May 2 01:47:46 PDT 2022


Author: Shraiysh Vaishay
Date: 2022-05-02T14:17:34+05:30
New Revision: a60fda59dc6b1dda25cad26214b02d1f630319e7

URL: https://github.com/llvm/llvm-project/commit/a60fda59dc6b1dda25cad26214b02d1f630319e7
DIFF: https://github.com/llvm/llvm-project/commit/a60fda59dc6b1dda25cad26214b02d1f630319e7.diff

LOG: [mlir][OpenMP] Restrict types for omp.parallel args

This patch restricts the value of `if` clause expression to an I1 value.
It also restricts the value of `num_threads` clause expression to an I32
value.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D124142

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    flang/test/Lower/OpenMP/parallel.f90
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 93cb936e1e860..c3740123437f1 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -254,8 +254,10 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     if (const auto &ifClause =
             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
       auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-      ifClauseOperand = fir::getBase(
+      mlir::Value ifVal = fir::getBase(
           converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+      ifClauseOperand = firOpBuilder.createConvert(
+          currentLocation, firOpBuilder.getI1Type(), ifVal);
     } else if (const auto &numThreadsClause =
                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
                        &clause.u)) {

diff  --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90
index 849db5c3705f6..70dcee4ddd255 100644
--- a/flang/test/Lower/OpenMP/parallel.f90
+++ b/flang/test/Lower/OpenMP/parallel.f90
@@ -15,8 +15,13 @@ end subroutine parallel_simple
 !===============================================================================
 
 !FIRDialect-LABEL: func @_QPparallel_if
-subroutine parallel_if(alpha)
+subroutine parallel_if(alpha, beta, gamma)
    integer, intent(in) :: alpha
+   logical, intent(in) :: beta
+   logical(1) :: logical1
+   logical(2) :: logical2
+   logical(4) :: logical4
+   logical(8) :: logical8
 
    !OMPDialect: omp.parallel if(%{{.*}} : i1) {
    !$omp parallel if(alpha .le. 0)
@@ -46,6 +51,41 @@ subroutine parallel_if(alpha)
    !OMPDialect: omp.terminator
    !$omp end parallel
 
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(beta)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(logical1)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(logical2)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(logical4)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
+   !OMPDialect: omp.parallel if(%{{.*}} : i1) {
+   !$omp parallel if(logical8)
+   !FIRDialect: fir.call
+   call f1()
+   !OMPDialect: omp.terminator
+   !$omp end parallel
+
 end subroutine parallel_if
 
 !===============================================================================

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 39ccb70b41383..bc5a7ff89783f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -99,8 +99,8 @@ def ParallelOp : OpenMP_Op<"parallel", [
     of the parallel region.
   }];
 
-  let arguments = (ins Optional<AnyType>:$if_expr_var,
-             Optional<AnyType>:$num_threads_var,
+  let arguments = (ins Optional<I1>:$if_expr_var,
+             Optional<IntLikeType>:$num_threads_var,
              Variadic<AnyType>:$allocate_vars,
              Variadic<AnyType>:$allocators_vars,
              Variadic<OpenMP_PointerLikeType>:$reduction_vars,

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 57387b35abd04..15ec6f796b880 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -51,15 +51,15 @@ func.func @omp_terminator() -> () {
   omp.terminator
 }
 
-func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32) -> () {
-  // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
+func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
+  // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
 
   // test without if condition
-  // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
+  // CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
     "omp.parallel"(%num_threads, %data_var, %data_var) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref<i32>, memref<i32>) -> ()
+    }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (i32, memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -71,13 +71,13 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
     }) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref<i32>, memref<i32>) -> ()
 
   // test without allocate
-  // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32)
+  // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> ()
+    }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref<i32>, memref<i32>) -> ()
+  }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, i32, memref<i32>, memref<i32>) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
@@ -88,14 +88,26 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : s
   return
 }
 
-func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : si32, %allocator : si32) -> () {
+func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %allocator : si32) -> () {
  // CHECK: omp.parallel
  omp.parallel {
   omp.terminator
  }
 
- // CHECK: omp.parallel num_threads(%{{.*}} : si32)
- omp.parallel num_threads(%num_threads : si32) {
+ // CHECK: omp.parallel num_threads(%{{.*}} : i32)
+ omp.parallel num_threads(%num_threads : i32) {
+   omp.terminator
+ }
+
+ %n_index = arith.constant 2 : index
+ // CHECK: omp.parallel num_threads(%{{.*}} : index)
+ omp.parallel num_threads(%n_index : index) {
+   omp.terminator
+ }
+
+ %n_i64 = arith.constant 4 : i64
+ // CHECK: omp.parallel num_threads(%{{.*}} : i64)
+ omp.parallel num_threads(%n_i64 : i64) {
    omp.terminator
  }
 
@@ -113,8 +125,8 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
- // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
- omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1) proc_bind(close) {
+ // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) private(%{{.*}} : memref<i32>) proc_bind(close)
+ omp.parallel num_threads(%num_threads : i32) if(%if_cond: i1) proc_bind(close) {
    omp.terminator
  }
 
@@ -347,14 +359,14 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
 }
 
 // CHECK-LABEL: omp_target
-func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
+func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32) -> () {
 
     // Test with optional operands; if_expr, device, thread_limit, private, firstprivate and nowait.
     // CHECK: omp.target if({{.*}}) device({{.*}}) thread_limit({{.*}}) nowait
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, si32 ) -> ()
+    }) {nowait, operand_segment_sizes = dense<[1,1,1]>: vector<3xi32>} : ( i1, si32, i32 ) -> ()
 
     // CHECK: omp.barrier
     omp.barrier
@@ -363,14 +375,14 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : si32) -> ()
 }
 
 // CHECK-LABEL: omp_target_pretty
-func.func @omp_target_pretty(%if_cond : i1, %device : si32,  %num_threads : si32) -> () {
+func.func @omp_target_pretty(%if_cond : i1, %device : si32,  %num_threads : i32) -> () {
     // CHECK: omp.target if({{.*}}) device({{.*}})
     omp.target if(%if_cond) device(%device : si32) {
       omp.terminator
     }
 
     // CHECK: omp.target if({{.*}}) device({{.*}}) nowait
-    omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : si32) nowait {
+    omp.target if(%if_cond) device(%device : si32) thread_limit(%num_threads : i32) nowait {
       omp.terminator
     }
 


        


More information about the Mlir-commits mailing list