[llvm-branch-commits] [flang] [Flang][OMP] Replace SUM intrinsic call with SUM operations (PR #113082)

Tom Eccles via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Oct 23 04:48:09 PDT 2024


================
@@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
-        OpBuilder singleBuilder(sourceRegion.getContext());
-        Block *singleBlock = new Block();
-        singleBuilder.setInsertionPointToStart(singleBlock);
-
         OpBuilder allocaBuilder(sourceRegion.getContext());
         Block *allocaBlock = new Block();
         allocaBuilder.setInsertionPointToStart(allocaBlock);
 
-        OpBuilder parallelBuilder(sourceRegion.getContext());
-        Block *parallelBlock = new Block();
-        parallelBuilder.setInsertionPointToStart(parallelBlock);
-
-        auto [allParallelized, copyprivateVars] =
-            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
-                         singleBuilder, parallelBuilder);
-        if (allParallelized) {
-          // The single region was not required as all operations were safe to
-          // parallelize
-          assert(copyprivateVars.empty());
-          assert(allocaBlock->empty());
-          delete singleBlock;
+        it = block.begin();
+        while (&*it != terminator)
+          if (isa<hlfir::SumOp>(it))
+            break;
+          else
+            it++;
+
+        if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) {
+          /// Implementation:
+          /// Intrinsic function `SUM` operations
+          /// --
+          /// x = sum(array)
+          ///
+          /// is converted to
+          ///
+          /// !$omp parallel do
+          /// do i = 1, size(array)
+          ///     x = x + array(i)
+          /// end do
+          /// !$omp end parallel do
+
+          OpBuilder wslBuilder(sourceRegion.getContext());
+          Block *wslBlock = new Block();
+          wslBuilder.setInsertionPointToStart(wslBlock);
+
+          Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs();
+          Value array = sumOp.getArray();
+          Value dim = sumOp.getDim();
+          fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>(
+              hlfir::getFortranElementOrSequenceType(array.getType()));
+          llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+          if (arrayShape.size() == 1 && !dim) {
+            Value itr = allocaBuilder.create<fir::AllocaOp>(
+                loc, allocaBuilder.getI64Type());
+            Value c_one = allocaBuilder.create<arith::ConstantOp>(
+                loc, allocaBuilder.getI64IntegerAttr(1));
+            Value c_arr_size = allocaBuilder.create<arith::ConstantOp>(
+                loc, allocaBuilder.getI64IntegerAttr(arrayShape[0]));
+            // Value c_zero = allocaBuilder.create<arith::ConstantOp>(loc,
+            //     allocaBuilder.getZeroAttr(arrayTy.getEleTy()));
+            // allocaBuilder.create<fir::StoreOp>(loc, c_zero, target);
+
+            omp::WsloopOperands wslOps;
+            omp::WsloopOp wslOp =
+                rootBuilder.create<omp::WsloopOp>(loc, wslOps);
+
+            hlfir::LoopNest ln;
+            ln.outerOp = wslOp;
+            omp::LoopNestOperands lnOps;
+            lnOps.loopLowerBounds.push_back(c_one);
+            lnOps.loopUpperBounds.push_back(c_arr_size);
+            lnOps.loopSteps.push_back(c_one);
+            lnOps.loopInclusive = wslBuilder.getUnitAttr();
+            omp::LoopNestOp lnOp =
+                wslBuilder.create<omp::LoopNestOp>(loc, lnOps);
+            Block *lnBlock = wslBuilder.createBlock(&lnOp.getRegion());
+            lnBlock->addArgument(c_one.getType(), loc);
+            wslBuilder.create<fir::StoreOp>(
+                loc, lnOp.getRegion().getArgument(0), itr);
+            Value tarLoad = wslBuilder.create<fir::LoadOp>(loc, target);
+            Value itrLoad = wslBuilder.create<fir::LoadOp>(loc, itr);
+            hlfir::DesignateOp arrDesOp = wslBuilder.create<hlfir::DesignateOp>(
+                loc, fir::ReferenceType::get(arrayTy.getEleTy()), array,
+                itrLoad);
+            Value desLoad = wslBuilder.create<fir::LoadOp>(loc, arrDesOp);
+            Value addf =
+                wslBuilder.create<arith::AddFOp>(loc, tarLoad, desLoad);
+            wslBuilder.create<fir::StoreOp>(loc, addf, target);
+            wslBuilder.create<omp::YieldOp>(loc);
+            ln.body = lnBlock;
+            wslOp.getRegion().push_back(wslBlock);
+            targetRegion.front().getOperations().splice(
+                wslOp->getIterator(), allocaBlock->getOperations());
+          } else {
+            emitError(loc, "Only 1D array scalar assignment for sum "
----------------
tblah wrote:

instead of emitting an error here it would be better to go to the outer else branch and use the runtime library version of SUM

https://github.com/llvm/llvm-project/pull/113082


More information about the llvm-branch-commits mailing list