[llvm-branch-commits] [flang] [Flang][OMP] Replace SUM intrinsic call with SUM operations (PR #113082)

Tom Eccles via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Oct 23 04:48:09 PDT 2024


================
@@ -335,49 +336,129 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
     for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
       bool isLast = i + 1 == regions.size();
       if (std::holds_alternative<SingleRegion>(opOrSingle)) {
-        OpBuilder singleBuilder(sourceRegion.getContext());
-        Block *singleBlock = new Block();
-        singleBuilder.setInsertionPointToStart(singleBlock);
-
         OpBuilder allocaBuilder(sourceRegion.getContext());
         Block *allocaBlock = new Block();
         allocaBuilder.setInsertionPointToStart(allocaBlock);
 
-        OpBuilder parallelBuilder(sourceRegion.getContext());
-        Block *parallelBlock = new Block();
-        parallelBuilder.setInsertionPointToStart(parallelBlock);
-
-        auto [allParallelized, copyprivateVars] =
-            moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
-                         singleBuilder, parallelBuilder);
-        if (allParallelized) {
-          // The single region was not required as all operations were safe to
-          // parallelize
-          assert(copyprivateVars.empty());
-          assert(allocaBlock->empty());
-          delete singleBlock;
+        it = block.begin();
+        while (&*it != terminator)
+          if (isa<hlfir::SumOp>(it))
+            break;
+          else
+            it++;
+
+        if (auto sumOp = dyn_cast<hlfir::SumOp>(it)) {
+          /// Implementation:
+          /// Intrinsic function `SUM` operations
+          /// --
+          /// x = sum(array)
+          ///
+          /// is converted to
+          ///
+          /// !$omp parallel do
+          /// do i = 1, size(array)
+          ///     x = x + array(i)
+          /// end do
+          /// !$omp end parallel do
+
+          OpBuilder wslBuilder(sourceRegion.getContext());
+          Block *wslBlock = new Block();
+          wslBuilder.setInsertionPointToStart(wslBlock);
+
+          Value target = dyn_cast<hlfir::AssignOp>(++it).getLhs();
+          Value array = sumOp.getArray();
+          Value dim = sumOp.getDim();
+          fir::SequenceType arrayTy = dyn_cast<fir::SequenceType>(
+              hlfir::getFortranElementOrSequenceType(array.getType()));
+          llvm::ArrayRef<int64_t> arrayShape = arrayTy.getShape();
+          if (arrayShape.size() == 1 && !dim) {
----------------
tblah wrote:

I think you also need to check that the element type is floating point.

https://github.com/llvm/llvm-project/pull/113082


More information about the llvm-branch-commits mailing list