[flang-commits] [flang] [llvm] [mlir] [flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref (PR #84304)

David Truby via flang-commits flang-commits at lists.llvm.org
Mon Mar 11 09:10:13 PDT 2024


================
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
   builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
   mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
   mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+  mlir::Value outAddr = op1;
+
+  op1 = builder.loadIfRef(loc, op1);
+  op2 = builder.loadIfRef(loc, op2);
 
   mlir::Value reductionOp =
       createScalarCombiner(builder, loc, redId, type, op1, op2);
-  builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  if (isByRef) {
+    builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+    builder.create<mlir::omp::YieldOp>(loc, outAddr);
+  } else {
+    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+  }
 
   return decl;
 }
 
+bool ReductionProcessor::doReductionByRef(
+    const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+  if (reductionVars.empty())
+    return false;
+  if (forceByrefReduction)
+    return true;
+
+  for (mlir::Value reductionVar : reductionVars) {
+    if (auto declare =
+            mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+      reductionVar = declare.getMemref();
+
+    if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+      return true;
----------------
DavidTruby wrote:

That's probably true wrt performance, but I believe when it comes to `openmp target` reductions doing reductions on basic types will affect correctness. As far as I remember you do not need to ensure manually that for example the INTEGER exists on the target device, whereas you do for an array (and would need to if the integer is passed by reference).

I think fixing that in a subsequent patch is probably fine though, as long as we add a TODO mentioning that ideally it should be considered separately per argument.

https://github.com/llvm/llvm-project/pull/84304


More information about the flang-commits mailing list