[Mlir-commits] [mlir] 660832c - [OpenMP, MLIR] Translation of parallel operation: num_threads, if clauses 3/n

Kiran Chandramohan llvmlistbot at llvm.org
Fri Aug 7 13:55:44 PDT 2020


Author: Kiran Chandramohan
Date: 2020-08-07T20:54:24Z
New Revision: 660832c4e744108ecb45b697e51be72482cacd42

URL: https://github.com/llvm/llvm-project/commit/660832c4e744108ecb45b697e51be72482cacd42
DIFF: https://github.com/llvm/llvm-project/commit/660832c4e744108ecb45b697e51be72482cacd42.diff

LOG: [OpenMP,MLIR] Translation of parallel operation: num_threads, if clauses 3/n

This simple patch translates the num_threads and if clauses of the parallel
operation. Also includes test cases.
A minor change was made to parsing of the if clause to parse AnyType and
return the parsed type. Updates to test cases also.

Reviewed by: SouraVX
Differential Revision: https://reviews.llvm.org/D84798

Added: 
    

Modified: 
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9159e87509c6..217588289e85 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -69,7 +69,7 @@ static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
   p << "omp.parallel";
 
   if (auto ifCond = op.if_expr_var())
-    p << " if(" << ifCond << ")";
+    p << " if(" << ifCond << " : " << ifCond.getType() << ")";
 
   if (auto threads = op.num_threads_var())
     p << " num_threads(" << threads << " : " << threads.getType() << ")";
@@ -124,7 +124,7 @@ static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
 /// Note that each clause can only appear once in the clase-list.
 static ParseResult parseParallelOp(OpAsmParser &parser,
                                    OperationState &result) {
-  OpAsmParser::OperandType ifCond;
+  std::pair<OpAsmParser::OperandType, Type> ifCond;
   std::pair<OpAsmParser::OperandType, Type> numThreads;
   llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
   llvm::SmallVector<Type, 4> privateTypes;
@@ -152,8 +152,8 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       // Fail if there was already another if condition
       if (segments[ifClausePos])
         return allowedOnce(parser, "if", opName);
-      if (parser.parseLParen() || parser.parseOperand(ifCond) ||
-          parser.parseRParen())
+      if (parser.parseLParen() || parser.parseOperand(ifCond.first) ||
+          parser.parseColonType(ifCond.second) || parser.parseRParen())
         return failure();
       segments[ifClausePos] = 1;
     } else if (keyword == "num_threads") {
@@ -209,7 +209,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       auto attr = parser.getBuilder().getStringAttr(attrval);
       result.addAttribute("default_val", attr);
     } else if (keyword == "proc_bind") {
-      // fail if there was already another default clause
+      // fail if there was already another proc_bind clause
       if (procBind)
         return allowedOnce(parser, "proc_bind", opName);
       procBind = true;
@@ -228,8 +228,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
 
   // Add if parameter
   if (segments[ifClausePos]) {
-    parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(),
-                          result.operands);
+    parser.resolveOperand(ifCond.first, ifCond.second, result.operands);
   }
 
   // Add num_threads parameter

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index f58d02845ccd..b3b7e4c7afa5 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -454,7 +454,12 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
   // TODO: The various operands of parallel operation are not handled.
   // Parallel operation is created with some default options for now.
   llvm::Value *ifCond = nullptr;
+  if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
+    ifCond = valueMapping.lookup(ifExprVar);
   llvm::Value *numThreads = nullptr;
+  if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
+    numThreads = valueMapping.lookup(numThreadsVar);
+  // TODO: Is the Parallel construct cancellable?
   bool isCancellable = false;
   // TODO: Determine the actual alloca insertion point, e.g., the function
   // entry or the alloca insertion point as provided by the body callback

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 00f9726b119d..88f61e7f7916 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -12,7 +12,7 @@ func @unknown_clause() {
 
 func @if_once(%n : i1) {
   // expected-error at +1 {{at most one if clause can appear on the omp.parallel operation}}
-  omp.parallel if(%n) if(%n) {
+  omp.parallel if(%n : i1) if(%n : i1) {
   }
 
   return

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 85343f985501..e3e7afaff541 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -99,14 +99,14 @@ func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads :
 
   // CHECK omp.parallel shared(%{{.*}} : memref<i32>) copyin(%{{.*}} : memref<i32>, %{{.*}} : memref<i32>)
   omp.parallel shared(%data_var : memref<i32>) copyin(%data_var : memref<i32>, %data_var : memref<i32>) {
-    omp.parallel if(%if_cond) {
+    omp.parallel if(%if_cond: i1) {
       omp.terminator
     }
     omp.terminator
   }
 
   // CHECK omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) private(%{{.*}} : memref<i32>) proc_bind(close)
-  omp.parallel num_threads(%num_threads : si32) if(%if_cond) 
+  omp.parallel num_threads(%num_threads : si32) if(%if_cond: i1)
                private(%data_var : memref<i32>) proc_bind(close) {
     omp.terminator
   }

diff  --git a/mlir/test/Target/openmp-llvm.mlir b/mlir/test/Target/openmp-llvm.mlir
index c8acd8022b2b..60462fee3b97 100644
--- a/mlir/test/Target/openmp-llvm.mlir
+++ b/mlir/test/Target/openmp-llvm.mlir
@@ -78,3 +78,100 @@ llvm.func @test_omp_parallel_2() -> () {
   // CHECK-LABEL: omp.par.region2:
   // CHECK: call void @body(i64 43)
   // CHECK: br label %omp.par.pre_finalize
+
+// CHECK: define void @test_omp_parallel_num_threads_1(i32 %[[NUM_THREADS_VAR_1:.*]])
+llvm.func @test_omp_parallel_num_threads_1(%arg0: !llvm.i32) -> () {
+  // CHECK: %[[GTN_NUM_THREADS_VAR_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_1:.*]])
+  // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_1]], i32 %[[GTN_NUM_THREADS_VAR_1]], i32 %[[NUM_THREADS_VAR_1]])
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_1:.*]] to {{.*}}
+  omp.parallel num_threads(%arg0: !llvm.i32) {
+    omp.barrier
+    omp.terminator
+  }
+
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_1]]
+  // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_num_threads_2()
+llvm.func @test_omp_parallel_num_threads_2() -> () {
+  %0 = llvm.mlir.constant(4 : index) : !llvm.i32
+  // CHECK: %[[GTN_NUM_THREADS_VAR_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_2:.*]])
+  // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_2]], i32 %[[GTN_NUM_THREADS_VAR_2]], i32 4)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_2:.*]] to {{.*}}
+  omp.parallel num_threads(%0: !llvm.i32) {
+    omp.barrier
+    omp.terminator
+  }
+
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_2]]
+  // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_num_threads_3()
+llvm.func @test_omp_parallel_num_threads_3() -> () {
+  %0 = llvm.mlir.constant(4 : index) : !llvm.i32
+  // CHECK: %[[GTN_NUM_THREADS_VAR_3_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_1:.*]])
+  // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_1]], i32 %[[GTN_NUM_THREADS_VAR_3_1]], i32 4)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_1:.*]] to {{.*}}
+  omp.parallel num_threads(%0: !llvm.i32) {
+    omp.barrier
+    omp.terminator
+  }
+  %1 = llvm.mlir.constant(8 : index) : !llvm.i32
+  // CHECK: %[[GTN_NUM_THREADS_VAR_3_2:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[GTN_SI_VAR_3_2:.*]])
+  // CHECK: call void @__kmpc_push_num_threads(%struct.ident_t* @[[GTN_SI_VAR_3_2]], i32 %[[GTN_NUM_THREADS_VAR_3_2]], i32 8)
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_NUM_THREADS_3_2:.*]] to {{.*}}
+  omp.parallel num_threads(%1: !llvm.i32) {
+    omp.barrier
+    omp.terminator
+  }
+
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_2]]
+  // CHECK: call void @__kmpc_barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_NUM_THREADS_3_1]]
+  // CHECK: call void @__kmpc_barrier
+
+// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
+llvm.func @test_omp_parallel_if_1(%arg0: !llvm.i32) -> () {
+
+// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
+  %0 = llvm.mlir.constant(0 : index) : !llvm.i32
+  %1 = llvm.icmp "slt" %arg0, %0 : !llvm.i32
+
+// CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* @[[SI_VAR_IF_1:.*]])
+// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]]
+// CHECK: [[IF_COND_TRUE_BLOCK_1]]:
+// CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]]
+// CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]:
+// CHECK: call void {{.*}} @__kmpc_fork_call(%struct.ident_t* @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]] to void
+// CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]]
+// CHECK: [[OUTLINED_EXIT_IF_1]]:
+// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]]
+// CHECK: [[OUTLINED_EXIT_IF_2]]:
+// CHECK: br label %[[RETURN_BLOCK_IF_1:.*]]
+// CHECK: [[IF_COND_FALSE_BLOCK_1]]:
+// CHECK: call void @__kmpc_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
+// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]]
+// CHECK: call void @__kmpc_end_serialized_parallel(%struct.ident_t* @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
+// CHECK: br label %[[RETURN_BLOCK_IF_1]]
+  omp.parallel if(%1 : !llvm.i1) {
+    omp.barrier
+    omp.terminator
+  }
+
+// CHECK: [[RETURN_BLOCK_IF_1]]:
+// CHECK: ret void
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_IF_1]]
+  // CHECK: call void @__kmpc_barrier


        


More information about the Mlir-commits mailing list