[llvm-branch-commits] [flang] [Flang][OMP] Replace SUM intrinsic call with SUM operations (PR #113082)
Tom Eccles via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Oct 23 04:48:09 PDT 2024
================
@@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
bool isLast = i + 1 == regions.size();
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
- OpBuilder singleBuilder(sourceRegion.getContext());
- Block *singleBlock = new Block();
- singleBuilder.setInsertionPointToStart(singleBlock);
-
OpBuilder allocaBuilder(sourceRegion.getContext());
Block *allocaBlock = new Block();
allocaBuilder.setInsertionPointToStart(allocaBlock);
- OpBuilder parallelBuilder(sourceRegion.getContext());
- Block *parallelBlock = new Block();
- parallelBuilder.setInsertionPointToStart(parallelBlock);
-
- auto [allParallelized, copyprivateVars] =
- moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
- singleBuilder, parallelBuilder);
- if (allParallelized) {
- // The single region was not required as all operations were safe to
- // parallelize
- assert(copyprivateVars.empty());
- assert(allocaBlock->empty());
- delete singleBlock;
+ it = block.begin();
+ while (&*it != terminator)
+ if (isa<hlfir::SumOp>(it))
+ break;
+ else
+ it++;
+
+ if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) {
+ /// Implementation:
+ /// Intrinsic function `SUM` operations
+ /// --
+ /// x = sum(array)
+ ///
+ /// is converted to
+ ///
+ /// !$omp parallel do
+ /// do i = 1, size(array)
+ /// x = x + array(i)
+ /// end do
+ /// !$omp end parallel do
+
+ OpBuilder wslBuilder(sourceRegion.getContext());
+ Block *wslBlock = new Block();
+ wslBuilder.setInsertionPointToStart(wslBlock);
+
+ Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs();
+ Value array = sumOp.getArray();
+ Value dim = sumOp.getDim();
+ fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>(
+ hlfir::getFortranElementOrSequenceType(array.getType()));
+ llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+ if (arrayShape.size() == 1 && !dim) {
----------------
tblah wrote:
I think you also need to check that the element type is floating point.
https://github.com/llvm/llvm-project/pull/113082
More information about the llvm-branch-commits
mailing list