[flang-commits] [flang] a7d80f4 - [Flang][OpenMP] Add support for OpenMP max reduction
Kiran Chandramohan via flang-commits
flang-commits at lists.llvm.org
Tue Mar 14 09:02:15 PDT 2023
Author: Kiran Chandramohan
Date: 2023-03-14T16:01:58Z
New Revision: a7d80f43cb8d3a2deb09f57d0904732d40020752
URL: https://github.com/llvm/llvm-project/commit/a7d80f43cb8d3a2deb09f57d0904732d40020752
DIFF: https://github.com/llvm/llvm-project/commit/a7d80f43cb8d3a2deb09f57d0904732d40020752.diff
LOG: [Flang][OpenMP] Add support for OpenMP max reduction
This patch adds support for reduction of max-intrinsic for scalar
types. Max is lowered as a compare-select in the default lowering
flow for Flang. This pattern is matched and replaced with the
OpenMP dialect reduction operation.
Note: This is a temporary flow. The plan is to move to a flow where
the OpenMP reduction operation is inserted during lowering.
Reviewed By: do
Differential Revision: https://reviews.llvm.org/D145083
Added:
flang/test/Lower/OpenMP/wsloop-reduction-max.f90
Modified:
flang/lib/Lower/OpenMP.cpp
Removed:
flang/test/Lower/OpenMP/Todo/reduction-max.f90
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 1f6449eca301b..6f5fd87553802 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -1121,22 +1121,37 @@ static int getOperationIdentity(llvm::StringRef reductionOpName,
static Value getReductionInitValue(mlir::Location loc, mlir::Type type,
llvm::StringRef reductionOpName,
fir::FirOpBuilder &builder) {
- if (type.isa<FloatType>())
+ assert(type.isIntOrIndexOrFloat() &&
+ "only integer and float types are currently supported");
+ if (reductionOpName.contains("max")) {
+ if (auto ty = type.dyn_cast<mlir::FloatType>()) {
+ const llvm::fltSemantics &sem = ty.getFloatSemantics();
+ return builder.createRealConstant(
+ loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
+ }
+ unsigned bits = type.getIntOrFloatBitWidth();
+ int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
+ return builder.createIntegerConstant(loc, type, minInt);
+ } else {
+ if (type.isa<FloatType>())
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, type,
+ builder.getFloatAttr(
+ type, (double)getOperationIdentity(reductionOpName, loc)));
+
+ if (type.isa<fir::LogicalType>()) {
+ Value intConst = builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getI1Type(),
+ builder.getIntegerAttr(builder.getI1Type(),
+ getOperationIdentity(reductionOpName, loc)));
+ return builder.createConvert(loc, type, intConst);
+ }
+
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(
- type, (double)getOperationIdentity(reductionOpName, loc)));
-
- if (type.isa<fir::LogicalType>()) {
- Value intConst = builder.create<mlir::arith::ConstantOp>(
- loc, builder.getI1Type(),
- builder.getIntegerAttr(builder.getI1Type(),
+ builder.getIntegerAttr(type,
getOperationIdentity(reductionOpName, loc)));
- return builder.createConvert(loc, type, intConst);
}
- return builder.create<mlir::arith::ConstantOp>(
- loc, type,
- builder.getIntegerAttr(type, getOperationIdentity(reductionOpName, loc)));
}
template <typename FloatOp, typename IntegerOp>
@@ -1150,6 +1165,65 @@ static Value getReductionOperation(fir::FirOpBuilder &builder, mlir::Type type,
return builder.create<FloatOp>(loc, op1, op2);
}
+static omp::ReductionDeclareOp
+createMinimalReductionDecl(fir::FirOpBuilder &builder,
+ llvm::StringRef reductionOpName, mlir::Type type,
+ mlir::Location loc) {
+ mlir::ModuleOp module = builder.getModule();
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+ mlir::omp::ReductionDeclareOp decl =
+ modBuilder.create<omp::ReductionDeclareOp>(loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ Value init = getReductionInitValue(loc, type, reductionOpName, builder);
+ builder.create<omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ return decl;
+}
+
+/// Creates an OpenMP reduction declaration and inserts it into the provided
+/// symbol table. The declaration has a constant initializer with the neutral
+/// value `initValue`, and the reduction combiner carried over from `reduce`.
+/// TODO: Generalize this for non-integer types, add atomic region.
+static omp::ReductionDeclareOp
+createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const Fortran::parser::ProcedureDesignator &procDesignator,
+ mlir::Type type, mlir::Location loc) {
+ OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
+
+ decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ Value reductionOp;
+ if (const auto *name{
+ Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) {
+ if (name->source == "max") {
+ reductionOp =
+ getReductionOperation<mlir::arith::MaxFOp, mlir::arith::MaxSIOp>(
+ builder, type, loc, op1, op2);
+ } else {
+ TODO(loc, "Reduction of some intrinsic operators is not supported");
+ }
+ }
+
+ builder.create<omp::YieldOp>(loc, reductionOp);
+ return decl;
+}
+
/// Creates an OpenMP reduction declaration and inserts it into the provided
/// symbol table. The declaration has a constant initializer with the neutral
/// value `initValue`, and the reduction combiner carried over from `reduce`.
@@ -1160,23 +1234,13 @@ static omp::ReductionDeclareOp createReductionDecl(
mlir::Type type, mlir::Location loc) {
OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
- mlir::OpBuilder modBuilder(module.getBodyRegion());
+
auto decl =
module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (!decl)
- decl =
- modBuilder.create<omp::ReductionDeclareOp>(loc, reductionOpName, type);
- else
+ if (decl)
return decl;
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- Value init = getReductionInitValue(loc, type, reductionOpName, builder);
- builder.create<omp::YieldOp>(loc, init);
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
+ decl = createMinimalReductionDecl(builder, reductionOpName, type, loc);
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
@@ -1284,6 +1348,13 @@ getSIMDModifier(const Fortran::parser::OmpScheduleClause &x) {
return mlir::omp::ScheduleModifier::none;
}
+static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
+ return (llvm::Twine(name) +
+ (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
+ llvm::Twine(ty.getIntOrFloatBitWidth()))
+ .str();
+}
+
static std::string getReductionName(
Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
mlir::Type ty) {
@@ -1305,10 +1376,7 @@ static std::string getReductionName(
break;
}
- return (llvm::Twine(reductionName) +
- (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) +
- llvm::Twine(ty.getIntOrFloatBitWidth()))
- .str();
+ return getReductionName(reductionName, ty);
}
static void genOMP(Fortran::lower::AbstractConverter &converter,
@@ -1443,9 +1511,34 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
}
}
}
- } else {
- TODO(currentLocation,
- "Reduction of intrinsic procedures is not supported");
+ } else if (auto reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ reductionIntrinsic)}) {
+ if (name->source != "max") {
+ TODO(currentLocation,
+ "Reduction of intrinsic procedures is not supported");
+ }
+ for (const auto &ompObject : objectList.v) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ ompObject)}) {
+ if (const auto *symbol{name->symbol}) {
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder, getReductionName("max", redType),
+ *reductionIntrinsic, redType, currentLocation);
+ reductionDeclSymbols.push_back(SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ }
+ }
}
} else if (const auto &simdlenClause =
std::get_if<Fortran::parser::OmpClause::Simdlen>(
@@ -2104,6 +2197,21 @@ void Fortran::lower::genOpenMPDeclarativeConstruct(
ompDeclConstruct.u);
}
+static mlir::Operation *getCompareFromReductionOp(mlir::Operation *reductionOp,
+ mlir::Value loadVal) {
+ for (auto reductionOperand : reductionOp->getOperands()) {
+ if (auto compareOp = reductionOperand.getDefiningOp()) {
+ if (compareOp->getOperand(0) == loadVal ||
+ compareOp->getOperand(1) == loadVal)
+ assert((mlir::isa<mlir::arith::CmpIOp>(compareOp) ||
+ mlir::isa<mlir::arith::CmpFOp>(compareOp)) &&
+ "Expected comparison not found in reduction intrinsic");
+ return compareOp;
+ }
+ }
+ return nullptr;
+}
+
// Generate an OpenMP reduction operation.
// TODO: Currently assumes it is either an integer addition/multiplication
// reduction, or a logical and reduction. Generalize this for various reduction
@@ -2170,6 +2278,40 @@ void Fortran::lower::genOpenMPReduction(
}
}
}
+ } else if (auto reductionIntrinsic =
+ std::get_if<Fortran::parser::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ reductionIntrinsic)}) {
+ if (name->source != "max") {
+ continue;
+ }
+ for (const auto &ompObject : objectList.v) {
+ if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(
+ ompObject)}) {
+ if (const auto *symbol{name->symbol}) {
+ mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
+ for (mlir::OpOperand &reductionValUse :
+ reductionVal.getUses()) {
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
+ reductionValUse.getOwner())) {
+ mlir::Value loadVal = loadOp.getRes();
+ // Max is lowered as a compare -> select.
+ // Match the pattern here.
+ mlir::Operation *reductionOp =
+ findReductionChain(loadVal, &reductionVal);
+ assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
+ "Selection Op not found in reduction intrinsic");
+ mlir::Operation *compareOp =
+ getCompareFromReductionOp(reductionOp, loadVal);
+ updateReduction(compareOp, firOpBuilder, loadVal,
+ reductionVal);
+ }
+ }
+ }
+ }
+ }
+ }
}
}
}
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-max.f90 b/flang/test/Lower/OpenMP/Todo/reduction-max.f90
deleted file mode 100644
index e965e6860712e..0000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-max.f90
+++ /dev/null
@@ -1,16 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction of intrinsic procedures is not supported
-subroutine reduction_max(y)
- integer :: x, y(:)
- x = 0
- !$omp parallel
- !$omp do reduction(max:x)
- do i=1, 100
- x = max(x, y(i))
- end do
- !$omp end do
- !$omp end parallel
- print *, x
-end subroutine
diff --git a/flang/test/Lower/OpenMP/wsloop-reduction-max.f90 b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90
new file mode 100644
index 0000000000000..0db5b7f813d74
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-reduction-max.f90
@@ -0,0 +1,66 @@
+! RUN: bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK: omp.reduction.declare @[[MAX_DECLARE_F:.*]] : f32 init {
+!CHECK: %[[MINIMUM_VAL_F:.*]] = arith.constant -3.40282347E+38 : f32
+!CHECK: omp.yield(%[[MINIMUM_VAL_F]] : f32)
+!CHECK: combiner
+!CHECK: ^bb0(%[[ARG0_F:.*]]: f32, %[[ARG1_F:.*]]: f32):
+!CHECK: %[[COMB_VAL_F:.*]] = arith.maxf %[[ARG0_F]], %[[ARG1_F]] {{.*}}: f32
+!CHECK: omp.yield(%[[COMB_VAL_F]] : f32)
+
+!CHECK: omp.reduction.declare @[[MAX_DECLARE_I:.*]] : i32 init {
+!CHECK: %[[MINIMUM_VAL_I:.*]] = arith.constant -2147483648 : i32
+!CHECK: omp.yield(%[[MINIMUM_VAL_I]] : i32)
+!CHECK: combiner
+!CHECK: ^bb0(%[[ARG0_I:.*]]: i32, %[[ARG1_I:.*]]: i32):
+!CHECK: %[[COMB_VAL_I:.*]] = arith.maxsi %[[ARG0_I]], %[[ARG1_I]] : i32
+!CHECK: omp.yield(%[[COMB_VAL_I]] : i32)
+
+!CHECK-LABEL: @_QPreduction_max_int
+!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box<!fir.array<?xi32>>
+!CHECK: %[[X_REF:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFreduction_max_intEx"}
+!CHECK: omp.parallel
+!CHECK: omp.wsloop reduction(@[[MAX_DECLARE_I]] -> %[[X_REF]] : !fir.ref<i32>) for
+!CHECK: %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]]
+!CHECK: %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref<i32>
+!CHECK: omp.reduction %[[Y_I]], %[[X_REF]] : i32, !fir.ref<i32>
+!CHECK: omp.yield
+!CHECK: omp.terminator
+
+!CHECK-LABEL: @_QPreduction_max_real
+!CHECK-SAME: %[[Y_BOX:.*]]: !fir.box<!fir.array<?xf32>>
+!CHECK: %[[X_REF:.*]] = fir.alloca f32 {bindc_name = "x", uniq_name = "_QFreduction_max_realEx"}
+!CHECK: omp.parallel
+!CHECK: omp.wsloop reduction(@[[MAX_DECLARE_F]] -> %[[X_REF]] : !fir.ref<f32>) for
+!CHECK: %[[Y_I_REF:.*]] = fir.coordinate_of %[[Y_BOX]]
+!CHECK: %[[Y_I:.*]] = fir.load %[[Y_I_REF]] : !fir.ref<f32>
+!CHECK: omp.reduction %[[Y_I]], %[[X_REF]] : f32, !fir.ref<f32>
+!CHECK: omp.yield
+!CHECK: omp.terminator
+
+subroutine reduction_max_int(y)
+ integer :: x, y(:)
+ x = 0
+ !$omp parallel
+ !$omp do reduction(max:x)
+ do i=1, 100
+ x = max(x, y(i))
+ end do
+ !$omp end do
+ !$omp end parallel
+ print *, x
+end subroutine
+
+subroutine reduction_max_real(y)
+ real :: x, y(:)
+ x = 0.0
+ !$omp parallel
+ !$omp do reduction(max:x)
+ do i=1, 100
+ x = max(y(i), x)
+ end do
+ !$omp end do
+ !$omp end parallel
+ print *, x
+end subroutine
More information about the flang-commits
mailing list