[flang-commits] [flang] 319705d - [flang] `do concurrent`: fix reduction symbol resolution when mapping to OpenMP (#155355)

via flang-commits flang-commits at lists.llvm.org
Wed Aug 27 08:06:20 PDT 2025


Author: Kareem Ergawy
Date: 2025-08-27T17:06:16+02:00
New Revision: 319705d0ab6f7b78ca26ee49b87393473ae63082

URL: https://github.com/llvm/llvm-project/commit/319705d0ab6f7b78ca26ee49b87393473ae63082
DIFF: https://github.com/llvm/llvm-project/commit/319705d0ab6f7b78ca26ee49b87393473ae63082.diff

LOG: [flang] `do concurrent`: fix reduction symbol resolution when mapping to OpenMP (#155355)

Fixes #155273

This PR introduces 2 changes:
1. The `do concurrent` to OpenMP pass is now a module pass rather than a
function pass.
2. Reduction ops are looked up in the parent module before being
created.

The benefit of using a module pass is that the same reduction operation
can be used across multiple functions if the reduction type matches.

Added: 
    flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90

Modified: 
    flang/include/flang/Optimizer/OpenMP/Passes.td
    flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 99202f6ee81e7..e2f092024c250 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -50,7 +50,7 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
   ];
 }
 
-def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::func::FuncOp"> {
+def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::ModuleOp"> {
   let summary = "Map `DO CONCURRENT` loops to OpenMP worksharing loops.";
 
   let description = [{ This is an experimental pass to map `DO CONCURRENT` loops

diff  --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index 2b3ac169e8b5b..c928b76065ade 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -173,9 +173,11 @@ class DoConcurrentConversion
 
   DoConcurrentConversion(
       mlir::MLIRContext *context, bool mapToDevice,
-      llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip)
+      llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
+      mlir::SymbolTable &moduleSymbolTable)
       : OpConversionPattern(context), mapToDevice(mapToDevice),
-        concurrentLoopsToSkip(concurrentLoopsToSkip) {}
+        concurrentLoopsToSkip(concurrentLoopsToSkip),
+        moduleSymbolTable(moduleSymbolTable) {}
 
   mlir::LogicalResult
   matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
@@ -332,8 +334,8 @@ class DoConcurrentConversion
                loop.getLocalVars(),
                loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
                loop.getRegionLocalArgs())) {
-        auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom<
-            fir::LocalitySpecifierOp>(loop, sym);
+        auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
+            sym.getLeafReference());
         if (localizer.getLocalitySpecifierType() ==
             fir::LocalitySpecifierType::LocalInit)
           TODO(localizer.getLoc(),
@@ -352,6 +354,8 @@ class DoConcurrentConversion
         cloneFIRRegionToOMP(localizer.getDeallocRegion(),
                             privatizer.getDeallocRegion());
 
+        moduleSymbolTable.insert(privatizer);
+
         wsloopClauseOps.privateVars.push_back(op);
         wsloopClauseOps.privateSyms.push_back(
             mlir::SymbolRefAttr::get(privatizer));
@@ -362,28 +366,34 @@ class DoConcurrentConversion
                loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
                loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
                loop.getRegionReduceArgs())) {
-        auto firReducer =
-            mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
-                loop, sym);
+        auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
+            sym.getLeafReference());
 
         mlir::OpBuilder::InsertionGuard guard(rewriter);
         rewriter.setInsertionPointAfter(firReducer);
-
-        auto ompReducer = mlir::omp::DeclareReductionOp::create(
-            rewriter, firReducer.getLoc(),
-            sym.getLeafReference().str() + ".omp",
-            firReducer.getTypeAttr().getValue());
-
-        cloneFIRRegionToOMP(firReducer.getAllocRegion(),
-                            ompReducer.getAllocRegion());
-        cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
-                            ompReducer.getInitializerRegion());
-        cloneFIRRegionToOMP(firReducer.getReductionRegion(),
-                            ompReducer.getReductionRegion());
-        cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
-                            ompReducer.getAtomicReductionRegion());
-        cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
-                            ompReducer.getCleanupRegion());
+        std::string ompReducerName = sym.getLeafReference().str() + ".omp";
+
+        auto ompReducer =
+            moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
+                rewriter.getStringAttr(ompReducerName));
+
+        if (!ompReducer) {
+          ompReducer = mlir::omp::DeclareReductionOp::create(
+              rewriter, firReducer.getLoc(), ompReducerName,
+              firReducer.getTypeAttr().getValue());
+
+          cloneFIRRegionToOMP(firReducer.getAllocRegion(),
+                              ompReducer.getAllocRegion());
+          cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
+                              ompReducer.getInitializerRegion());
+          cloneFIRRegionToOMP(firReducer.getReductionRegion(),
+                              ompReducer.getReductionRegion());
+          cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
+                              ompReducer.getAtomicReductionRegion());
+          cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
+                              ompReducer.getCleanupRegion());
+          moduleSymbolTable.insert(ompReducer);
+        }
 
         wsloopClauseOps.reductionVars.push_back(op);
         wsloopClauseOps.reductionByref.push_back(byRef);
@@ -431,6 +441,7 @@ class DoConcurrentConversion
 
   bool mapToDevice;
   llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
+  mlir::SymbolTable &moduleSymbolTable;
 };
 
 class DoConcurrentConversionPass
@@ -444,12 +455,9 @@ class DoConcurrentConversionPass
       : DoConcurrentConversionPassBase(options) {}
 
   void runOnOperation() override {
-    mlir::func::FuncOp func = getOperation();
-
-    if (func.isDeclaration())
-      return;
-
+    mlir::ModuleOp module = getOperation();
     mlir::MLIRContext *context = &getContext();
+    mlir::SymbolTable moduleSymbolTable(module);
 
     if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host &&
         mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) {
@@ -463,7 +471,7 @@ class DoConcurrentConversionPass
     mlir::RewritePatternSet patterns(context);
     patterns.insert<DoConcurrentConversion>(
         context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
-        concurrentLoopsToSkip);
+        concurrentLoopsToSkip, moduleSymbolTable);
     mlir::ConversionTarget target(*context);
     target.addDynamicallyLegalOp<fir::DoConcurrentOp>(
         [&](fir::DoConcurrentOp op) {
@@ -472,8 +480,8 @@ class DoConcurrentConversionPass
     target.markUnknownOpDynamicallyLegal(
         [](mlir::Operation *) { return true; });
 
-    if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
-                                               std::move(patterns)))) {
+    if (mlir::failed(
+            mlir::applyFullConversion(module, target, std::move(patterns)))) {
       signalPassFailure();
     }
   }

diff  --git a/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90 b/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90
new file mode 100644
index 0000000000000..ab56a4f6c7e70
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90
@@ -0,0 +1,32 @@
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=host %s -o - \
+! RUN:   | FileCheck %s
+
+subroutine test1(x,s,N)
+  real :: x(N), s
+  integer :: N
+  do concurrent(i=1:N) reduce(+:s)
+     s=s+x(i)
+  end do
+end subroutine test1
+subroutine test2(x,s,N)
+  real :: x(N), s
+  integer :: N
+  do concurrent(i=1:N) reduce(+:s)
+     s=s+x(i)
+  end do
+end subroutine test2
+
+! CHECK:       omp.declare_reduction @[[RED_SYM:.*]] : f32 init
+! CHECK-NOT:   omp.declare_reduction
+
+! CHECK-LABEL: func.func @_QPtest1
+! CHECK:         omp.parallel {
+! CHECK:           omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
+! CHECK:           }
+! CHECK:         }
+
+! CHECK-LABEL: func.func @_QPtest2
+! CHECK:         omp.parallel {
+! CHECK:           omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
+! CHECK:           }
+! CHECK:         }


        


More information about the flang-commits mailing list