[llvm-branch-commits] [flang] [Flang][OMP] Replace SUM intrinsic call with SUM operations (PR #113082)
Tom Eccles via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Oct 23 04:48:09 PDT 2024
================
@@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
bool isLast = i + 1 == regions.size();
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
- OpBuilder singleBuilder(sourceRegion.getContext());
- Block *singleBlock = new Block();
- singleBuilder.setInsertionPointToStart(singleBlock);
-
OpBuilder allocaBuilder(sourceRegion.getContext());
Block *allocaBlock = new Block();
allocaBuilder.setInsertionPointToStart(allocaBlock);
- OpBuilder parallelBuilder(sourceRegion.getContext());
- Block *parallelBlock = new Block();
- parallelBuilder.setInsertionPointToStart(parallelBlock);
-
- auto [allParallelized, copyprivateVars] =
- moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
- singleBuilder, parallelBuilder);
- if (allParallelized) {
- // The single region was not required as all operations were safe to
- // parallelize
- assert(copyprivateVars.empty());
- assert(allocaBlock->empty());
- delete singleBlock;
+ it = block.begin();
+ while (&*it != terminator)
+ if (isa<hlfir::SumOp>(it))
+ break;
+ else
+ it++;
+
+ if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) {
+ /// Implementation:
+ /// Intrinsic function `SUM` operations
+ /// --
+ /// x = sum(array)
+ ///
+ /// is converted to
+ ///
+ /// !$omp parallel do
+ /// do i = 1, size(array)
+ /// x = x + array(i)
+ /// end do
+ /// !$omp end parallel do
+
+ OpBuilder wslBuilder(sourceRegion.getContext());
+ Block *wslBlock = new Block();
+ wslBuilder.setInsertionPointToStart(wslBlock);
+
+ Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs();
+ Value array = sumOp.getArray();
+ Value dim = sumOp.getDim();
+ fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>(
+ hlfir::getFortranElementOrSequenceType(array.getType()));
+ llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+ if (arrayShape.size() == 1 && !dim) {
+ Value itr = allocaBuilder.create<fir::AllocaOp>(
+ loc, allocaBuilder.getI64Type());
+ Value c_one = allocaBuilder.create<arith::ConstantOp>(
+ loc, allocaBuilder.getI64IntegerAttr(1));
+ Value c_arr_size = allocaBuilder.create<arith::ConstantOp>(
+ loc, allocaBuilder.getI64IntegerAttr(arrayShape[0]));
+ // Value c_zero = allocaBuilder.create<arith::ConstantOp>(loc,
+ // allocaBuilder.getZeroAttr(arrayTy.getEleTy()));
+ // allocaBuilder.create<fir::StoreOp>(loc, c_zero, target);
+
+ omp::WsloopOperands wslOps;
+ omp::WsloopOp wslOp =
+ rootBuilder.create<omp::WsloopOp>(loc, wslOps);
+
+ hlfir::LoopNest ln;
+ ln.outerOp = wslOp;
+ omp::LoopNestOperands lnOps;
+ lnOps.loopLowerBounds.push_back(c_one);
+ lnOps.loopUpperBounds.push_back(c_arr_size);
+ lnOps.loopSteps.push_back(c_one);
+ lnOps.loopInclusive = wslBuilder.getUnitAttr();
+ omp::LoopNestOp lnOp =
+ wslBuilder.create<omp::LoopNestOp>(loc, lnOps);
+ Block *lnBlock = wslBuilder.createBlock(&lnOp.getRegion());
+ lnBlock->addArgument(c_one.getType(), loc);
+ wslBuilder.create<fir::StoreOp>(
+ loc, lnOp.getRegion().getArgument(0), itr);
+ Value tarLoad = wslBuilder.create<fir::LoadOp>(loc, target);
+ Value itrLoad = wslBuilder.create<fir::LoadOp>(loc, itr);
+ hlfir::DesignateOp arrDesOp = wslBuilder.create<hlfir::DesignateOp>(
+ loc, fir::ReferenceType::get(arrayTy.getEleTy()), array,
+ itrLoad);
+ Value desLoad = wslBuilder.create<fir::LoadOp>(loc, arrDesOp);
+ Value addf =
+ wslBuilder.create<arith::AddFOp>(loc, tarLoad, desLoad);
+ wslBuilder.create<fir::StoreOp>(loc, addf, target);
+ wslBuilder.create<omp::YieldOp>(loc);
+ ln.body = lnBlock;
+ wslOp.getRegion().push_back(wslBlock);
+ targetRegion.front().getOperations().splice(
+ wslOp->getIterator(), allocaBlock->getOperations());
+ } else {
+ emitError(loc, "Only 1D array scalar assignment for sum "
----------------
tblah wrote:
instead of emitting an error here it would be better to go to the outer else branch and use the runtime library version of SUM
https://github.com/llvm/llvm-project/pull/113082
More information about the llvm-branch-commits
mailing list