[flang-commits] [flang] 5e3faa0 - [flang][openacc] Lower reduction for compute constructs

Razvan Lupusoru via flang-commits flang-commits at lists.llvm.org
Wed Jun 7 13:45:12 PDT 2023


Author: Razvan Lupusoru
Date: 2023-06-07T13:44:25-07:00
New Revision: 5e3faa05a0f15368fe3aa28380fa530a9a745c4a

URL: https://github.com/llvm/llvm-project/commit/5e3faa05a0f15368fe3aa28380fa530a9a745c4a
DIFF: https://github.com/llvm/llvm-project/commit/5e3faa05a0f15368fe3aa28380fa530a9a745c4a.diff

LOG: [flang][openacc] Lower reduction for compute constructs

Parallel and serial constructs support reduction clause. Extend
recent D151564 loop reduction clause support to also include these
compute constructs.

Reviewed By: clementval, vzakhari

Differential Revision: https://reviews.llvm.org/D151955

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/test/Lower/OpenACC/acc-kernels-loop.f90
    flang/test/Lower/OpenACC/acc-loop.f90
    flang/test/Lower/OpenACC/acc-parallel-loop.f90
    flang/test/Lower/OpenACC/acc-parallel.f90
    flang/test/Lower/OpenACC/acc-serial-loop.f90
    flang/test/Lower/OpenACC/acc-serial.f90
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index c59be17f2c6e1..abdf320cdfd98 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -997,7 +997,7 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
 
   llvm::SmallVector<mlir::Value> reductionOperands, privateOperands,
       firstprivateOperands;
-  llvm::SmallVector<mlir::Attribute> privatizations;
+  llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
 
   // Async, wait and self clause have optional values but can be present with
   // no value as well. When there is no value, the op has an attribute to
@@ -1151,8 +1151,11 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       genObjectList(firstprivateClause->v, converter, semanticsContext, stmtCtx,
                     firstprivateOperands);
-    } else if (std::get_if<Fortran::parser::AccClause::Reduction>(&clause.u)) {
-      TODO(clauseLocation, "compute construct reduction clause lowering");
+    } else if (const auto *reductionClause =
+                   std::get_if<Fortran::parser::AccClause::Reduction>(
+                       &clause.u)) {
+      genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
+                    reductionOperands, reductionRecipes);
     }
   }
 
@@ -1194,6 +1197,9 @@ createComputeOp(Fortran::lower::AbstractConverter &converter,
     if (!privatizations.empty())
       computeOp.setPrivatizationsAttr(
           mlir::ArrayAttr::get(builder.getContext(), privatizations));
+    if (!reductionRecipes.empty())
+      computeOp.setReductionRecipesAttr(
+          mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
   }
 
   auto insPt = builder.saveInsertionPoint();

diff  --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index 6aad08b13499c..33c3a8c447cf4 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -16,6 +16,8 @@ subroutine acc_kernels_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -709,6 +711,20 @@ subroutine acc_kernels_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
+! CHECK-NEXT: }{{$}}
+
+  !$acc kernels loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.kernels {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
 end subroutine

diff  --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index ec8eb0f73b74e..5b84763e32d7b 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -17,6 +17,8 @@ program acc_loop
   integer :: gangStatic = 8
   integer :: vectorLength = 128
   integer, parameter :: tileSize = 2
+  integer :: reduction_i
+  real :: reduction_r
 
 
   !$acc loop
@@ -270,4 +272,15 @@ program acc_loop
 !CHECK:        acc.yield
 !CHECK-NEXT: }{{$}}
 
+  !$acc loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        fir.do_loop
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end program

diff  --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index 38df6228acc83..b295a905bfd85 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -23,6 +23,8 @@ subroutine acc_parallel_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -729,6 +731,20 @@ subroutine acc_parallel_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
+  !$acc parallel loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
 end subroutine acc_parallel_loop

diff  --git a/flang/test/Lower/OpenACC/acc-parallel.f90 b/flang/test/Lower/OpenACC/acc-parallel.f90
index d1c9d80c1fbb6..acfab91f46710 100644
--- a/flang/test/Lower/OpenACC/acc-parallel.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel.f90
@@ -21,6 +21,8 @@ subroutine acc_parallel
   logical :: ifCondition = .TRUE.
   real, dimension(10, 10) :: a, b, c
   real, pointer :: d, e
+  integer :: reduction_i
+  real :: reduction_r
 
 !CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
 !CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
@@ -302,4 +304,11 @@ subroutine acc_parallel
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
+!$acc parallel reduction(+:reduction_r) reduction(*:reduction_i)
+!$acc end parallel
+
+! CHECK:      acc.parallel reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end subroutine acc_parallel

diff  --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 2e26da8bb2c63..bf83af8bf55fd 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -23,6 +23,8 @@ subroutine acc_serial_loop
   real, dimension(n) :: a, b, c
   real, dimension(n, n) :: d, e
   real, pointer :: f, g
+  integer :: reduction_i
+  real :: reduction_r
 
   integer :: gangNum = 8
   integer :: gangStatic = 8
@@ -645,6 +647,20 @@ subroutine acc_serial_loop
 ! CHECK:          acc.yield
 ! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
+  !$acc serial loop reduction(+:reduction_r) reduction(*:reduction_i)
+  do i = 1, n
+    reduction_r = reduction_r + a(i)
+    reduction_i = 1
+  end do
+
+! CHECK:      acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.loop reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:          fir.do_loop
+! CHECK:          acc.yield
+! CHECK-NEXT:   }{{$}}
+! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
 end subroutine acc_serial_loop

diff  --git a/flang/test/Lower/OpenACC/acc-serial.f90 b/flang/test/Lower/OpenACC/acc-serial.f90
index d10a3ab7a0c4f..4d17d58c24100 100644
--- a/flang/test/Lower/OpenACC/acc-serial.f90
+++ b/flang/test/Lower/OpenACC/acc-serial.f90
@@ -21,6 +21,8 @@ subroutine acc_serial
   logical :: ifCondition = .TRUE.
   real, dimension(10, 10) :: a, b, c
   real, pointer :: d, e
+  integer :: reduction_i
+  real :: reduction_r
 
 ! CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
 ! CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
@@ -245,4 +247,11 @@ subroutine acc_serial
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
+!$acc serial reduction(+:reduction_r) reduction(*:reduction_i)
+!$acc end serial
+
+! CHECK:      acc.serial reduction(@reduction_add_f32 -> %{{.*}} : !fir.ref<f32>, @reduction_mul_i32 -> %{{.*}} : !fir.ref<i32>) {
+! CHECK:        acc.yield
+! CHECK-NEXT: }{{$}}
+
 end subroutine

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index da5a2856aec21..b2998b736f991 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -558,7 +558,7 @@ LogicalResult acc::ParallelOp::verify() {
     return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
@@ -586,7 +586,7 @@ LogicalResult acc::SerialOp::verify() {
     return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
-          "reductions")))
+          "reductions", false)))
     return failure();
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }


        


More information about the flang-commits mailing list