[flang-commits] [flang] 86077c4 - [flang][OpenMP] Rewrite min/max with more than 2 arguments (#146423)
via flang-commits
flang-commits at lists.llvm.org
Tue Jul 1 07:55:02 PDT 2025
Author: Krzysztof Parzyszek
Date: 2025-07-01T09:54:58-05:00
New Revision: 86077c41a7899fb3a3ce4654bdb373e7cd954f49
URL: https://github.com/llvm/llvm-project/commit/86077c41a7899fb3a3ce4654bdb373e7cd954f49
DIFF: https://github.com/llvm/llvm-project/commit/86077c41a7899fb3a3ce4654bdb373e7cd954f49.diff
LOG: [flang][OpenMP] Rewrite min/max with more than 2 arguments (#146423)
Given an atomic operation `w = max(w, x1, x2, ...)` rewrite it as `w =
max(w, max(x1, x2, ...))`. This will avoid unnecessary non-atomic
comparisons inside of the atomic operation (min/max are expanded
inline).
In particular, if some of the x_i's are optional dummy parameters in the
containing function, this will avoid any presence tests within the
atomic operation.
Fixes https://github.com/llvm/llvm-project/issues/144838
Added:
flang/test/Lower/OpenMP/minmax-optional-parameters.f90
Modified:
flang/lib/Lower/OpenMP/Atomic.cpp
flang/test/Lower/OpenMP/atomic-update.f90
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp
index 33a743f8f9dda..2ab91b239a3cc 100644
--- a/flang/lib/Lower/OpenMP/Atomic.cpp
+++ b/flang/lib/Lower/OpenMP/Atomic.cpp
@@ -11,6 +11,8 @@
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/fold.h"
#include "flang/Evaluate/tools.h"
+#include "flang/Evaluate/traverse.h"
+#include "flang/Evaluate/type.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
@@ -41,6 +43,179 @@ namespace omp {
using namespace Fortran::lower::omp;
}
+namespace {
+// An example of a type that can be used to get the return value from
+// the visitor:
+// visitor(type_identity<Xyz>) -> result_type
+using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
+
+struct GetProc
+ : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
+ false> {
+ using Result = const evaluate::ProcedureDesignator *;
+ using Base = evaluate::Traverse<GetProc, Result, false>;
+ GetProc() : Base(*this) {}
+
+ using Base::operator();
+
+ static Result Default() { return nullptr; }
+
+ Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
+ static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
+};
+
+struct WithType {
+ WithType(const evaluate::DynamicType &t) : type(t) {
+ assert(type.category() != common::TypeCategory::Derived &&
+ "Type cannot be a derived type");
+ }
+
+ template <typename VisitorTy> //
+ auto visit(VisitorTy &&visitor) const
+ -> std::invoke_result_t<VisitorTy, SomeArgType> {
+ switch (type.category()) {
+ case common::TypeCategory::Integer:
+ switch (type.kind()) {
+ case 1:
+ return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
+ case 8:
+ return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
+ case 16:
+ return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
+ }
+ break;
+ case common::TypeCategory::Unsigned:
+ switch (type.kind()) {
+ case 1:
+ return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
+ case 8:
+ return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
+ case 16:
+ return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
+ }
+ break;
+ case common::TypeCategory::Real:
+ switch (type.kind()) {
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
+ case 3:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
+ case 8:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
+ case 10:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
+ case 16:
+ return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
+ }
+ break;
+ case common::TypeCategory::Complex:
+ switch (type.kind()) {
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
+ case 3:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
+ case 8:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
+ case 10:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
+ case 16:
+ return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
+ }
+ break;
+ case common::TypeCategory::Logical:
+ switch (type.kind()) {
+ case 1:
+ return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
+ case 8:
+ return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
+ }
+ break;
+ case common::TypeCategory::Character:
+ switch (type.kind()) {
+ case 1:
+ return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
+ case 2:
+ return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
+ case 4:
+ return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
+ }
+ break;
+ case common::TypeCategory::Derived:
+ (void)Derived;
+ break;
+ }
+ llvm_unreachable("Unhandled type");
+ }
+
+ const evaluate::DynamicType &type;
+
+private:
+ // Shorter names.
+ static constexpr auto Character = common::TypeCategory::Character;
+ static constexpr auto Complex = common::TypeCategory::Complex;
+ static constexpr auto Derived = common::TypeCategory::Derived;
+ static constexpr auto Integer = common::TypeCategory::Integer;
+ static constexpr auto Logical = common::TypeCategory::Logical;
+ static constexpr auto Real = common::TypeCategory::Real;
+ static constexpr auto Unsigned = common::TypeCategory::Unsigned;
+};
+
+template <typename T, typename U = std::remove_const_t<T>>
+U AsRvalue(T &t) {
+ U copy{t};
+ return std::move(copy);
+}
+
+template <typename T>
+T &&AsRvalue(T &&t) {
+ return std::move(t);
+}
+
+struct ArgumentReplacer
+ : public evaluate::Traverse<ArgumentReplacer, bool, false> {
+ using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
+ using Result = bool;
+
+ Result Default() const { return false; }
+
+ ArgumentReplacer(evaluate::ActualArguments &&newArgs)
+ : Base(*this), args_(std::move(newArgs)) {}
+
+ using Base::operator();
+
+ template <typename T>
+ Result operator()(const evaluate::FunctionRef<T> &x) {
+ assert(!done_);
+ auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
+ mut.arguments() = args_;
+ done_ = true;
+ return true;
+ }
+
+ Result Combine(Result &&a, Result &&b) { return a || b; }
+
+private:
+ bool done_{false};
+ evaluate::ActualArguments &&args_;
+};
+} // namespace
+
[[maybe_unused]] static void
dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
auto whatStr = [](int k) {
@@ -237,6 +412,85 @@ makeMemOrderAttr(lower::AbstractConverter &converter,
return nullptr;
}
+static bool replaceArgs(semantics::SomeExpr &expr,
+ evaluate::ActualArguments &&newArgs) {
+ return ArgumentReplacer(std::move(newArgs))(expr);
+}
+
+static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
+ const evaluate::ProcedureDesignator &proc,
+ const evaluate::ActualArguments &args) {
+ return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
+ using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
+ return evaluate::AsGenericExpr(
+ evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
+ });
+}
+
+static const evaluate::ProcedureDesignator &
+getProcedureDesignator(const semantics::SomeExpr &call) {
+ const evaluate::ProcedureDesignator *proc = GetProc{}(call);
+ assert(proc && "Call has no procedure designator");
+ return *proc;
+}
+
+static semantics::SomeExpr //
+genReducedMinMax(const semantics::SomeExpr &orig,
+ const semantics::SomeExpr *atomArg,
+ const std::vector<semantics::SomeExpr> &args) {
+ // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
+ // One of the a_i's, say a_t, must be atomArg.
+ // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
+ // call = min/max(a_t, tmp).
+ // Return "call".
+
+ // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
+ // Make sure that the "tmp = min/max(...)" doesn't promote an optional
+ // argument to a non-optional position. This could happen if a_t is at
+ // position 0 or 1.
+ if (args.size() <= 2)
+ return orig;
+
+ evaluate::ActualArguments nonAtoms;
+
+ auto AsActual = [](const semantics::SomeExpr &x) {
+ semantics::SomeExpr copy = x;
+ return evaluate::ActualArgument(std::move(copy));
+ };
+ // Semantic checks guarantee that the "atom" shows exactly once in the
+ // argument list (with potential conversions around it).
+ // For the first two (non-optional) arguments, if "atom" is among them,
+ // replace it with another occurrence of the other non-optional argument.
+ if (atomArg == &args[0]) {
+ // (atom, x, y...) -> (x, x, y...)
+ nonAtoms.push_back(AsActual(args[1]));
+ nonAtoms.push_back(AsActual(args[1]));
+ } else if (atomArg == &args[1]) {
+ // (x, atom, y...) -> (x, x, y...)
+ nonAtoms.push_back(AsActual(args[0]));
+ nonAtoms.push_back(AsActual(args[0]));
+ } else {
+ // (x, y, z...) -> unchanged
+ nonAtoms.push_back(AsActual(args[0]));
+ nonAtoms.push_back(AsActual(args[1]));
+ }
+
+ // The rest of arguments are optional, so we can just skip "atom".
+ for (size_t i = 2, e = args.size(); i != e; ++i) {
+ if (atomArg != &args[i])
+ nonAtoms.push_back(AsActual(args[i]));
+ }
+
+ // The type of the intermediate min/max is the same as the type of its
+ // arguments, which may be
diff erent from the type of the original
+ // expression. The original expression may have additional coverts.
+ auto tmp =
+ makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
+ semantics::SomeExpr call = orig;
+ replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
+ return call;
+}
+
static mlir::Operation * //
genAtomicRead(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
@@ -350,10 +604,29 @@ genAtomicUpdate(lower::AbstractConverter &converter,
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must exist by now.
- semantics::SomeExpr input = *evaluate::GetConvertInput(assign.rhs);
- std::vector<semantics::SomeExpr> args =
- evaluate::GetTopLevelOperation(input).second;
+ semantics::SomeExpr rhs = assign.rhs;
+ semantics::SomeExpr input = *evaluate::GetConvertInput(rhs);
+ auto [opcode, args] = evaluate::GetTopLevelOperation(input);
assert(!args.empty() && "Update operation without arguments");
+
+ // Pass args as an argument to avoid capturing a structured binding.
+ const semantics::SomeExpr *atomArg = [&](auto &args) {
+ for (const semantics::SomeExpr &e : args) {
+ if (evaluate::IsSameOrConvertOf(e, atom))
+ return &e;
+ }
+ llvm_unreachable("Atomic variable not in argument list");
+ }(args);
+
+ if (opcode == evaluate::operation::Operator::Min ||
+ opcode == evaluate::operation::Operator::Max) {
+ // Min and max operations are expanded inline, so reduce them to
+ // operations with exactly two (non-optional) arguments.
+ rhs = genReducedMinMax(rhs, atomArg, args);
+ input = *evaluate::GetConvertInput(rhs);
+ std::tie(opcode, args) = evaluate::GetTopLevelOperation(input);
+ atomArg = nullptr; // No longer valid.
+ }
for (auto &arg : args) {
if (!evaluate::IsSameOrConvertOf(arg, atom)) {
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
@@ -372,7 +645,7 @@ genAtomicUpdate(lower::AbstractConverter &converter,
converter.overrideExprValues(&overrides);
mlir::Value updated =
- fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
+ fir::getBase(converter.genExprValue(rhs, stmtCtx, &loc));
mlir::Value converted = builder.createConvert(loc, atomType, updated);
builder.create<mlir::omp::YieldOp>(loc, converted);
converter.resetExprOverrides();
diff --git a/flang/test/Lower/OpenMP/atomic-update.f90 b/flang/test/Lower/OpenMP/atomic-update.f90
index 3f840acefa6e8..f88bbea6fca85 100644
--- a/flang/test/Lower/OpenMP/atomic-update.f90
+++ b/flang/test/Lower/OpenMP/atomic-update.f90
@@ -107,8 +107,6 @@ program OmpAtomicUpdate
!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_Y_DECLARE]]#0 : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG:.*]]: i32):
!CHECK: {{.*}} = arith.cmpi sgt, %[[ARG]], {{.*}} : i32
-!CHECK: {{.*}} = arith.select {{.*}}, %[[ARG]], {{.*}} : i32
-!CHECK: {{.*}} = arith.cmpi sgt, {{.*}}
!CHECK: %[[TEMP:.*]] = arith.select {{.*}} : i32
!CHECK: omp.yield(%[[TEMP]] : i32)
!CHECK: }
@@ -177,13 +175,9 @@ program OmpAtomicUpdate
!CHECK: %[[VAL_Z_LOADED:.*]] = fir.load %[[VAL_Z_DECLARE]]#0 : !fir.ref<i32>
!CHECK: omp.atomic.update %[[VAL_W_DECLARE]]#0 : !fir.ref<i32> {
!CHECK: ^bb0(%[[ARG_W:.*]]: i32):
-!CHECK: %[[WX_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], %[[VAL_X_LOADED]] : i32
-!CHECK: %[[WX_MIN:.*]] = arith.select %[[WX_CMP]], %[[ARG_W]], %[[VAL_X_LOADED]] : i32
-!CHECK: %[[WXY_CMP:.*]] = arith.cmpi sgt, %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
-!CHECK: %[[WXY_MIN:.*]] = arith.select %[[WXY_CMP]], %[[WX_MIN]], %[[VAL_Y_LOADED]] : i32
-!CHECK: %[[WXYZ_CMP:.*]] = arith.cmpi sgt, %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
-!CHECK: %[[WXYZ_MIN:.*]] = arith.select %[[WXYZ_CMP]], %[[WXY_MIN]], %[[VAL_Z_LOADED]] : i32
-!CHECK: omp.yield(%[[WXYZ_MIN]] : i32)
+!CHECK: %[[W_CMP:.*]] = arith.cmpi sgt, %[[ARG_W]], {{.*}} : i32
+!CHECK: %[[WXYZ_MAX:.*]] = arith.select %[[W_CMP]], %[[ARG_W]], {{.*}} : i32
+!CHECK: omp.yield(%[[WXYZ_MAX]] : i32)
!CHECK: }
!$omp atomic update
w = max(w,x,y,z)
diff --git a/flang/test/Lower/OpenMP/minmax-optional-parameters.f90 b/flang/test/Lower/OpenMP/minmax-optional-parameters.f90
new file mode 100644
index 0000000000000..418a3cad8cdaf
--- /dev/null
+++ b/flang/test/Lower/OpenMP/minmax-optional-parameters.f90
@@ -0,0 +1,68 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=52 %s -o - | FileCheck %s
+
+! Check that the presence tests are done outside of the atomic update
+! construct.
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[VAL_A:[0-9]+]]:2 = hlfir.declare %arg0 dummy_scope %0
+!CHECK: %[[VAL_X:[0-9]+]]:2 = hlfir.declare %arg1 dummy_scope %0
+!CHECK: %[[VAL_Y:[0-9]+]]:2 = hlfir.declare %arg2 dummy_scope %0
+!CHECK: %[[V4:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<f32>
+!CHECK: %[[V5:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<f32>
+!CHECK: %[[V6:[0-9]+]] = fir.is_present %[[VAL_Y]]#0 : (!fir.ref<f32>) -> i1
+!CHECK: %[[V7:[0-9]+]] = arith.cmpf ogt, %[[V4]], %[[V5]] fastmath<contract> : f32
+!CHECK: %[[V8:[0-9]+]] = arith.select %[[V7]], %[[V4]], %[[V5]] : f32
+!CHECK: %[[V9:[0-9]+]] = fir.if %[[V6]] -> (f32) {
+!CHECK: %[[V10:[0-9]+]] = fir.load %[[VAL_Y]]#0 : !fir.ref<f32>
+!CHECK: %[[V11:[0-9]+]] = arith.cmpf ogt, %[[V8]], %[[V10]] fastmath<contract> : f32
+!CHECK: %[[V12:[0-9]+]] = arith.select %[[V11]], %[[V8]], %[[V10]] : f32
+!CHECK: fir.result %[[V12]] : f32
+!CHECK: } else {
+!CHECK: fir.result %[[V8]] : f32
+!CHECK: }
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_A]]#0 : !fir.ref<f32> {
+!CHECK: ^bb0(%[[ARG:[a-z0-9]+]]: f32):
+!CHECK: %[[V10:[0-9]+]] = arith.cmpf ogt, %[[ARG]], %[[V9]] fastmath<contract> : f32
+!CHECK: %[[V11:[0-9]+]] = arith.select %[[V10]], %[[ARG]], %[[V9]] : f32
+!CHECK: omp.yield(%[[V11]] : f32)
+!CHECK: }
+
+subroutine f00(a, x, y)
+ real :: a
+ real, optional :: x, y
+ !$omp atomic update
+ a = max(x, a, y)
+end
+
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[VAL_A:[0-9]+]]:2 = hlfir.declare %arg0 dummy_scope %0
+!CHECK: %[[VAL_X:[0-9]+]]:2 = hlfir.declare %arg1 dummy_scope %0
+!CHECK: %[[VAL_Y:[0-9]+]]:2 = hlfir.declare %arg2 dummy_scope %0
+!CHECK: %[[V4:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<i32>
+!CHECK: %[[V5:[0-9]+]] = fir.load %[[VAL_X]]#0 : !fir.ref<i32>
+!CHECK: %[[V6:[0-9]+]] = fir.is_present %[[VAL_Y]]#0 : (!fir.ref<i32>) -> i1
+!CHECK: %[[V7:[0-9]+]] = arith.cmpi slt, %[[V4]], %[[V5]] : i32
+!CHECK: %[[V8:[0-9]+]] = arith.select %[[V7]], %[[V4]], %[[V5]] : i32
+!CHECK: %[[V9:[0-9]+]] = fir.if %[[V6]] -> (i32) {
+!CHECK: %[[V10:[0-9]+]] = fir.load %[[VAL_Y]]#0 : !fir.ref<i32>
+!CHECK: %[[V11:[0-9]+]] = arith.cmpi slt, %[[V8]], %[[V10]] : i32
+!CHECK: %[[V12:[0-9]+]] = arith.select %[[V11]], %[[V8]], %[[V10]] : i32
+!CHECK: fir.result %[[V12]] : i32
+!CHECK: } else {
+!CHECK: fir.result %[[V8]] : i32
+!CHECK: }
+!CHECK: omp.atomic.update memory_order(relaxed) %[[VAL_A]]#0 : !fir.ref<i32> {
+!CHECK: ^bb0(%[[ARG:[a-z0-9]+]]: i32):
+!CHECK: %[[V10:[0-9]+]] = arith.cmpi slt, %[[ARG]], %[[V9]] : i32
+!CHECK: %[[V11:[0-9]+]] = arith.select %[[V10]], %[[ARG]], %[[V9]] : i32
+!CHECK: omp.yield(%[[V11]] : i32)
+!CHECK: }
+
+subroutine f01(a, x, y)
+ integer :: a
+ integer, optional :: x, y
+ !$omp atomic update
+ a = min(x, a, y)
+end
+
More information about the flang-commits
mailing list