[flang] [llvm] [mlir] [flang][OpenMP][MLIR] Add MLIR op for loop directive (PR #113911)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 07:05:45 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-mlir-openmp

Author: Kareem Ergawy (ergawy)

<details>
<summary>Changes</summary>

Adds MLIR op that corresponds to the `loop` directive.

Parent PR #<!-- -->113662.

---
Full diff: https://github.com/llvm/llvm-project/pull/113911.diff


10 Files Affected:

- (modified) flang/include/flang/Parser/dump-parse-tree.h (+2) 
- (modified) flang/include/flang/Parser/parse-tree.h (+7) 
- (modified) flang/lib/Parser/openmp-parsers.cpp (+8) 
- (modified) flang/test/Parser/OpenMP/target-loop-unparse.f90 (+14-2) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMP.td (+12) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+25) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+44) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+24) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+30) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+40) 


``````````diff
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index ccdfe980f6f38c..ea02ae27347a1c 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -551,6 +551,8 @@ class ParseTreeDumper {
   NODE_ENUM(OmpGrainsizeClause, Prescriptiveness)
   NODE(parser, OmpNumTasksClause)
   NODE_ENUM(OmpNumTasksClause, Prescriptiveness)
+  NODE(parser, OmpBindClause)
+  NODE_ENUM(OmpBindClause, Type)
   NODE(parser, OmpProcBindClause)
   NODE_ENUM(OmpProcBindClause, Type)
   NODE_ENUM(OmpReductionClause, ReductionModifier)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 2a312e29a3a44d..a08e269003aaab 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3702,6 +3702,13 @@ struct OmpNumTasksClause {
   std::tuple<std::optional<Prescriptiveness>, ScalarIntExpr> t;
 };
 
+// OMP 5.2 11.7.1 bind-clause ->
+//                  BIND( PARALLEL | TEAMS | THREAD )
+struct OmpBindClause {
+  ENUM_CLASS(Type, Parallel, Teams, Thread)
+  WRAPPER_CLASS_BOILERPLATE(OmpBindClause, Type);
+};
+
 // OpenMP Clauses
 struct OmpClause {
   UNION_CLASS_BOILERPLATE(OmpClause);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index ae0c351fed56d1..06520d3fc2a56c 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -427,6 +427,12 @@ TYPE_PARSER(construct<OmpLastprivateClause>(
         pure(OmpLastprivateClause::LastprivateModifier::Conditional) / ":"),
     Parser<OmpObjectList>{}))
 
+// OMP 5.2 11.7.1 BIND ( PARALLEL | TEAMS | THREAD )
+TYPE_PARSER(construct<OmpBindClause>(
+    "PARALLEL" >> pure(OmpBindClause::Type::Parallel) ||
+    "TEAMS" >> pure(OmpBindClause::Type::Teams) ||
+    "THREAD" >> pure(OmpBindClause::Type::Thread)))
+
 TYPE_PARSER(
     "ACQUIRE" >> construct<OmpClause>(construct<OmpClause::Acquire>()) ||
     "ACQ_REL" >> construct<OmpClause>(construct<OmpClause::AcqRel>()) ||
@@ -441,6 +447,8 @@ TYPE_PARSER(
     "ATOMIC_DEFAULT_MEM_ORDER" >>
         construct<OmpClause>(construct<OmpClause::AtomicDefaultMemOrder>(
             parenthesized(Parser<OmpAtomicDefaultMemOrderClause>{}))) ||
+    "BIND" >> construct<OmpClause>(construct<OmpClause::Bind>(
+                  parenthesized(Parser<OmpBindClause>{}))) ||
     "COLLAPSE" >> construct<OmpClause>(construct<OmpClause::Collapse>(
                       parenthesized(scalarIntConstantExpr))) ||
     "COPYIN" >> construct<OmpClause>(construct<OmpClause::Copyin>(
diff --git a/flang/test/Parser/OpenMP/target-loop-unparse.f90 b/flang/test/Parser/OpenMP/target-loop-unparse.f90
index 3ee2fcef075a37..b2047070496527 100644
--- a/flang/test/Parser/OpenMP/target-loop-unparse.f90
+++ b/flang/test/Parser/OpenMP/target-loop-unparse.f90
@@ -1,6 +1,8 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp -fopenmp-version=50 %s | \
+! RUN:   FileCheck --ignore-case %s
 
-! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
-! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp -fopenmp-version=50 %s | \
+! RUN:   FileCheck --check-prefix="PARSE-TREE" %s
 
 ! Check for parsing of loop directive
 
@@ -14,6 +16,16 @@ subroutine test_loop
    j = j + 1
   end do
   !$omp end loop
+
+  !PARSE-TREE: OmpBeginLoopDirective
+  !PARSE-TREE-NEXT: OmpLoopDirective -> llvm::omp::Directive = loop
+  !PARSE-TREE-NEXT: OmpClauseList -> OmpClause -> Bind -> OmpBindClause -> Type = Thread
+  !CHECK: !$omp loop
+  !$omp loop bind(thread)
+  do i=1,10
+   j = j + 1
+  end do
+  !$omp end loop
 end subroutine
 
 subroutine test_target_loop
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 70179bab475779..9668ab1e90ecf3 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -71,9 +71,21 @@ def OMPC_AtomicDefaultMemOrder : Clause<"atomic_default_mem_order"> {
   let clangClass = "OMPAtomicDefaultMemOrderClause";
   let flangClass = "OmpAtomicDefaultMemOrderClause";
 }
+
+def OMP_BIND_parallel : ClauseVal<"parallel",1,1> {}
+def OMP_BIND_teams : ClauseVal<"teams",2,1> {}
+def OMP_BIND_thread : ClauseVal<"thread",3,1> { let isDefault = true; }
 def OMPC_Bind : Clause<"bind"> {
   let clangClass = "OMPBindClause";
+  let flangClass = "OmpBindClause";
+  let enumClauseValue = "BindKind";
+  let allowedClauseValues = [
+    OMP_BIND_parallel,
+    OMP_BIND_teams,
+    OMP_BIND_thread
+  ];
 }
+
 def OMP_CANCELLATION_CONSTRUCT_Parallel : ClauseVal<"parallel", 1, 1> {}
 def OMP_CANCELLATION_CONSTRUCT_Loop : ClauseVal<"loop", 2, 1> {}
 def OMP_CANCELLATION_CONSTRUCT_Sections : ClauseVal<"sections", 3, 1> {}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..855deab94b2f16 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -107,6 +107,31 @@ class OpenMP_CancelDirectiveNameClauseSkip<
 
 def OpenMP_CancelDirectiveNameClause : OpenMP_CancelDirectiveNameClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [11.7.1] `bind` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_BindClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    OptionalAttr<BindKindAttr>:$bind_kind
+  );
+
+  let optAssemblyFormat = [{
+    `bind` `(` custom<ClauseAttr>($bind_kind) `)`
+  }];
+
+  let description = [{
+    The `bind` clause specifies the binding region of the construct on which it
+    appears.
+  }];
+}
+
+def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [5.7.2] `copyprivate` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..4f16108a7a585f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -382,6 +382,50 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 // 2.9.2 Workshare Loop Construct
 //===----------------------------------------------------------------------===//
 
+def LoopOp : OpenMP_Op<"loop", traits = [
+    DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator, SingleBlock,
+    AttrSizedOperandSegments
+  ], clauses = [
+    OpenMP_BindClause, OpenMP_PrivateClause, OpenMP_OrderClause,
+    OpenMP_ReductionClause
+  ], singleRegion = true> {
+  let summary = "loop construct";
+  let description = [{
+    A loop construct specifies that the logical iterations of the associated loops
+    may execute concurrently and permits the encountering threads to execute the
+    loop accordingly. A loop construct is a worksharing construct if its binding
+    region is the innermost enclosing parallel region. Otherwise it is not a work-
+    sharing region. The directive asserts that the iterations of the associated
+    loops may execute in any order, including concurrently. Each logical iteration
+    is executed once per instance of the loop region that is encountered by exactly
+    one thread that is a member of the binding thread set.
+
+    The body region can only contain a single block which must contain a single
+    operation, this operation must be an `omp.loop_nest`.
+
+    ```
+    omp.loop <clauses> {
+      omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+        %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+        %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+        %sum = arith.addf %a, %b : f32
+        store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+        omp.yield
+      }
+    }
+    ```
+  }] # clausesDescription;
+
+  let assemblyFormat = clausesAssemblyFormat # [{
+    custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
+        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $reduction_syms) attr-dict
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+}
+
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..a420faa97a28c0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1948,6 +1948,30 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LoopOp::verify() {
+  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
+                                getReductionByref());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+    Region& region = getRegion();
+
+    // Minimal amount of checks to verify the only nested op is an
+    // `omp.loop_nest`. A more extensive vierfication is done by the
+    // `LoopWrapperInterface` trait but the difference is that `omp.loop` cannot
+    // have another nested `LoopWrapperInterface`.
+    if (range_size(region.getOps()) != 1 ||
+        !isa<LoopNestOp>(*region.op_begin()))
+      return emitError() << "`omp.loop` expected to have a single nested "
+                            "operation which is a `omp.loop_nest`";
+
+    return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index fd89ec31c64a60..194daf8c2edf34 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2577,3 +2577,33 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
   } {omp.composite}
   return
 }
+
+// -----
+
+func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
+
+  // expected-error @below {{`omp.loop` expected to have a single nested operation which is a `omp.loop_nest`}}
+  omp.loop {
+    omp.simd {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    } {omp.composite}
+  }
+
+  return
+}
+
+// -----
+
+func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
+
+  // expected-error @below {{custom op 'omp.loop' invalid clause value: 'dummy_value'}}
+  omp.loop bind(dummy_value) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6f11b451fa00a3..47b8890b851c10 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2749,3 +2749,43 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   return
 }
+
+// CHECK-LABEL: omp_loop
+func.func @omp_loop(%lb : index, %ub : index, %step : index) {
+  // CHECK: omp.loop {
+  omp.loop {
+    // CHECK: omp.loop_nest {{.*}} {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      // CHECK: omp.yield
+      omp.yield
+    // CHECK: }
+    }
+  // CHECK: }
+  }
+
+  // CHECK: omp.loop bind(teams) {
+  omp.loop bind(teams) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  // CHECK: }
+  }
+
+  // CHECK: omp.loop bind(parallel) {
+  omp.loop bind(parallel) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  // CHECK: }
+  }
+
+  // CHECK: omp.loop bind(thread) {
+  omp.loop bind(thread) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  // CHECK: }
+  }
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/113911


More information about the llvm-commits mailing list