[flang-commits] [flang] 86077c4 - [flang][OpenMP] Rewrite min/max with more than 2 arguments (#146423)

via flang-commits flang-commits at lists.llvm.org
Tue Jul 1 07:55:02 PDT 2025


Author: Krzysztof Parzyszek
Date: 2025-07-01T09:54:58-05:00
New Revision: 86077c41a7899fb3a3ce4654bdb373e7cd954f49

URL: https://github.com/llvm/llvm-project/commit/86077c41a7899fb3a3ce4654bdb373e7cd954f49
DIFF: https://github.com/llvm/llvm-project/commit/86077c41a7899fb3a3ce4654bdb373e7cd954f49.diff

LOG: [flang][OpenMP] Rewrite min/max with more than 2 arguments (#146423)

Given an atomic operation `w = max(w, x1, x2, ...)` rewrite it as `w =
max(w, max(x1, x2, ...))`. This will avoid unnecessary non-atomic
comparisons inside of the atomic operation (min/max are expanded
inline).

In particular, if some of the x_i's are optional dummy parameters in the
containing function, this will avoid any presence tests within the
atomic operation.

Fixes https://github.com/llvm/llvm-project/issues/144838

Added: 
    flang/test/Lower/OpenMP/minmax-optional-parameters.f90

Modified: 
    flang/lib/Lower/OpenMP/Atomic.cpp
    flang/test/Lower/OpenMP/atomic-update.f90

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 33a743f8f9dda..2ab91b239a3cc 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -11,6 +11,8 @@
 #include "flang/Evaluate/expression.h"
 #include "flang/Evaluate/fold.h"
 #include "flang/Evaluate/tools.h"
+#include "flang/Evaluate/traverse.h"
+#include "flang/Evaluate/type.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
@@ -41,6 +43,179 @@ namespace omp {
 using namespace Fortran::lower::omp;
 }
 
+namespace {
+// An example of a type that can be used to get the return value from
+// the visitor:
+//   visitor(type_identity<Xyz>) -> result_type
+using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
+
+struct GetProc
+    : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
+                                false> {
+  using Result = const evaluate::ProcedureDesignator *;
+  using Base = evaluate::Traverse<GetProc, Result, false>;
+  GetProc() : Base(*this) {}
+
+  using Base::operator();
+
+  static Result Default() { return nullptr; }
+
+  Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
+  static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
+};
+
+struct WithType {
+  WithType(const evaluate::DynamicType &t) : type(t) {
+    assert(type.category() != common::TypeCategory::Derived &&
+           "Type cannot be a derived type");
+  }
+
+  template <typename VisitorTy> //
+  auto visit(VisitorTy &&visitor) const
+      -> std::invoke_result_t<VisitorTy, SomeArgType> {
+    switch (type.category()) {
+    case common::TypeCategory::Integer:
+      switch (type.kind()) {
+      case 1:
+        return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
+      case 8:
+        return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
+      case 16:
+        return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
+      }
+      break;
+    case common::TypeCategory::Unsigned:
+      switch (type.kind()) {
+      case 1:
+        return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
+      case 8:
+        return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
+      case 16:
+        return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
+      }
+      break;
+    case common::TypeCategory::Real:
+      switch (type.kind()) {
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
+      case 3:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
+      case 8:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
+      case 10:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
+      case 16:
+        return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
+      }
+      break;
+    case common::TypeCategory::Complex:
+      switch (type.kind()) {
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
+      case 3:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
+      case 8:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
+      case 10:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
+      case 16:
+        return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
+      }
+      break;
+    case common::TypeCategory::Logical:
+      switch (type.kind()) {
+      case 1:
+        return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
+      case 8:
+        return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
+      }
+      break;
+    case common::TypeCategory::Character:
+      switch (type.kind()) {
+      case 1:
+        return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
+      case 2:
+        return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
+      case 4:
+        return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
+      }
+      break;
+    case common::TypeCategory::Derived:
+      (void)Derived;
+      break;
+    }
+    llvm_unreachable("Unhandled type");
+  }
+
+  const evaluate::DynamicType &type;
+
+private:
+  // Shorter names.
+  static constexpr auto Character = common::TypeCategory::Character;
+  static constexpr auto Complex = common::TypeCategory::Complex;
+  static constexpr auto Derived = common::TypeCategory::Derived;
+  static constexpr auto Integer = common::TypeCategory::Integer;
+  static constexpr auto Logical = common::TypeCategory::Logical;
+  static constexpr auto Real = common::TypeCategory::Real;
+  static constexpr auto Unsigned = common::TypeCategory::Unsigned;
+};
+
+template <typename T, typename U = std::remove_const_t<T>>
+U AsRvalue(T &t) {
+  U copy{t};
+  return std::move(copy);
+}
+
+template <typename T>
+T &&AsRvalue(T &&t) {
+  return std::move(t);
+}
+
+struct ArgumentReplacer
+    : public evaluate::Traverse<ArgumentReplacer, bool, false> {
+  using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
+  using Result = bool;
+
+  Result Default() const { return false; }
+
+  ArgumentReplacer(evaluate::ActualArguments &&newArgs)
+      : Base(*this), args_(std::move(newArgs)) {}
+
+  using Base::operator();
+
+  template <typename T>
+  Result operator()(const evaluate::FunctionRef<T> &x) {
+    assert(!done_);
+    auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
+    mut.arguments() = args_;
+    done_ = true;
+    return true;
+  }
+
+  Result Combine(Result &&a, Result &&b) { return a || b; }
+
+private:
+  bool done_{false};
+  evaluate::ActualArguments &&args_;
+};
+} // namespace
+
 [[maybe_unused]] static void
 dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
   auto whatStr = [](int k) {
@@ -237,6 +412,85 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
   return nullptr;
 }
 
+static bool replaceArgs(semantics::SomeExpr &expr,
+                        evaluate::ActualArguments &&newArgs) {
+  return ArgumentReplacer(std::move(newArgs))(expr);
+}
+
+static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
+                                    const evaluate::ProcedureDesignator &proc,
+                                    const evaluate::ActualArguments &args) {
+  return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
+    using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
+    return evaluate::AsGenericExpr(
+        evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
+  });
+}
+
+static const evaluate::ProcedureDesignator &
+getProcedureDesignator(const semantics::SomeExpr &call) {
+  const evaluate::ProcedureDesignator *proc = GetProc{}(call);
+  assert(proc && "Call has no procedure designator");
+  return *proc;
+}
+
+static semantics::SomeExpr //
+genReducedMinMax(const semantics::SomeExpr &orig,
+                 const semantics::SomeExpr *atomArg,
+                 const std::vector<semantics::SomeExpr> &args) {
+  // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
+  // One of the a_i's, say a_t, must be atomArg.
+  // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
+  // call = min/max(a_t, tmp).
+  // Return "call".
+
+  // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
+  // Make sure that the "tmp = min/max(...)" doesn't promote an optional
+  // argument to a non-optional position. This could happen if a_t is at
+  // position 0 or 1.
+  if (args.size() <= 2)
+    return orig;
+
+  evaluate::ActualArguments nonAtoms;
+
+  auto AsActual = [](const semantics::SomeExpr &x) {
+    semantics::SomeExpr copy = x;
+    return evaluate::ActualArgument(std::move(copy));
+  };
+  // Semantic checks guarantee that the "atom" shows exactly once in the
+  // argument list (with potential conversions around it).
+  // For the first two (non-optional) arguments, if "atom" is among them,
+  // replace it with another occurrence of the other non-optional argument.
+  if (atomArg == &args[0]) {
+    // (atom, x, y...) -> (x, x, y...)
+    nonAtoms.push_back(AsActual(args[1]));
+    nonAtoms.push_back(AsActual(args[1]));
+  } else if (atomArg == &args[1]) {
+    // (x, atom, y...) -> (x, x, y...)
+    nonAtoms.push_back(AsActual(args[0]));
+    nonAtoms.push_back(AsActual(args[0]));
+  } else {
+    // (x, y, z...) -> unchanged
+    nonAtoms.push_back(AsActual(args[0]));
+    nonAtoms.push_back(AsActual(args[1]));
+  }
+
+  // The rest of arguments are optional, so we can just skip "atom".
+  for (size_t i = 2, e = args.size(); i != e; ++i) {
+    if (atomArg != &args[i])
+      nonAtoms.push_back(AsActual(args[i]));
+  }
+
+  // The type of the intermediate min/max is the same as the type of its
+  // arguments, which may be 
diff erent from the type of the original
+  // expression. The original expression may have additional coverts.
+  auto tmp =
+      makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
+  semantics::SomeExpr call = orig;
+  replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
+  return call;
+}
+
 static mlir::Operation * //
 genAtomicRead(lower::AbstractConverter &converter,
               semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -350,10 +604,29 @@ genAtomicUpdate(lower::AbstractConverter &converter,
   mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
 
   // This must exist by now.
-  semantics::SomeExpr input = *evaluate::GetConvertInput(assign.rhs);
-  std::vector<semantics::SomeExpr> args =
-      evaluate::GetTopLevelOperation(input).second;
+  semantics::SomeExpr rhs = assign.rhs;
+  semantics::SomeExpr input = *evaluate::GetConvertInput(rhs);
+  auto [opcode, args] = evaluate::GetTopLevelOperation(input);
   assert(!args.empty() && "Update operation without arguments");
+
+  // Pass args as an argument to avoid capturing a structured binding.
+  const semantics::SomeExpr *atomArg = [&](auto &args) {
+    for (const semantics::SomeExpr &e : args) {
+      if (evaluate::IsSameOrConvertOf(e, atom))
+        return &e;
+    }
+    llvm_unreachable("Atomic variable not in argument list");
+  }(args);
+
+  if (opcode == evaluate::operation::Operator::Min ||
+      opcode == evaluate::operation::Operator::Max) {
+    // Min and max operations are expanded inline, so reduce them to
+    // operations with exactly two (non-optional) arguments.
+    rhs = genReducedMinMax(rhs, atomArg, args);
+    input = *evaluate::GetConvertInput(rhs);
+    std::tie(opcode, args) = evaluate::GetTopLevelOperation(input);
+    atomArg = nullptr; // No longer valid.
+  }
   for (auto &arg : args) {
     if (!evaluate::IsSameOrConvertOf(arg, atom)) {
       mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
@@ -372,7 +645,7 @@ genAtomicUpdate(lower::AbstractConverter &converter,
 
   converter.overrideExprValues(&overrides);
   mlir::Value updated =
-      fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
+      fir::getBase(converter.genExprValue(rhs, stmtCtx, &loc));
   mlir::Value converted = builder.createConvert(loc, atomType, updated);
   builder.create<mlir::omp::YieldOp>(loc, converted);
   converter.resetExprOverrides();

diff  --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index 3f840acefa6e8..f88bbea6fca85 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -107,8 +107,6 @@ program OmpAtomicUpdate
 !CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#0 : !fir.ref<i32> {
 !CHECK: ^bb0(%[[ARG:.*]]: i32):
 !CHECK:  {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32
-!CHECK:  {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32
-!CHECK:  {{.*}} = arith.cmpi sgt, {{.*}}
 !CHECK:  %[[TEMP:.*]] = arith.select {{.*}} : i32
 !CHECK:  omp.yield(%[[TEMP]] : i32)
 !CHECK: }
@@ -177,13 +175,9 @@ program OmpAtomicUpdate
 !CHECK:   %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref<i32>
 !CHECK:   omp.atomic.update %[[VAL_W_DECLARE]]#0 : !fir.ref<i32> {
 !CHECK:   ^bb0(%[[ARG_W:.*]]: i32):
-!CHECK:     %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32
-!CHECK:     %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32
-!CHECK:     %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
-!CHECK:     %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
-!CHECK:     %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
-!CHECK:     %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
-!CHECK:     omp.yield(%[[WXYZ_MIN]] : i32)
+!CHECK:     %[[W_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], {{.*}} : i32
+!CHECK:     %[[WXYZ_MAX:.*]] = arith.select %[[W_CMP]], %[[ARG_W]], {{.*}} : i32
+!CHECK:     omp.yield(%[[WXYZ_MAX]] : i32)
 !CHECK:   }
   !$omp atomic update
     w = max(w,x,y,z)

diff  --git a/flang/test/Lower/OpenMP/minmax-optional-parameters.f90 b/flang/test/Lower/OpenMP/minmax-optional-parameters.f90
new file mode 100644
index 0000000000000..418a3cad8cdaf
--- /dev/null
+++ b/flang/test/Lower/OpenMP/minmax-optional-parameters.f90
@@ -0,0 +1,68 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+! Check that the presence tests are done outside of the atomic update
+! construct.
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[VAL_A:[0-9]+]]:2 = hlfir.declare %arg0 dummy_scope %0
+!CHECK: %[[VAL_X:[0-9]+]]:2 = hlfir.declare %arg1 dummy_scope %0
+!CHECK: %[[VAL_Y:[0-9]+]]:2 = hlfir.declare %arg2 dummy_scope %0
+!CHECK: %[[V4:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<f32>
+!CHECK: %[[V5:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<f32>
+!CHECK: %[[V6:[0-9]+]] = fir.is_present %[[VAL_Y]]#0 : (!fir.ref<f32>) -> i1
+!CHECK: %[[V7:[0-9]+]] = arith.cmpf ogt, %[[V4]], %[[V5]] fastmath<contract> : f32
+!CHECK: %[[V8:[0-9]+]] = arith.select %[[V7]], %[[V4]], %[[V5]] : f32
+!CHECK: %[[V9:[0-9]+]] = fir.if %[[V6]] -> (f32) {
+!CHECK:   %[[V10:[0-9]+]] = fir.load %[[VAL_Y]]#0 : !fir.ref<f32>
+!CHECK:   %[[V11:[0-9]+]] = arith.cmpf ogt, %[[V8]], %[[V10]] fastmath<contract> : f32
+!CHECK:   %[[V12:[0-9]+]] = arith.select %[[V11]], %[[V8]], %[[V10]] : f32
+!CHECK:   fir.result %[[V12]] : f32
+!CHECK: } else {
+!CHECK:   fir.result %[[V8]] : f32
+!CHECK: }
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_A]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:[a-z0-9]+]]: f32):
+!CHECK:   %[[V10:[0-9]+]] = arith.cmpf ogt, %[[ARG]], %[[V9]] fastmath<contract> : f32
+!CHECK:   %[[V11:[0-9]+]] = arith.select %[[V10]], %[[ARG]], %[[V9]] : f32
+!CHECK:   omp.yield(%[[V11]] : f32)
+!CHECK: }
+
+subroutine f00(a, x, y)
+  real :: a
+  real, optional :: x, y
+  !$omp atomic update
+  a = max(x, a, y)
+end
+
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[VAL_A:[0-9]+]]:2 = hlfir.declare %arg0 dummy_scope %0
+!CHECK: %[[VAL_X:[0-9]+]]:2 = hlfir.declare %arg1 dummy_scope %0
+!CHECK: %[[VAL_Y:[0-9]+]]:2 = hlfir.declare %arg2 dummy_scope %0
+!CHECK: %[[V4:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<i32>
+!CHECK: %[[V5:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<i32>
+!CHECK: %[[V6:[0-9]+]] = fir.is_present %[[VAL_Y]]#0 : (!fir.ref<i32>) -> i1
+!CHECK: %[[V7:[0-9]+]] = arith.cmpi slt, %[[V4]], %[[V5]] : i32
+!CHECK: %[[V8:[0-9]+]] = arith.select %[[V7]], %[[V4]], %[[V5]] : i32
+!CHECK: %[[V9:[0-9]+]] = fir.if %[[V6]] -> (i32) {
+!CHECK:   %[[V10:[0-9]+]] = fir.load %[[VAL_Y]]#0 : !fir.ref<i32>
+!CHECK:   %[[V11:[0-9]+]] = arith.cmpi slt, %[[V8]], %[[V10]] : i32
+!CHECK:   %[[V12:[0-9]+]] = arith.select %[[V11]], %[[V8]], %[[V10]] : i32
+!CHECK:   fir.result %[[V12]] : i32
+!CHECK: } else {
+!CHECK:   fir.result %[[V8]] : i32
+!CHECK: }
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_A]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:[a-z0-9]+]]: i32):
+!CHECK:   %[[V10:[0-9]+]] = arith.cmpi slt, %[[ARG]], %[[V9]] : i32
+!CHECK:   %[[V11:[0-9]+]] = arith.select %[[V10]], %[[ARG]], %[[V9]] : i32
+!CHECK:   omp.yield(%[[V11]] : i32)
+!CHECK: }
+
+subroutine f01(a, x, y)
+  integer :: a
+  integer, optional :: x, y
+  !$omp atomic update
+  a = min(x, a, y)
+end
+


        


More information about the flang-commits mailing list