[flang-commits] [flang] [Flang][OpenMP] Minor changes in reduction to work with HLFIR (PR #65775)

via flang-commits flang-commits at lists.llvm.org
Fri Sep 8 09:13:06 PDT 2023


https://github.com/kiranchandramohan created https://github.com/llvm/llvm-project/pull/65775:

Changes are to work correctly in the presence of hlfir.declare, and hlfir.assign (instead of fir.store).

>From ee0c5f8b51055a3f8f696bab1e50d5cb7a824b34 Mon Sep 17 00:00:00 2001
From: Kiran Chandramohan <kiran.chandramohan at arm.com>
Date: Fri, 8 Sep 2023 16:04:01 +0000
Subject: [PATCH] [Flang][OpenMP] Minor changes in reduction to work with HLFIR

Changes are to work correctly in the presence of hlfir.declare,
and hlfir.assign (instead of fir.store).
---
 flang/lib/Lower/OpenMP.cpp                    | 21 +++++++++
 .../OpenMP/wsloop-reduction-add-hlfir.f90     | 43 +++++++++++++++++++
 .../OpenMP/wsloop-reduction-max-hlfir.f90     | 36 ++++++++++++++++
 3 files changed, 100 insertions(+)
 create mode 100644 flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90
 create mode 100644 flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e56b26a243a4423..aef9352e70ea32f 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1123,6 +1123,8 @@ addReductionDecl(mlir::Location currentLocation,
               Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
         if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
           mlir::Value symVal = converter.getSymbolAddress(*symbol);
+          if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+            symVal = declOp.getBase();
           mlir::Type redType =
               symVal.getType().cast<fir::ReferenceType>().getEleTy();
           reductionVars.push_back(symVal);
@@ -1160,6 +1162,8 @@ addReductionDecl(mlir::Location currentLocation,
                 Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
           if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
             mlir::Value symVal = converter.getSymbolAddress(*symbol);
+            if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+              symVal = declOp.getBase();
             mlir::Type redType =
                 symVal.getType().cast<fir::ReferenceType>().getEleTy();
             reductionVars.push_back(symVal);
@@ -3746,6 +3750,8 @@ void Fortran::lower::genOpenMPReduction(
                   Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
             if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
               mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+              if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
+                reductionVal = declOp.getBase();
               mlir::Type reductionType =
                   reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
               if (!reductionType.isa<fir::LogicalType>()) {
@@ -3789,6 +3795,9 @@ void Fortran::lower::genOpenMPReduction(
                     ompObject)}) {
               if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
                 mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+                if (auto declOp =
+                        reductionVal.getDefiningOp<hlfir::DeclareOp>())
+                  reductionVal = declOp.getBase();
                 for (const mlir::OpOperand &reductionValUse :
                      reductionVal.getUses()) {
                   if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
@@ -3844,6 +3853,13 @@ mlir::Operation *Fortran::lower::findReductionChain(mlir::Value loadVal,
             return reductionOp;
           }
         }
+        if (auto assign =
+                mlir::dyn_cast<hlfir::AssignOp>(reductionOperand.getOwner())) {
+          if (assign.getLhs() == *reductionVal) {
+            assign.erase();
+            return reductionOp;
+          }
+        }
       }
     }
   }
@@ -3899,6 +3915,11 @@ void Fortran::lower::removeStoreOp(mlir::Operation *reductionOp,
           if (storeOp.getMemref() == symVal)
             storeOp.erase();
         }
+        if (auto assignOp =
+                mlir::dyn_cast<hlfir::AssignOp>(convertReductionUse)) {
+          if (assignOp.getLhs() == symVal)
+            assignOp.erase();
+        }
       }
     }
   }
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90
new file mode 100644
index 000000000000000..97ee665442e3a8f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90
@@ -0,0 +1,43 @@
+! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK-LABEL: omp.reduction.declare
+!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
+!CHECK: ^bb0(%{{.*}}: i32):
+!CHECK:  %[[C0_1:.*]] = arith.constant 0 : i32
+!CHECK:  omp.yield(%[[C0_1]] : i32)
+!CHECK: } combiner {
+!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
+!CHECK:  %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
+!CHECK:  omp.yield(%[[RES]] : i32)
+!CHECK: }
+
+!CHECK-LABEL: func.func @_QPsimple_int_reduction
+!CHECK:  %[[XREF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_int_reductionEx"}
+!CHECK:  %[[XDECL:.*]]:2 = hlfir.declare %[[XREF]] {uniq_name = "_QFsimple_int_reductionEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:  %[[C0_2:.*]] = arith.constant 0 : i32
+!CHECK:  hlfir.assign %[[C0_2]] to %[[XDECL]]#0 : i32, !fir.ref<i32>
+!CHECK:  omp.parallel
+!CHECK:    %[[I_PVT_REF:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
+!CHECK:    %[[I_PVT_DECL:.*]]:2 = hlfir.declare %[[I_PVT_REF]] {uniq_name = "_QFsimple_int_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:    %[[C1_1:.*]] = arith.constant 1 : i32
+!CHECK:    %[[C100:.*]] = arith.constant 100 : i32
+!CHECK:    %[[C1_2:.*]] = arith.constant 1 : i32
+!CHECK:    omp.wsloop   reduction(@[[RED_I32_NAME]] -> %[[XDECL]]#0 : !fir.ref<i32>) for  (%[[IVAL:.*]]) : i32 = (%[[C1_1]]) to (%[[C100]]) inclusive step (%[[C1_2]])
+!CHECK:      fir.store %[[IVAL]] to %[[I_PVT_DECL]]#1 : !fir.ref<i32>
+!CHECK:      %[[I_PVT_VAL:.*]] = fir.load %[[I_PVT_DECL]]#0 : !fir.ref<i32>
+!CHECK:      omp.reduction %[[I_PVT_VAL]], %[[XDECL]]#0 : i32, !fir.ref<i32>
+!CHECK:      omp.yield
+!CHECK:    omp.terminator
+!CHECK:  return
+subroutine simple_int_reduction
+  integer :: x
+  x = 0
+  !$omp parallel
+  !$omp do reduction(+:x)
+  do i=1, 100
+    x = x + i
+  end do
+  !$omp end do
+  !$omp end parallel
+end subroutine
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90
new file mode 100644
index 000000000000000..0c5d99226600bfd
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90
@@ -0,0 +1,36 @@
+! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK: omp.reduction.declare @[[MAX_DECLARE_I:.*]] : i32 init {
+!CHECK:   %[[MINIMUM_VAL_I:.*]] = arith.constant -2147483648 : i32
+!CHECK:   omp.yield(%[[MINIMUM_VAL_I]] : i32)
+!CHECK: combiner
+!CHECK: ^bb0(%[[ARG0_I:.*]]: i32, %[[ARG1_I:.*]]: i32):
+!CHECK:   %[[COMB_VAL_I:.*]] = arith.maxsi %[[ARG0_I]], %[[ARG1_I]] : i32
+!CHECK:   omp.yield(%[[COMB_VAL_I]] : i32)
+
+!CHECK-LABEL: @_QPreduction_max_int
+!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box<!fir.array<?xi32>>
+!CHECK:   %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFreduction_max_intEx"}
+!CHECK:   %[[X_DECL:.*]]:2 = hlfir.declare %[[X_REF]] {uniq_name = "_QFreduction_max_intEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:   %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y_BOX]] {uniq_name = "_QFreduction_max_intEy"} : (!fir.box<!fir.array<?xi32>>) -> (!fir.box<!fir.array<?xi32>>, !fir.box<!fir.array<?xi32>>)
+!CHECK:   omp.parallel
+!CHECK:     omp.wsloop reduction(@[[MAX_DECLARE_I]] -> %[[X_DECL]]#0 : !fir.ref<i32>) for
+!CHECK:       %[[Y_I_REF:.*]] = hlfir.designate %[[Y_DECL]]#0 ({{.*}}) : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+!CHECK:       %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref<i32>
+!CHECK:       omp.reduction %[[Y_I]], %[[X_DECL]]#0 : i32, !fir.ref<i32>
+!CHECK:       omp.yield
+!CHECK:     omp.terminator
+
+subroutine reduction_max_int(y)
+  integer :: x, y(:)
+  x = 0
+  !$omp parallel
+  !$omp do reduction(max:x)
+  do i=1, 100
+    x = max(x, y(i))
+  end do
+  !$omp end do
+  !$omp end parallel
+  print *, x
+end subroutine



More information about the flang-commits mailing list