[flang-commits] [flang] [mlir] Changes for invoking scan Op (PR #123254)
Anchu Rajendran S via flang-commits
flang-commits at lists.llvm.org
Wed Jan 22 11:21:46 PST 2025
https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/123254
>From fccf0c69a2f8209591de081b7236deb33ca63203 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 16 Jan 2025 17:13:56 -0600
Subject: [PATCH] Changes for invoking scan Op
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 39 +++++++++++++++++--
flang/lib/Lower/OpenMP/ClauseProcessor.h | 4 ++
flang/lib/Lower/OpenMP/Clauses.cpp | 8 ++--
flang/lib/Lower/OpenMP/OpenMP.cpp | 22 ++++++++++-
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 32 +++++++++++++--
flang/lib/Lower/OpenMP/ReductionProcessor.h | 6 ++-
.../Lower/OpenMP/Todo/reduction-inscan.f90 | 15 -------
.../Lower/OpenMP/Todo/reduction-modifiers.f90 | 14 -------
.../test/Lower/OpenMP/Todo/reduction-task.f90 | 2 +-
flang/test/Lower/OpenMP/scan.f90 | 34 ++++++++++++++++
.../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 3 +-
11 files changed, 134 insertions(+), 45 deletions(-)
delete mode 100644 flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
delete mode 100644 flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
create mode 100644 flang/test/Lower/OpenMP/scan.f90
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..98028f581e3871 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -344,6 +344,22 @@ bool ClauseProcessor::processDistSchedule(
return false;
}
+bool ClauseProcessor::processExclusive(
+ mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Exclusive>(
+ [&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ semantics::Symbol *sym = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*sym);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
+ symVal = declOp.getBase();
+ }
+ result.exclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +396,22 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}
+bool ClauseProcessor::processInclusive(
+ mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Inclusive>(
+ [&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ const semantics::Symbol *symbol = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ // if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
+ // symVal = declOp.getBase();
+ // }
+ result.inclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1167,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reduceVarByRef,
- reductionDeclSymbols, reductionSyms);
-
+ rp.addDeclareReduction(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7b047d4a7567ad..e05f66c7666844 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
+ bool processExclusive(mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processInclusive(mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b424e209d56da9..a26bdcdf343e13 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,
Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: exclusive");
+ // inp.v -> parser::OmpObjectList
+ return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Fail make(const parser::OmpClause::Fail &inp,
@@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,
Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: inclusive");
+ // inp.v -> parser::OmpObjectList
+ return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Indirect make(const parser::OmpClause::Indirect &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1434bcd6330e02..0e0e9a57287c3d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}
+static void genScanClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::ScanOperands &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInclusive(loc, clauseOps);
+ cp.processExclusive(loc, clauseOps);
+}
+
static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
@@ -1975,6 +1984,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}
+static mlir::omp::ScanOp
+genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+
+ mlir::omp::ScanOperands clauseOps;
+ genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+ return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
+ converter.getCurrentLocation(), clauseOps);
+}
+
/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
@@ -2978,7 +2998,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
- TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+ genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 2cd21107a916e4..e6ca5d5073c336 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
+#include <string>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
@@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
+mlir::omp::ReductionModifier
+translateReductionModifier(const ReductionModifier &m) {
+ switch (m) {
+ case ReductionModifier::Default:
+ return mlir::omp::ReductionModifier::defaultmod;
+ case ReductionModifier::Inscan:
+ return mlir::omp::ReductionModifier::inscan;
+ case ReductionModifier::Task:
+ return mlir::omp::ReductionModifier::task;
+ }
+ return mlir::omp::ReductionModifier::defaultmod;
+}
+
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
- reduction.t))
- TODO(currentLocation, "Reduction modifiers are not supported");
+ auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
+ std::string modStr = "default";
+ if (mod.value() == ReductionModifier::Task)
+ modStr = "task";
+ TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
+ }
mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
@@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+ auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (redMod.has_value())
+ *reductionMod = mlir::omp::ReductionModifierAttr::get(
+ firOpBuilder.getContext(),
+ translateReductionModifier(redMod.value()));
}
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 5f4d742b62cb10..44ab67979d5db9 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
@@ -126,7 +127,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod);
};
template <typename FloatOp, typename IntegerOp>
@@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
return builder.create<ComplexOp>(loc, op1, op2);
}
+using ReductionModifier = omp::clause::Reduction::ReductionModifier;
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
deleted file mode 100644
index 152d91a16f80fe..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-subroutine reduction_inscan()
- integer :: i,j
- i = 0
-
- !$omp do reduction(inscan, +:i)
- do j=1,10
- !$omp scan inclusive(i)
- i = i + 1
- end do
- !$omp end do
-end subroutine reduction_inscan
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
deleted file mode 100644
index 82625ed8c5f31c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-
-subroutine foo()
- integer :: i, j
- j = 0
- !$omp do reduction (inscan, *: j)
- do i = 1, 10
- !$omp scan inclusive(j)
- j = j + 1
- end do
-end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
index 6707f65e1a4cc3..b746872e9e7edf 100644
--- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90
+++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! CHECK: not yet implemented: Reduction modifiers are not supported
+! CHECK: not yet implemented: Reduction modifier task is not supported
subroutine reduction_task()
integer :: i
i = 0
diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90
new file mode 100644
index 00000000000000..9cf2174a7f3314
--- /dev/null
+++ b/flang/test/Lower/OpenMP/scan.f90
@@ -0,0 +1,34 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine inclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan inclusive({{.*}})
+ !$omp scan inclusive(x)
+ b(k) = x
+ end do
+end subroutine inclusive_scan
+
+
+subroutine exclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan exclusive({{.*}})
+ !$omp scan exclusive(x)
+ b(k) = x
+ end do
+end subroutine exclusive_scan
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 5d0003911bca87..056b7f989128a4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
- omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
+ omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
+ RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
More information about the flang-commits
mailing list