[Mlir-commits] [flang] [mlir] [flang][OpenMP] Add basic support to lower `loop` directive to MLIR (PR #114199)

Kareem Ergawy llvmlistbot at llvm.org
Fri Nov 8 01:40:16 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/114199

>From b4b65697874b7643801eaaf64cbd2eed415dc442 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 30 Oct 2024 04:40:14 -0500
Subject: [PATCH] [flang][OpenMP] Add basic support to lower `loop` to MLIR

Adds initial support for lowering the `loop` directive to MLIR.

The PR includes basic suport and testing for the following clauses:
 * `collapse`
 * `order`
 * `private`
 * `reduction`
---
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 48 +++++++++-
 .../test/Lower/OpenMP/Todo/loop-directive.f90 | 15 ----
 flang/test/Lower/OpenMP/loop-directive.f90    | 90 +++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  4 +
 .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp  | 14 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 11 +++
 6 files changed, 159 insertions(+), 23 deletions(-)
 delete mode 100644 flang/test/Lower/OpenMP/Todo/loop-directive.f90
 create mode 100644 flang/test/Lower/OpenMP/loop-directive.f90

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 4f9e2347308aa1..e3dc780321f89f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1176,6 +1176,18 @@ genLoopNestClauses(lower::AbstractConverter &converter,
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
 }
 
+static void genLoopClauses(
+    lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
+    const List<Clause> &clauses, mlir::Location loc,
+    mlir::omp::LoopOperands &clauseOps,
+    llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+  ClauseProcessor cp(converter, semaCtx, clauses);
+  cp.processOrder(clauseOps);
+  cp.processReduction(loc, clauseOps, reductionSyms);
+  cp.processTODO<clause::Bind, clause::Lastprivate>(
+      loc, llvm::omp::Directive::OMPD_loop);
+}
+
 static void genMaskedClauses(lower::AbstractConverter &converter,
                              semantics::SemanticsContext &semaCtx,
                              lower::StatementContext &stmtCtx,
@@ -2051,6 +2063,40 @@ static void genStandaloneDo(lower::AbstractConverter &converter,
                 llvm::omp::Directive::OMPD_do, dsp);
 }
 
+static void genLoopOp(lower::AbstractConverter &converter,
+                      lower::SymMap &symTable,
+                      semantics::SemanticsContext &semaCtx,
+                      lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  mlir::omp::LoopOperands loopClauseOps;
+  llvm::SmallVector<const semantics::Symbol *> loopReductionSyms;
+  genLoopClauses(converter, semaCtx, item->clauses, loc, loopClauseOps,
+                 loopReductionSyms);
+
+  DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
+                           /*shouldCollectPreDeterminedSymbols=*/true,
+                           /*useDelayedPrivatization=*/true, &symTable);
+  dsp.processStep1(&loopClauseOps);
+
+  mlir::omp::LoopNestOperands loopNestClauseOps;
+  llvm::SmallVector<const semantics::Symbol *> iv;
+  genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
+                     loopNestClauseOps, iv);
+
+  EntryBlockArgs loopArgs;
+  loopArgs.priv.syms = dsp.getDelayedPrivSymbols();
+  loopArgs.priv.vars = loopClauseOps.privateVars;
+  loopArgs.reduction.syms = loopReductionSyms;
+  loopArgs.reduction.vars = loopClauseOps.reductionVars;
+
+  auto loopOp =
+      genWrapperOp<mlir::omp::LoopOp>(converter, loc, loopClauseOps, loopArgs);
+  genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
+                loopNestClauseOps, iv, {{loopOp, loopArgs}},
+                llvm::omp::Directive::OMPD_loop, dsp);
+}
+
 static void genStandaloneParallel(lower::AbstractConverter &converter,
                                   lower::SymMap &symTable,
                                   semantics::SemanticsContext &semaCtx,
@@ -2479,7 +2525,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
     genStandaloneDo(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_loop:
-    TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+    genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_masked:
     genMaskedOp(converter, symTable, semaCtx, eval, loc, queue, item);
diff --git a/flang/test/Lower/OpenMP/Todo/loop-directive.f90 b/flang/test/Lower/OpenMP/Todo/loop-directive.f90
deleted file mode 100644
index f1aea70458aa6c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/loop-directive.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! This test checks lowering of OpenMP loop Directive.
-
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Unhandled directive loop
-subroutine test_loop()
-  integer :: i, j = 1
-  !$omp loop
-  do i=1,10
-   j = j + 1
-  end do
-  !$omp end loop
-end subroutine
-
diff --git a/flang/test/Lower/OpenMP/loop-directive.f90 b/flang/test/Lower/OpenMP/loop-directive.f90
new file mode 100644
index 00000000000000..0a27a18aa8b52f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/loop-directive.f90
@@ -0,0 +1,90 @@
+! This test checks lowering of OpenMP loop Directive.
+
+! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
+
+! CHECK: omp.declare_reduction @[[RED:add_reduction_i32]] : i32
+! CHECK: omp.private {type = private} @[[DUMMY_PRIV:.*test_privateEdummy_private.*]] : !fir.ref<i32>
+! CHECK: omp.private {type = private} @[[I_PRIV:.*test_no_clausesEi.*]] : !fir.ref<i32>
+
+! CHECK-LABEL: func.func @_QPtest_no_clauses
+subroutine test_no_clauses()
+  integer :: i, j, dummy = 1
+
+  ! CHECK: omp.loop private(@[[I_PRIV]] %{{.*}}#0 -> %[[ARG:.*]] : !fir.ref<i32>) {
+  ! CHECK:   omp.loop_nest (%[[IV:.*]]) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[ARG_DECL:.*]]:2 = hlfir.declare %[[ARG]]
+  ! CHECK:     fir.store %[[IV]] to %[[ARG_DECL]]#1 : !fir.ref<i32>
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_collapse
+subroutine test_collapse()
+  integer :: i, j, dummy = 1
+  ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}}, @{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}, %{{.*}}) : i32 {{.*}} {
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop collapse(2)
+  do i=1,10
+    do j=2,20
+     dummy = dummy + 1
+    end do
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_private
+subroutine test_private()
+  integer :: i, dummy = 1
+  ! CHECK: omp.loop private(@[[DUMMY_PRIV]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]], @{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_privateEdummy"}
+  ! CHECK:     %{{.*}} = fir.load %[[DUMMY_DECL]]#0
+  ! CHECK:     hlfir.assign %{{.*}} to %[[DUMMY_DECL]]#0
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop private(dummy)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+
+! CHECK-LABEL: func.func @_QPtest_order
+subroutine test_order()
+  integer :: i, dummy = 1
+  ! CHECK: omp.loop order(reproducible:concurrent) private(@{{.*}} %{{.*}}#0 -> %{{.*}} : {{.*}}) {
+  ! CHECK: }
+  !$omp loop order(concurrent)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtest_reduction
+subroutine test_reduction()
+  integer :: i, dummy = 1
+
+  ! CHECK: omp.loop private(@{{.*}} %{{.*}}#0 -> %{{.*}} : !{{.*}}) reduction
+  ! CHECK-SAME:  (@[[RED]] %{{.*}}#0 -> %[[DUMMY_ARG:.*]] : !{{.*}}) {
+  ! CHECK:   omp.loop_nest (%{{.*}}) : i32 = (%{{.*}}) to (%{{.*}}) {{.*}} {
+  ! CHECK:     %[[DUMMY_DECL:.*]]:2 = hlfir.declare %[[DUMMY_ARG]] {uniq_name = "_QFtest_reductionEdummy"}
+  ! CHECK:     %{{.*}} = fir.load %[[DUMMY_DECL]]#0
+  ! CHECK:     hlfir.assign %{{.*}} to %[[DUMMY_DECL]]#0
+  ! CHECK:   }
+  ! CHECK: }
+  !$omp loop reduction(+:dummy)
+  do i=1,10
+   dummy = dummy + 1
+  end do
+  !$omp end loop
+end subroutine
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a0da3db124d1f4..b01fda993f5e78 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -422,6 +422,10 @@ def LoopOp : OpenMP_Op<"loop", traits = [
         $reduction_syms) attr-dict
   }];
 
+  let builders = [
+    OpBuilder<(ins CArg<"const LoopOperands &">:$clauses)>
+  ];
+
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index aa824a95b1574a..58fd3d565fce50 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -234,11 +234,11 @@ void mlir::configureOpenMPToLLVMConversionLegality(
   });
   target.addDynamicallyLegalOp<
       omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
-      omp::DistributeOp, omp::LoopNestOp, omp::MasterOp, omp::OrderedRegionOp,
-      omp::ParallelOp, omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp,
-      omp::SimdOp, omp::SingleOp, omp::TargetDataOp, omp::TargetOp,
-      omp::TaskgroupOp, omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
-      omp::WsloopOp>([&](Operation *op) {
+      omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
+      omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp,
+      omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp,
+      omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp,
+      omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) {
     return std::all_of(op->getRegions().begin(), op->getRegions().end(),
                        [&](Region &region) {
                          return typeConverter.isLegal(&region);
@@ -275,8 +275,8 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       RegionOpConversion<omp::AtomicCaptureOp>,
       RegionOpConversion<omp::CriticalOp>,
       RegionOpConversion<omp::DistributeOp>,
-      RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::MaskedOp>,
-      RegionOpConversion<omp::MasterOp>,
+      RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
+      RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
       RegionOpConversion<omp::OrderedRegionOp>,
       RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
       RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 228c2d034ad4ad..5809d06e6a9d39 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1952,6 +1952,17 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
 // LoopOp
 //===----------------------------------------------------------------------===//
 
+void LoopOp::build(OpBuilder &builder, OperationState &state,
+                   const LoopOperands &clauses) {
+  MLIRContext *ctx = builder.getContext();
+
+  LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
+                makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
+                clauses.orderMod, clauses.reductionVars,
+                makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                makeArrayAttr(ctx, clauses.reductionSyms));
+}
+
 LogicalResult LoopOp::verify() {
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());



More information about the Mlir-commits mailing list