[Mlir-commits] [flang] [llvm] [mlir] [flang][OpenMP] WIP: Rewrite `omp.loop` to semantically equivalent ops (PR #115443)

Kareem Ergawy llvmlistbot at llvm.org
Fri Nov 8 01:37:04 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/115443

>From 095b143d08d0f1d0e576015d603cb5443d2801f7 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 28 Oct 2024 00:04:15 -0500
Subject: [PATCH 1/4] [flang][OpenMP][MLIR] Add MLIR op for `loop` directive

Adds MLIR op that corresponds to the `loop` directive.
---
 llvm/include/llvm/Frontend/OpenMP/OMP.td      | 11 +++++
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 25 ++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 44 ++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 17 +++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 46 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 40 ++++++++++++++++
 6 files changed, 183 insertions(+)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 0fc0f066c2c43c..d1cc753b7daf02 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -71,10 +71,21 @@ def OMPC_AtomicDefaultMemOrder : Clause<"atomic_default_mem_order"> {
   let clangClass = "OMPAtomicDefaultMemOrderClause";
   let flangClass = "OmpAtomicDefaultMemOrderClause";
 }
+
+def OMP_BIND_parallel : ClauseVal<"parallel",1,1> {}
+def OMP_BIND_teams : ClauseVal<"teams",2,1> {}
+def OMP_BIND_thread : ClauseVal<"thread",3,1> { let isDefault = true; }
 def OMPC_Bind : Clause<"bind"> {
   let clangClass = "OMPBindClause";
   let flangClass = "OmpBindClause";
+  let enumClauseValue = "BindKind";
+  let allowedClauseValues = [
+    OMP_BIND_parallel,
+    OMP_BIND_teams,
+    OMP_BIND_thread
+  ];
 }
+
 def OMP_CANCELLATION_CONSTRUCT_Parallel : ClauseVal<"parallel", 1, 1> {}
 def OMP_CANCELLATION_CONSTRUCT_Loop : ClauseVal<"loop", 2, 1> {}
 def OMP_CANCELLATION_CONSTRUCT_Sections : ClauseVal<"sections", 3, 1> {}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..855deab94b2f16 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -107,6 +107,31 @@ class OpenMP_CancelDirectiveNameClauseSkip<
 
 def OpenMP_CancelDirectiveNameClause : OpenMP_CancelDirectiveNameClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [11.7.1] `bind` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_BindClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    OptionalAttr<BindKindAttr>:$bind_kind
+  );
+
+  let optAssemblyFormat = [{
+    `bind` `(` custom<ClauseAttr>($bind_kind) `)`
+  }];
+
+  let description = [{
+    The `bind` clause specifies the binding region of the construct on which it
+    appears.
+  }];
+}
+
+def OpenMP_BindClause : OpenMP_BindClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [5.7.2] `copyprivate` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5fd8184fe0e0f7..a0da3db124d1f4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -382,6 +382,50 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
 // 2.9.2 Workshare Loop Construct
 //===----------------------------------------------------------------------===//
 
+def LoopOp : OpenMP_Op<"loop", traits = [
+    AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    NoTerminator, SingleBlock
+  ], clauses = [
+    OpenMP_BindClause, OpenMP_PrivateClause, OpenMP_OrderClause,
+    OpenMP_ReductionClause
+  ], singleRegion = true> {
+  let summary = "loop construct";
+  let description = [{
+    A loop construct specifies that the logical iterations of the associated loops
+    may execute concurrently and permits the encountering threads to execute the
+    loop accordingly. A loop construct can have 3 different types of binding:
+      1. teams: in which case the binding region is the innermost enclosing `teams`
+         region.
+      2. parallel: in which case the binding region is the innermost enclosing `parallel`
+         region.
+      3. thread: in which case the binding region is not defined.
+
+    The body region can only contain a single block which must contain a single
+    operation, this operation must be an `omp.loop_nest`.
+
+    ```
+    omp.loop <clauses> {
+      omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+        %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+        %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+        %sum = arith.addf %a, %b : f32
+        store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+        omp.yield
+      }
+    }
+    ```
+  }] # clausesDescription;
+
+  let assemblyFormat = clausesAssemblyFormat # [{
+    custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
+        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $reduction_syms) attr-dict
+  }];
+
+  let hasVerifier = 1;
+  let hasRegionVerifier = 1;
+}
+
 def WsloopOp : OpenMP_Op<"wsloop", traits = [
     AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>,
     DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4a27a5ed8eb74b..228c2d034ad4ad 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1948,6 +1948,23 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult LoopOp::verify() {
+  return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
+                                getReductionByref());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+  if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
+      getNestedWrapper())
+    return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index db941d401d52dc..aa41eea44f3ef4 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2577,3 +2577,49 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
   } {omp.composite}
   return
 }
+
+// -----
+
+func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
+
+  // expected-error @below {{`omp.loop` expected to be a standalone loop wrapper}}
+  omp.loop {
+    omp.simd {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    } {omp.composite}
+  }
+
+  return
+}
+
+// -----
+
+func.func @omp_loop_invalid_nesting2(%lb : index, %ub : index, %step : index) {
+
+  omp.simd {
+    // expected-error @below {{`omp.loop` expected to be a standalone loop wrapper}}
+    omp.loop {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        omp.yield
+      }
+    } {omp.composite}
+  }
+
+  return
+}
+
+// -----
+
+func.func @omp_loop_invalid_binding(%lb : index, %ub : index, %step : index) {
+
+  // expected-error @below {{custom op 'omp.loop' invalid clause value: 'dummy_value'}}
+  omp.loop bind(dummy_value) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b606f9eb708cf3..4f5cc696cada81 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2749,3 +2749,43 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   return
 }
+
+// CHECK-LABEL: omp_loop
+func.func @omp_loop(%lb : index, %ub : index, %step : index) {
+  // CHECK: omp.loop {
+  omp.loop {
+    // CHECK: omp.loop_nest {{.*}} {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      // CHECK: omp.yield
+      omp.yield
+    }
+    // CHECK: }
+  }
+  // CHECK: }
+
+  // CHECK: omp.loop bind(teams) {
+  omp.loop bind(teams) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  // CHECK: }
+
+  // CHECK: omp.loop bind(parallel) {
+  omp.loop bind(parallel) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  // CHECK: }
+
+  // CHECK: omp.loop bind(thread) {
+  omp.loop bind(thread) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  // CHECK: }
+
+  return
+}

>From 8df37d37d1dc16558c20ed0004298c93f182eb30 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 30 Oct 2024 04:40:14 -0500
Subject: [PATCH 2/4] [flang][OpenMP] Add basic support to lower `loop` to MLIR

Adds initial support for lowering the `loop` directive to MLIR.

The PR includes basic suport and testing for the following clauses:
 * `collapse`
 * `order`
 * `private`
 * `reduction`
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 48 +++++++++-
 .../test/Lower/OpenMP/Todo/loop-directive.f90 | 15 ----
 flang/test/Lower/OpenMP/loop-directive.f90    | 90 +++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  4 +
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  | 14 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 11 +++
 6 files changed, 159 insertions(+), 23 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/loop-directive.f90
 create mode 100644 flang/test/Lower/OpenMP/loop-directive.f90

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 4f9e2347308aa1..e3dc780321f89f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1176,6 +1176,18 @@ genLoopNestClauses(lower::AbstractConverter &converter,
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
 }
 
+static void genLoopClauses(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    const List<Clause> &clauses, mlir::Location loc,
+    mlir::omp::LoopOperands &clauseOps,
+    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processOrder(clauseOps);
+  cp.processReduction(loc, clauseOps, reductionSyms);
+  cp.processTODO<clause::Bind, clause::Lastprivate>(
+      loc, llvm::omp::Directive::OMPD_loop);
+}
+
 static void genMaskedClauses(lower::AbstractConverter &converter,
                              semantics::SemanticsContext &semaCtx,
                              lower::StatementContext &stmtCtx,
@@ -2051,6 +2063,40 @@ static void genStandaloneDo(lower::AbstractConverter &converter,
                 llvm::omp::Directive::OMPD_do, dsp);
 }
 
+static void genLoopOp(lower::AbstractConverter &converter,
+                      lower::SymMap &symTable,
+                      semantics::SemanticsContext &semaCtx,
+                      lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  mlir::omp::LoopOperands loopClauseOps;
+  llvm::SmallVector<const semantics::Symbol *> loopReductionSyms;
+  genLoopClauses(converter, semaCtx, item->clauses, loc, loopClauseOps,
+                 loopReductionSyms);
+
+  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
+                           /*shouldCollectPreDeterminedSymbols=*/true,
+                           /*useDelayedPrivatization=*/true, &symTable);
+  dsp.processStep1(&loopClauseOps);
+
+  mlir::omp::LoopNestOperands loopNestClauseOps;
+  llvm::SmallVector<const semantics::Symbol *> iv;
+  genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
+                     loopNestClauseOps, iv);
+
+  EntryBlockArgs loopArgs;
+  loopArgs.priv.syms = dsp.getDelayedPrivSymbols();
+  loopArgs.priv.vars = loopClauseOps.privateVars;
+  loopArgs.reduction.syms = loopReductionSyms;
+  loopArgs.reduction.vars = loopClauseOps.reductionVars;
+
+  auto loopOp =
+      genWrapperOp<mlir::omp::LoopOp>(converter, loc, loopClauseOps, loopArgs);
+  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                loopNestClauseOps, iv, {{loopOp, loopArgs}},
+                llvm::omp::Directive::OMPD_loop, dsp);
+}
+
 static void genStandaloneParallel(lower::AbstractConverter &converter,
                                   lower::SymMap &symTable,
                                   semantics::SemanticsContext &semaCtx,
@@ -2479,7 +2525,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genStandaloneDo(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_loop:
-    TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+    genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_masked:
     genMaskedOp(converter, symTable, semaCtx, eval, loc, queue, item);
diff --git a/flang/test/Lower/OpenMP/Todo/loop-directive.f90 b/flang/test/Lower/OpenMP/Todo/loop-directive.f90
deleted file mode 100644
index f1aea70458aa6c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/loop-directive.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! This test checks lowering of OpenMP loop Directive.
-
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unhandled directive loop
-subroutine test_loop()
-  integer :: i, j = 1
-  !$omp loop
-  do i=1,10
-   j = j + 1
-  end do
-  !$omp end loop
-end subroutine
-
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
new file mode 100644
index 00000000000000..0a27a18aa8b52f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -0,0 +1,90 @@
+! This test checks lowering of OpenMP loop Directive.
+
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! CHECK: omp.declare_reduction @[[RED:add_reduction_i32]] : i32
+! CHECK: omp.private {type = private} @[[DUMMY_PRIV:.*test_privateEdummy_private.*]] : !fir.ref<i32>
+! CHECK: omp.private {type = private} @[[I_PRIV:.*test_no_clausesEi.*]] : !fir.ref<i32>
+
+! CHECK-LABEL: func.func @_QPtest_no_clauses
+subroutine test_no_clauses()
+  integer :: i, j, dummy = 1
+
+  ! CHECK: omp.loop private(@[[I_PRIV]] %{{.*}}#0 -> %[[ARG:.*]] : !fir.ref<i32>) {
+  ! CHECK:   omp.loop_nest (%[[IV:.*]]) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[ARG_DECL:.*]]:2 = hlfir.declare %[[ARG]]
+  ! CHECK:     fir.store %[[IV]] to %[[ARG_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_collapse
+subroutine test_collapse()
+  integer :: i, j, dummy = 1
+  ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}}, @{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}, %{{.*}}) : i32 {{.*}} {
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop collapse(2)
+  do i=1,10
+    do j=2,20
+     dummy = dummy + 1
+    end do
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_private
+subroutine test_private()
+  integer :: i, dummy = 1
+  ! CHECK: omp.loop private(@[[DUMMY_PRIV]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]], @{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_privateEdummy"}
+  ! CHECK:     %{{.*}} = fir.load %[[DUMMY_DECL]]#0
+  ! CHECK:     hlfir.assign %{{.*}} to %[[DUMMY_DECL]]#0
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop private(dummy)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+
+! CHECK-LABEL: func.func @_QPtest_order
+subroutine test_order()
+  integer :: i, dummy = 1
+  ! CHECK: omp.loop order(reproducible:concurrent) private(@{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK: }
+  !$omp loop order(concurrent)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_reduction
+subroutine test_reduction()
+  integer :: i, dummy = 1
+
+  ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
+  ! CHECK-SAME:  (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
+  ! CHECK:     %{{.*}} = fir.load %[[DUMMY_DECL]]#0
+  ! CHECK:     hlfir.assign %{{.*}} to %[[DUMMY_DECL]]#0
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop reduction(+:dummy)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a0da3db124d1f4..b01fda993f5e78 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -422,6 +422,10 @@ def LoopOp : OpenMP_Op<"loop", traits = [
         $reduction_syms) attr-dict
   }];
 
+  let builders = [
+    OpBuilder<(ins CArg<"const LoopOperands &">:$clauses)>
+  ];
+
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index aa824a95b1574a..58fd3d565fce50 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -234,11 +234,11 @@ void mlir::configureOpenMPToLLVMConversionLegality(
   });
   target.addDynamicallyLegalOp<
       omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
-      omp::DistributeOp, omp::LoopNestOp, omp::MasterOp, omp::OrderedRegionOp,
-      omp::ParallelOp, omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp,
-      omp::SimdOp, omp::SingleOp, omp::TargetDataOp, omp::TargetOp,
-      omp::TaskgroupOp, omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
-      omp::WsloopOp>([&](Operation *op) {
+      omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
+      omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp,
+      omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp,
+      omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp,
+      omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) {
     return std::all_of(op->getRegions().begin(), op->getRegions().end(),
                        [&](Region &region) {
                          return typeConverter.isLegal(&region);
@@ -275,8 +275,8 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionOpConversion<omp::AtomicCaptureOp>,
       RegionOpConversion<omp::CriticalOp>,
       RegionOpConversion<omp::DistributeOp>,
-      RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::MaskedOp>,
-      RegionOpConversion<omp::MasterOp>,
+      RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
+      RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
       RegionOpConversion<omp::OrderedRegionOp>,
       RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
       RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 228c2d034ad4ad..5809d06e6a9d39 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1952,6 +1952,17 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
 // LoopOp
 //===----------------------------------------------------------------------===//
 
+void LoopOp::build(OpBuilder &builder, OperationState &state,
+                   const LoopOperands &clauses) {
+  MLIRContext *ctx = builder.getContext();
+
+  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
+                makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
+                clauses.orderMod, clauses.reductionVars,
+                makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                makeArrayAttr(ctx, clauses.reductionSyms));
+}
+
 LogicalResult LoopOp::verify() {
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());

>From 3afcb7e1293914c58e099f55e90bb9973df3e57f Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 30 Oct 2024 07:33:39 -0500
Subject: [PATCH 3/4] [flang][OpenMP] Add MLIR lowering for `loop ... bind`

Extends MLIR lowering support for the `loop` directive by adding
lowering support for the `bind` clause.
---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 28 ++++++++++++++++++++++
 flang/lib/Lower/OpenMP/ClauseProcessor.h   |  1 +
 flang/lib/Lower/OpenMP/Clauses.cpp         | 15 ++++++++++--
 flang/lib/Lower/OpenMP/OpenMP.cpp          |  4 ++--
 flang/test/Lower/OpenMP/loop-directive.f90 | 12 ++++++++++
 5 files changed, 56 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e768c1cbc0784a..0191669e100b45 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -98,6 +98,25 @@ genAllocateClause(lower::AbstractConverter &converter,
   genObjectList(objects, converter, allocateOperands);
 }
 
+static mlir::omp::ClauseBindKindAttr
+genBindKindAttr(fir::FirOpBuilder &firOpBuilder,
+                const omp::clause::Bind &clause) {
+  mlir::omp::ClauseBindKind bindKind;
+  switch (clause.v) {
+  case omp::clause::Bind::Binding::Teams:
+    bindKind = mlir::omp::ClauseBindKind::Teams;
+    break;
+  case omp::clause::Bind::Binding::Parallel:
+    bindKind = mlir::omp::ClauseBindKind::Parallel;
+    break;
+  case omp::clause::Bind::Binding::Thread:
+    bindKind = mlir::omp::ClauseBindKind::Thread;
+    break;
+  }
+  return mlir::omp::ClauseBindKindAttr::get(firOpBuilder.getContext(),
+                                            bindKind);
+}
+
 static mlir::omp::ClauseProcBindKindAttr
 genProcBindKindAttr(fir::FirOpBuilder &firOpBuilder,
                     const omp::clause::ProcBind &clause) {
@@ -204,6 +223,15 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
 // ClauseProcessor unique clauses
 //===----------------------------------------------------------------------===//
 
+bool ClauseProcessor::processBind(mlir::omp::BindClauseOps &result) const {
+  if (auto *clause = findUniqueClause<omp::clause::Bind>()) {
+    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+    result.bindKind = genBindKindAttr(firOpBuilder, *clause);
+    return true;
+  }
+  return false;
+}
+
 bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
     mlir::omp::LoopRelatedClauseOps &result,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index f34121c70d0b44..772e5415d2da27 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -53,6 +53,7 @@ class ClauseProcessor {
       : converter(converter), semaCtx(semaCtx), clauses(clauses) {}
 
   // 'Unique' clauses: They can appear at most once in the clause list.
+  bool processBind(mlir::omp::BindClauseOps &result) const;
   bool
   processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
                   mlir::omp::LoopRelatedClauseOps &result,
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index 46caafeef8e4a8..ff4af737b27fdb 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -484,8 +484,19 @@ AtomicDefaultMemOrder make(const parser::OmpClause::AtomicDefaultMemOrder &inp,
 
 Bind make(const parser::OmpClause::Bind &inp,
           semantics::SemanticsContext &semaCtx) {
-  // inp -> empty
-  llvm_unreachable("Empty: bind");
+  // inp.v -> parser::OmpBindClause
+  using wrapped = parser::OmpBindClause;
+
+  CLAUSET_ENUM_CONVERT( //
+      convert, wrapped::Type, Bind::Binding,
+      // clang-format off
+      MS(Teams, Teams)
+      MS(Parallel, Parallel)
+      MS(Thread, Thread)
+      // clang-format on
+  );
+
+  return Bind{/*Binding=*/convert(inp.v.v)};
 }
 
 // CancellationConstructType: empty
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index e3dc780321f89f..4c733bbbd7b2a3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1182,10 +1182,10 @@ static void genLoopClauses(
     mlir::omp::LoopOperands &clauseOps,
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processBind(clauseOps);
   cp.processOrder(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
-  cp.processTODO<clause::Bind, clause::Lastprivate>(
-      loc, llvm::omp::Directive::OMPD_loop);
+  cp.processTODO<clause::Lastprivate>(loc, llvm::omp::Directive::OMPD_loop);
 }
 
 static void genMaskedClauses(lower::AbstractConverter &converter,
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
index 0a27a18aa8b52f..6ad2393920e9b4 100644
--- a/flang/test/Lower/OpenMP/loop-directive.f90
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -88,3 +88,15 @@ subroutine test_reduction()
   end do
   !$omp end loop
 end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_bind
+subroutine test_bind()
+  integer :: i, dummy = 1
+  ! CHECK: omp.loop bind(thread) private(@{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK: }
+  !$omp loop bind(thread)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine

>From 39cac1a61200c9c1211d7836aa5a9eac29eeaa03 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 7 Nov 2024 04:26:57 -0600
Subject: [PATCH 4/4] [flang][OpenMP] WIP: Rewrite `omp.loop` to semantically
 equivalent ops

Introduces a new conversion pass that rewrites `omp.loop` ops to their
semantically equivalent op nests bases on the surrounding/binding
context of the `loop` op. Not all forms of `omp.loop` are supported yet.
---
 flang/include/flang/Common/OpenMP-utils.h     |  68 +++++++
 .../include/flang/Optimizer/OpenMP/Passes.td  |  12 ++
 flang/lib/Common/CMakeLists.txt               |   4 +
 flang/lib/Common/OpenMP-utils.cpp             |  47 +++++
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 105 +----------
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   2 +
 .../OpenMP/GenericLoopConversion.cpp          | 171 ++++++++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   1 +
 .../Lower/OpenMP/generic-loop-rewriting.f90   |  40 ++++
 9 files changed, 354 insertions(+), 96 deletions(-)
 create mode 100644 flang/include/flang/Common/OpenMP-utils.h
 create mode 100644 flang/lib/Common/OpenMP-utils.cpp
 create mode 100644 flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
 create mode 100644 flang/test/Lower/OpenMP/generic-loop-rewriting.f90

diff --git a/flang/include/flang/Common/OpenMP-utils.h b/flang/include/flang/Common/OpenMP-utils.h
new file mode 100644
index 00000000000000..7dbb0f612b19cd
--- /dev/null
+++ b/flang/include/flang/Common/OpenMP-utils.h
@@ -0,0 +1,68 @@
+//===-- include/flang/Common/OpenMP-utils.h --------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_COMMON_OPENMP_UTILS_H_
+#define FORTRAN_COMMON_OPENMP_UTILS_H_
+
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/ArrayRef.h"
+
+namespace Fortran::openmp::common {
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to a single clause.
+struct EntryBlockArgsEntry {
+  llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
+  llvm::ArrayRef<mlir::Value> vars;
+
+  bool isValid() const {
+    // This check allows specifying a smaller number of symbols than values
+    // because in some case cases a single symbol generates multiple block
+    // arguments.
+    return syms.size() <= vars.size();
+  }
+};
+
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to all clauses that can define them.
+struct EntryBlockArgs {
+  EntryBlockArgsEntry inReduction;
+  EntryBlockArgsEntry map;
+  EntryBlockArgsEntry priv;
+  EntryBlockArgsEntry reduction;
+  EntryBlockArgsEntry taskReduction;
+  EntryBlockArgsEntry useDeviceAddr;
+  EntryBlockArgsEntry useDevicePtr;
+
+  bool isValid() const {
+    return inReduction.isValid() && map.isValid() && priv.isValid() &&
+        reduction.isValid() && taskReduction.isValid() &&
+        useDeviceAddr.isValid() && useDevicePtr.isValid();
+  }
+
+  auto getSyms() const {
+    return llvm::concat<const Fortran::semantics::Symbol *const>(
+        inReduction.syms, map.syms, priv.syms, reduction.syms,
+        taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
+  }
+
+  auto getVars() const {
+    return llvm::concat<const mlir::Value>(inReduction.vars, map.vars,
+        priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars,
+        useDevicePtr.vars);
+  }
+};
+
+mlir::Block *genEntryBlock(
+    mlir::OpBuilder &builder, const EntryBlockArgs &args, mlir::Region &region);
+} // namespace Fortran::openmp::common
+
+#endif // FORTRAN_COMMON_OPENMP_UTILS_H_
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index c070bc22ff20cc..bb5ad4b4e0f7bb 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -50,4 +50,16 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
   ];
 }
 
+def GenericLoopConversionPass
+    : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
+  let summary = "Converts OpenMP generic `loop` direcitve to semantically "
+                "equivalent OpenMP ops";
+  let description = [{
+     Rewrites `loop` ops to their semantically equivalent nest of ops. The
+     rewrite depends on the nesting/combination structure of the `loop` op
+     within its surrounding context as well as its `bind` clause value.
+  }];
+  let dependentDialects = ["mlir::omp::OpenMPDialect"];
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/lib/Common/CMakeLists.txt b/flang/lib/Common/CMakeLists.txt
index be72391847f3dd..de6bea396f3cbe 100644
--- a/flang/lib/Common/CMakeLists.txt
+++ b/flang/lib/Common/CMakeLists.txt
@@ -40,9 +40,13 @@ add_flang_library(FortranCommon
   default-kinds.cpp
   idioms.cpp
   LangOptions.cpp
+  OpenMP-utils.cpp
   Version.cpp
   ${version_inc}
 
   LINK_COMPONENTS
   Support
+
+  LINK_LIBS
+  MLIRIR
 )
diff --git a/flang/lib/Common/OpenMP-utils.cpp b/flang/lib/Common/OpenMP-utils.cpp
new file mode 100644
index 00000000000000..32df2be01e5484
--- /dev/null
+++ b/flang/lib/Common/OpenMP-utils.cpp
@@ -0,0 +1,47 @@
+//===-- include/flang/Common/OpenMP-utils.cpp ------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace Fortran::openmp::common {
+mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,
+    mlir::Region &region) {
+  assert(args.isValid() && "invalid args");
+  assert(region.empty() && "non-empty region");
+
+  llvm::SmallVector<mlir::Type> types;
+  llvm::SmallVector<mlir::Location> locs;
+  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
+      args.priv.vars.size() + args.reduction.vars.size() +
+      args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() +
+      args.useDevicePtr.vars.size();
+  types.reserve(numVars);
+  locs.reserve(numVars);
+
+  auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
+    llvm::transform(vals, std::back_inserter(types),
+        [](mlir::Value v) { return v.getType(); });
+    llvm::transform(vals, std::back_inserter(locs),
+        [](mlir::Value v) { return v.getLoc(); });
+  };
+
+  // Populate block arguments in clause name alphabetical order to match
+  // expected order by the BlockArgOpenMPOpInterface.
+  extractTypeLoc(args.inReduction.vars);
+  extractTypeLoc(args.map.vars);
+  extractTypeLoc(args.priv.vars);
+  extractTypeLoc(args.reduction.vars);
+  extractTypeLoc(args.taskReduction.vars);
+  extractTypeLoc(args.useDeviceAddr.vars);
+  extractTypeLoc(args.useDevicePtr.vars);
+
+  return builder.createBlock(&region, {}, types, locs);
+}
+} // namespace Fortran::openmp::common
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 4c733bbbd7b2a3..d397c9c2e4e059 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -19,6 +19,7 @@
 #include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
+#include "flang/Common/OpenMP-utils.h"
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
@@ -40,57 +41,12 @@
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 
 using namespace Fortran::lower::omp;
+using namespace Fortran::openmp::common;
 
 //===----------------------------------------------------------------------===//
 // Code generation helper functions
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to a single clause.
-struct EntryBlockArgsEntry {
-  llvm::ArrayRef<const semantics::Symbol *> syms;
-  llvm::ArrayRef<mlir::Value> vars;
-
-  bool isValid() const {
-    // This check allows specifying a smaller number of symbols than values
-    // because in some case cases a single symbol generates multiple block
-    // arguments.
-    return syms.size() <= vars.size();
-  }
-};
-
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to all clauses that can define them.
-struct EntryBlockArgs {
-  EntryBlockArgsEntry inReduction;
-  EntryBlockArgsEntry map;
-  EntryBlockArgsEntry priv;
-  EntryBlockArgsEntry reduction;
-  EntryBlockArgsEntry taskReduction;
-  EntryBlockArgsEntry useDeviceAddr;
-  EntryBlockArgsEntry useDevicePtr;
-
-  bool isValid() const {
-    return inReduction.isValid() && map.isValid() && priv.isValid() &&
-           reduction.isValid() && taskReduction.isValid() &&
-           useDeviceAddr.isValid() && useDevicePtr.isValid();
-  }
-
-  auto getSyms() const {
-    return llvm::concat<const semantics::Symbol *const>(
-        inReduction.syms, map.syms, priv.syms, reduction.syms,
-        taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
-  }
-
-  auto getVars() const {
-    return llvm::concat<const mlir::Value>(
-        inReduction.vars, map.vars, priv.vars, reduction.vars,
-        taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
-  }
-};
-} // namespace
-
 static void genOMPDispatch(lower::AbstractConverter &converter,
                            lower::SymMap &symTable,
                            semantics::SemanticsContext &semaCtx,
@@ -622,50 +578,6 @@ static void genLoopVars(
   firOpBuilder.setInsertionPointAfter(storeOp);
 }
 
-/// Create an entry block for the given region, including the clause-defined
-/// arguments specified.
-///
-/// \param [in] converter - PFT to MLIR conversion interface.
-/// \param [in]      args - entry block arguments information for the given
-///                         operation.
-/// \param [in]    region - Empty region in which to create the entry block.
-static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
-                                  const EntryBlockArgs &args,
-                                  mlir::Region &region) {
-  assert(args.isValid() && "invalid args");
-  assert(region.empty() && "non-empty region");
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  llvm::SmallVector<mlir::Type> types;
-  llvm::SmallVector<mlir::Location> locs;
-  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
-                     args.priv.vars.size() + args.reduction.vars.size() +
-                     args.taskReduction.vars.size() +
-                     args.useDeviceAddr.vars.size() +
-                     args.useDevicePtr.vars.size();
-  types.reserve(numVars);
-  locs.reserve(numVars);
-
-  auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
-    llvm::transform(vals, std::back_inserter(types),
-                    [](mlir::Value v) { return v.getType(); });
-    llvm::transform(vals, std::back_inserter(locs),
-                    [](mlir::Value v) { return v.getLoc(); });
-  };
-
-  // Populate block arguments in clause name alphabetical order to match
-  // expected order by the BlockArgOpenMPOpInterface.
-  extractTypeLoc(args.inReduction.vars);
-  extractTypeLoc(args.map.vars);
-  extractTypeLoc(args.priv.vars);
-  extractTypeLoc(args.reduction.vars);
-  extractTypeLoc(args.taskReduction.vars);
-  extractTypeLoc(args.useDeviceAddr.vars);
-  extractTypeLoc(args.useDevicePtr.vars);
-
-  return firOpBuilder.createBlock(&region, {}, types, locs);
-}
-
 static void
 markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
                   mlir::omp::DeclareTargetCaptureClause captureClause,
@@ -918,7 +830,7 @@ static void genBodyOfTargetDataOp(
     ConstructQueue::const_iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  genEntryBlock(converter, args, dataOp.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, dataOp.getRegion());
   bindEntryBlockArgs(converter, dataOp, args);
 
   // Insert dummy instruction to remember the insertion position. The
@@ -995,7 +907,8 @@ static void genBodyOfTargetOp(
   auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
 
   mlir::Region &region = targetOp.getRegion();
-  mlir::Block *entryBlock = genEntryBlock(converter, args, region);
+  mlir::Block *entryBlock =
+      genEntryBlock(converter.getFirOpBuilder(), args, region);
   bindEntryBlockArgs(converter, targetOp, args);
 
   // Check if cloning the bounds introduced any dependency on the outer region.
@@ -1121,7 +1034,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter,
   auto op = firOpBuilder.create<OpTy>(loc, clauseOps);
 
   // Create entry block with arguments.
-  genEntryBlock(converter, args, op.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, op.getRegion());
 
   return op;
 }
@@ -1544,7 +1457,7 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               const EntryBlockArgs &args, DataSharingProcessor *dsp,
               bool isComposite = false) {
   auto genRegionEntryCB = [&](mlir::Operation *op) {
-    genEntryBlock(converter, args, op->getRegion(0));
+    genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
     bindEntryBlockArgs(
         converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
     return llvm::to_vector(args.getSyms());
@@ -1617,12 +1530,12 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   args.reduction.syms = reductionSyms;
   args.reduction.vars = clauseOps.reductionVars;
 
-  genEntryBlock(converter, args, sectionsOp.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, sectionsOp.getRegion());
   mlir::Operation *terminator =
       lower::genOpenMPTerminator(builder, sectionsOp, loc);
 
   auto genRegionEntryCB = [&](mlir::Operation *op) {
-    genEntryBlock(converter, args, op->getRegion(0));
+    genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
     bindEntryBlockArgs(
         converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
     return llvm::to_vector(args.getSyms());
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 035d0d5ca46c76..f9d284b8d77820 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 
 add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
+  GenericLoopConversion.cpp
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
@@ -24,4 +25,5 @@ add_flang_library(FlangOpenMPTransforms
   HLFIRDialect
   MLIRIR
   MLIRPass
+  MLIRTransformUtils
 )
diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
new file mode 100644
index 00000000000000..ecd945830c23a3
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -0,0 +1,171 @@
+//===- GenericLoopConversion.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include <memory>
+
+namespace flangomp {
+#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+namespace {
+
+class GenericLoopConversionPattern
+    : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
+public:
+  enum class GenericLoopCombinedInfo {
+    None,
+    TargetTeamsLoop,
+    TargetParallelLoop
+  };
+
+  using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
+
+  GenericLoopConversionPattern(mlir::MLIRContext *context)
+      : OpConversionPattern(context) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    assert(isLoopConversionSupported(loopOp));
+
+    rewriteToDistributeParallelDo(loopOp, rewriter);
+    rewriter.eraseOp(loopOp);
+    return mlir::success();
+  }
+
+  static GenericLoopCombinedInfo
+  findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
+    mlir::Operation *parentOp = loopOp->getParentOp();
+    GenericLoopCombinedInfo result = GenericLoopCombinedInfo::None;
+
+    if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
+      if (mlir::isa<mlir::omp::TargetOp>(teamsOp->getParentOp()))
+        result = GenericLoopCombinedInfo::TargetTeamsLoop;
+
+    if (auto parallelOp =
+            mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
+      if (mlir::isa<mlir::omp::TargetOp>(parallelOp->getParentOp()))
+        result = GenericLoopCombinedInfo::TargetParallelLoop;
+
+    return result;
+  }
+
+  static bool isLoopConversionSupported(mlir::omp::LoopOp loopOp) {
+    GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
+
+    // TODO Support standalone `loop` ops and other forms of combined `loop` op
+    // nests.
+    if (combinedInfo != GenericLoopCombinedInfo::TargetTeamsLoop)
+      return false;
+
+    // TODO Support other clauses.
+    if (loopOp.getBindKind() || loopOp.getOrder() ||
+        !loopOp.getReductionVars().empty())
+      return false;
+
+    // TODO For `target teams loop`, check similar constrains to what is checked
+    // by `TeamsLoopChecker` in SemaOpenMP.cpp.
+    return true;
+  }
+
+  void rewriteToDistributeParallelDo(
+      mlir::omp::LoopOp loopOp,
+      mlir::ConversionPatternRewriter &rewriter) const {
+    mlir::omp::ParallelOperands parallelClauseOps;
+    parallelClauseOps.privateVars = loopOp.getPrivateVars();
+
+    if (loopOp.getPrivateSyms())
+      parallelClauseOps.privateSyms = llvm::SmallVector<mlir::Attribute>(
+          loopOp.getPrivateSyms()->getAsRange<mlir::Attribute>());
+
+    Fortran::openmp::common::EntryBlockArgs parallelArgs;
+    parallelArgs.priv.vars = parallelClauseOps.privateVars;
+
+    auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
+                                                             parallelClauseOps);
+    mlir::Block *parallelBlock =
+        genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
+    parallelOp.setComposite(true);
+    rewriter.setInsertionPoint(
+        rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
+
+    mlir::omp::DistributeOperands distributeClauseOps;
+    auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
+        loopOp.getLoc(), distributeClauseOps);
+    distributeOp.setComposite(true);
+    rewriter.createBlock(&distributeOp.getRegion());
+
+    mlir::omp::WsloopOperands wsloopClauseOps;
+    auto wsloopOp =
+        rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
+    wsloopOp.setComposite(true);
+    rewriter.createBlock(&wsloopOp.getRegion());
+
+    mlir::IRMapping mapper;
+    mlir::Block &loopBlock = *loopOp.getRegion().begin();
+
+    for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
+             loopBlock.getArguments(), parallelBlock->getArguments()))
+      mapper.map(loopOpArg, parallelOpArg);
+
+    rewriter.clone(*loopOp.begin(), mapper);
+
+    // TODO we probably need to move the `loop_nest` bounds ops from the `teams`
+    // region to the `parallel` region to avoid making these values `shared`. We
+    // can find the backward slices of these bounds that are within the `teams`
+    // region and move these slices to the `parallel` op.
+  }
+};
+
+class GenericLoopConversionPass
+    : public flangomp::impl::GenericLoopConversionPassBase<
+          GenericLoopConversionPass> {
+public:
+  GenericLoopConversionPass() = default;
+
+  void runOnOperation() override {
+    mlir::func::FuncOp func = getOperation();
+
+    if (func.isDeclaration()) {
+      return;
+    }
+
+    mlir::MLIRContext *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    patterns.insert<GenericLoopConversionPattern>(context);
+    mlir::ConversionTarget target(*context);
+    target.markUnknownOpDynamicallyLegal(
+        [](mlir::Operation *) { return true; });
+    target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
+        [](mlir::omp::LoopOp loopOp) {
+          return !GenericLoopConversionPattern::isLoopConversionSupported(
+              loopOp);
+        });
+
+    if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
+                                               std::move(patterns)))) {
+      mlir::emitError(mlir::UnknownLoc::get(context),
+                      "error in converting `omp.loop` op");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index a9144079915912..c3b9891b7bc854 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -245,6 +245,7 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm, bool isTargetDevice) {
   pm.addPass(flangomp::createMapInfoFinalizationPass());
   pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
   pm.addPass(flangomp::createMarkDeclareTargetPass());
+  pm.addPass(flangomp::createGenericLoopConversionPass());
   if (isTargetDevice)
     pm.addPass(flangomp::createFunctionFilteringPass());
 }
diff --git a/flang/test/Lower/OpenMP/generic-loop-rewriting.f90 b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
new file mode 100644
index 00000000000000..ba3363e8483646
--- /dev/null
+++ b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
@@ -0,0 +1,40 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine target_teams_loop
+    implicit none
+    integer :: x, i
+
+    !$omp target teams loop
+    do i = 0, 10
+      x = x + i
+    end do
+end subroutine target_teams_loop
+
+!CHECK-LABEL: func.func @_QPtarget_teams_loop
+!CHECK:         omp.target map_entries(
+!CHECK-SAME:      %{{.*}} -> %[[I_ARG:[^[:space:]]+]],
+!CHECK-SAME:      %{{.*}} -> %[[X_ARG:[^[:space:]]+]] : {{.*}}) {
+
+!CHECK:           %[[I_DECL:.*]]:2 = hlfir.declare %[[I_ARG]]
+!CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]]
+
+!CHECK:           omp.teams {
+!CHECK:             %[[LB:.*]] = arith.constant 0 : i32
+!CHECK:             %[[UB:.*]] = arith.constant 10 : i32
+!CHECK:             %[[STEP:.*]] = arith.constant 1 : i32
+
+!CHECK:             omp.parallel private(@{{.*}} %[[I_DECL]]#0 
+!CHECK-SAME:          -> %[[I_PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
+!CHECK:               omp.distribute {
+!CHECK:                 omp.wsloop {
+
+!CHECK:                   omp.loop_nest (%{{.*}}) : i32 = 
+!CHECK-SAME:                (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+!CHECK:                     %[[I_PRIV_DECL:.*]]:2 = hlfir.declare %[[I_PRIV_ARG]]
+!CHECK:                     fir.store %{{.*}} to %[[I_PRIV_DECL]]#1 : !fir.ref<i32>
+!CHECK:                   }
+!CHECK:                 }
+!CHECK:               }
+!CHECK:             }
+!CHECK:           }
+!CHECK:         }



More information about the Mlir-commits mailing list