[flang-commits] [flang] [flang][OpenMP] Common lowering flow for atomic update (PR #69866)

via flang-commits flang-commits at lists.llvm.org
Sat Oct 21 21:35:42 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: None (NimishMishra)

<details>
<summary>Changes</summary>

Offers a common lowering flow for scalar/non-scalar atomic variables. Fixes https://github.com/llvm/llvm-project/issues/68384

TODOs:
1. Lower intrinsic procedures
2. Use correct operations for for AND, OR, EQV, NEQV
3. Discuss whether multiply/divide are signed or unsigned

---
Full diff: https://github.com/llvm/llvm-project/pull/69866.diff


2 Files Affected:

- (modified) flang/lib/Lower/DirectivesCommon.h (+86-102) 
- (added) flang/test/Lower/OpenMP/common-atomic-lowering.f90 (+74) 


``````````diff
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index ed44598bc925212..56d79de2d31995d 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -204,76 +204,62 @@ static inline void genOmpAccAtomicUpdateStatement(
   // Generate `omp.atomic.update` operation for atomic assignment statements
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Location currentLocation = converter.getCurrentLocation();
-
-  const auto *varDesignator =
-      std::get_if<Fortran::common::Indirection<Fortran::parser::Designator>>(
-          &assignmentStmtVariable.u);
-  assert(varDesignator && "Variable designator for atomic update assignment "
-                          "statement does not exist");
-  const Fortran::parser::Name *name =
-      Fortran::semantics::getDesignatorNameIfDataRef(varDesignator->value());
-  if (!name)
-    TODO(converter.getCurrentLocation(),
-         "Array references as atomic update variable");
-  assert(name && name->symbol &&
-         "No symbol attached to atomic update variable");
-  if (Fortran::semantics::IsAllocatableOrPointer(name->symbol->GetUltimate()))
-    converter.bindSymbol(*name->symbol, lhsAddr);
-
-  //  Lowering is in two steps :
-  //  subroutine sb
-  //    integer :: a, b
-  //    !$omp atomic update
-  //      a = a + b
-  //  end subroutine
-  //
-  //  1. Lower to scf.execute_region_op
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = scf.execute_region -> i32 {
-  //      %3 = fir.load %0 : !fir.ref<i32>
-  //      %4 = fir.load %1 : !fir.ref<i32>
-  //      %5 = arith.addi %3, %4 : i32
-  //      scf.yield %5 : i32
-  //    }
-  //    return
-  //  }
-  auto tempOp =
-      firOpBuilder.create<mlir::scf::ExecuteRegionOp>(currentLocation, varType);
-  firOpBuilder.createBlock(&tempOp.getRegion());
-  mlir::Block &block = tempOp.getRegion().back();
-  firOpBuilder.setInsertionPointToEnd(&block);
   Fortran::lower::StatementContext stmtCtx;
-  mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
-      *Fortran::semantics::GetExpr(assignmentStmtExpr), stmtCtx));
-  mlir::Value convertResult =
-      firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
-  // Insert the terminator: YieldOp.
-  firOpBuilder.create<mlir::scf::YieldOp>(currentLocation, convertResult);
-  firOpBuilder.setInsertionPointToStart(&block);
-
-  //  2. Create the omp.atomic.update Operation using the Operations in the
-  //     temporary scf.execute_region Operation.
-  //
-  //  func.func @_QPsb() {
-  //    %0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFsbEa"}
-  //    %1 = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFsbEb"}
-  //    %2 = fir.load %1 : !fir.ref<i32>
-  //    omp.atomic.update   %0 : !fir.ref<i32> {
-  //    ^bb0(%arg0: i32):
-  //      %3 = fir.load %1 : !fir.ref<i32>
-  //      %4 = arith.addi %arg0, %3 : i32
-  //      omp.yield(%3 : i32)
-  //    }
-  //    return
-  //  }
-  mlir::Value updateVar = converter.getSymbolAddress(*name->symbol);
-  if (auto decl = updateVar.getDefiningOp<hlfir::DeclareOp>())
-    updateVar = decl.getBase();
-
-  firOpBuilder.setInsertionPointAfter(tempOp);
+  mlir::Value convertRhs = nullptr;
+
+  auto lowerExpression = [&](const auto &intrinsicBinaryExpr) {
+    const auto &variableName{assignmentStmtVariable.GetSource().ToString()};
+    const auto &exprLeft{std::get<0>(intrinsicBinaryExpr.t)};
+    if (exprLeft.value().source.ToString() == variableName) {
+      // Update statement is of form `x = x op expr`
+      const auto &exprToLower{std::get<1>(intrinsicBinaryExpr.t)};
+      mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(exprToLower), stmtCtx));
+      convertRhs =
+          firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    } else {
+      // Update statement is of form `x = expr op x`
+      const auto &exprToLower{std::get<0>(intrinsicBinaryExpr.t)};
+      mlir::Value rhsExpr = fir::getBase(converter.genExprValue(
+          *Fortran::semantics::GetExpr(exprToLower), stmtCtx));
+      convertRhs =
+          firOpBuilder.createConvert(currentLocation, varType, rhsExpr);
+    }
+  };
+  Fortran::common::visit(
+      Fortran::common::visitors{
+          [&](const common::Indirection<parser::FunctionReference> &x) {
+            TODO(converter.getCurrentLocation(),
+                 "Not yet implemented: intrinsic procedure in atomic update "
+                 "expressions");
+          },
+          [&](const Fortran::parser::Expr::Add &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::Subtract &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::Multiply &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::Divide &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::AND &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::OR &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::EQV &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const Fortran::parser::Expr::NEQV &intrinsicBinaryExpr) {
+            lowerExpression(intrinsicBinaryExpr);
+          },
+          [&](const auto &) {},
+      },
+      assignmentStmtExpr.u);
 
   mlir::Operation *atomicUpdateOp = nullptr;
   if constexpr (std::is_same<AtomicListT,
@@ -289,10 +275,10 @@ static inline void genOmpAccAtomicUpdateStatement(
       genOmpAtomicHintAndMemoryOrderClauses(converter, *rightHandClauseList,
                                             hint, memoryOrder);
     atomicUpdateOp = firOpBuilder.create<mlir::omp::AtomicUpdateOp>(
-        currentLocation, updateVar, hint, memoryOrder);
+        currentLocation, lhsAddr, hint, memoryOrder);
   } else {
     atomicUpdateOp = firOpBuilder.create<mlir::acc::AtomicUpdateOp>(
-        currentLocation, updateVar);
+        currentLocation, lhsAddr);
   }
 
   llvm::SmallVector<mlir::Type> varTys = {varType};
@@ -301,38 +287,36 @@ static inline void genOmpAccAtomicUpdateStatement(
   mlir::Value val =
       fir::getBase(atomicUpdateOp->getRegion(0).front().getArgument(0));
 
-  llvm::SmallVector<mlir::Operation *> ops;
-  for (mlir::Operation &op : tempOp.getRegion().getOps())
-    ops.push_back(&op);
-
-  // SCF Yield is converted to OMP Yield. All other operations are copied
-  for (mlir::Operation *op : ops) {
-    if (auto y = mlir::dyn_cast<mlir::scf::YieldOp>(op)) {
-      firOpBuilder.setInsertionPointToEnd(
-          &atomicUpdateOp->getRegion(0).front());
-      if constexpr (std::is_same<AtomicListT,
-                                 Fortran::parser::OmpAtomicClauseList>()) {
-        firOpBuilder.create<mlir::omp::YieldOp>(currentLocation,
-                                                y.getResults());
-      } else {
-        firOpBuilder.create<mlir::acc::YieldOp>(currentLocation,
-                                                y.getResults());
-      }
-      op->erase();
-    } else {
-      op->remove();
-      atomicUpdateOp->getRegion(0).front().push_back(op);
-    }
-  }
-
-  // Remove the load and replace all uses of load with the block argument
-  for (mlir::Operation &op : atomicUpdateOp->getRegion(0).getOps()) {
-    fir::LoadOp y = mlir::dyn_cast<fir::LoadOp>(&op);
-    if (y && y.getMemref() == updateVar)
-      y.getRes().replaceAllUsesWith(val);
+  mlir::Value op = nullptr;
+  if (std::get_if<Fortran::parser::Expr::Add>(&assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::AddIOp>(currentLocation, val,
+                                                  convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::Subtract>(
+                 &assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::SubIOp>(currentLocation, val,
+                                                  convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::Multiply>(
+                 &assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::MulIOp>(currentLocation, val,
+                                                  convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::Divide>(
+                 &assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::DivUIOp>(currentLocation, val,
+                                                   convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::AND>(&assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::AndIOp>(currentLocation, val,
+                                                  convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::OR>(&assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::OrIOp>(currentLocation, val,
+                                                 convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::EQV>(&assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::CmpIOp>(
+        currentLocation, mlir::arith::CmpIPredicate::eq, val, convertRhs);
+  } else if (std::get_if<Fortran::parser::Expr::NEQV>(&assignmentStmtExpr.u)) {
+    op = firOpBuilder.create<mlir::arith::CmpIOp>(
+        currentLocation, mlir::arith::CmpIPredicate::ne, val, convertRhs);
   }
-
-  tempOp.erase();
+  firOpBuilder.create<mlir::omp::YieldOp>(currentLocation, op);
 }
 
 /// Processes an atomic construct with write clause.
diff --git a/flang/test/Lower/OpenMP/common-atomic-lowering.f90 b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
new file mode 100644
index 000000000000000..7da30243e676c00
--- /dev/null
+++ b/flang/test/Lower/OpenMP/common-atomic-lowering.f90
@@ -0,0 +1,74 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+!CHECK: func.func @_QQmain() attributes {fir.bindc_name = "sample"} {
+!CHECK: %[[val_0:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFEa"}
+!CHECK: %[[val_1:.*]]:2 = hlfir.declare %[[val_0]] {uniq_name = "_QFEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_2:.*]] = fir.alloca i32 {bindc_name = "b", uniq_name = "_QFEb"}
+!CHECK: %[[val_3:.*]]:2 = hlfir.declare %[[val_2]] {uniq_name = "_QFEb"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_4:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFEx"}
+!CHECK: %[[val_5:.*]]:2 = hlfir.declare %[[val_4]] {uniq_name = "_QFEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[val_c5:.*]] = arith.constant 5 : index
+!CHECK: %[[val_6:.*]] = fir.alloca !fir.array<5xi32> {bindc_name = "y", uniq_name = "_QFEy"}
+!CHECK: %[[val_7:.*]] = fir.shape %[[val_c5]] : (index) -> !fir.shape<1>
+!CHECK: %[[val_8:.*]]:2 = hlfir.declare %[[val_6]](%[[val_7]]) {uniq_name = "_QFEy"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>)
+!CHECK: %[[val_c2:.*]] = arith.constant 2 : index
+!CHECK: %[[val_9:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_10:.*]] = fir.load %[[val_5]]#0 : !fir.ref<i32>
+!CHECK: %[[val_11:.*]] = arith.addi %[[val_c8]], %[[val_10]] : i32
+!CHECK: %[[val_12:.*]] = hlfir.no_reassoc %[[val_11]] : i32
+!CHECK: omp.atomic.update %[[val_9]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.muli %[[ARG]], %[[val_12]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c2_0:.*]] = arith.constant 2 : index
+!CHECK: %[[val_13:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c2_0]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_c8_1:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_13:.*]] : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.divui %[[ARG]], %[[val_c8_1]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_2:.*]] = arith.constant 8 : i32
+!CHECK: %[[val_c4:.*]] = arith.constant 4 : index
+!CHECK: %[[val_14:.*]] = hlfir.designate %[[val_8]]#0 (%[[val_c4]])  : (!fir.ref<!fir.array<5xi32>>, index) -> !fir.ref<i32>
+!CHECK: %[[val_15:.*]] = fir.load %[[val_14]] : !fir.ref<i32>
+!CHECK: %[[val_16:.*]] = arith.addi %[[val_c8_2]], %[[val_15]] : i32
+!CHECK: %[[val_17:.*]] = hlfir.no_reassoc %[[val_16]] : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG:.*]]: i32):
+!CHECK:      %[[val_18:.*]] = arith.addi %[[ARG]], %[[val_17]] : i32
+!CHECK:      omp.yield(%[[val_18]] : i32)
+!CHECK: }
+!CHECK: %[[val_c8_3:.*]] = arith.constant 8 : i32
+!CHECK: omp.atomic.update %[[val_5]]#1 : !fir.ref<i32> {
+!CHECK:   ^bb0(%[[ARG]]: i32):
+!CHECK:     %[[val_18:.*]] = arith.subi %[[ARG]], %[[val_c8_3]] : i32
+!CHECK:     omp.yield(%[[val_18]] : i32)
+!CHECK:   }
+!CHECK: return
+!CHECK: }
+program sample
+
+  integer :: x
+  integer, dimension(5) :: y
+  integer :: a, b
+
+  !$omp atomic update
+    y(2) =  (8 + x) * y(2)
+  !$omp end atomic
+
+  !$omp atomic update
+    y(2) =  y(2) / 8
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  (8 + y(4)) + x
+  !$omp end atomic
+
+  !$omp atomic update
+    x =  8 - x
+  !$omp end atomic
+
+end program sample

``````````

</details>


https://github.com/llvm/llvm-project/pull/69866


More information about the flang-commits mailing list