[Mlir-commits] [flang] [llvm] [mlir] [flang][OpenMP][OMPIRBuilder][mlir] Optionally pass reduction vars by ref (PR #84304)
David Truby
llvmlistbot at llvm.org
Mon Mar 11 03:23:33 PDT 2024
================
@@ -294,14 +322,41 @@ mlir::omp::ReductionDeclareOp ReductionProcessor::createReductionDecl(
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+ mlir::Value outAddr = op1;
+
+ op1 = builder.loadIfRef(loc, op1);
+ op2 = builder.loadIfRef(loc, op2);
mlir::Value reductionOp =
createScalarCombiner(builder, loc, redId, type, op1, op2);
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ if (isByRef) {
+ builder.create<fir::StoreOp>(loc, reductionOp, outAddr);
+ builder.create<mlir::omp::YieldOp>(loc, outAddr);
+ } else {
+ builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+ }
return decl;
}
+bool ReductionProcessor::doReductionByRef(
+ const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
+ if (reductionVars.empty())
+ return false;
+ if (forceByrefReduction)
+ return true;
+
+ for (mlir::Value reductionVar : reductionVars) {
+ if (auto declare =
+ mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
+ reductionVar = declare.getMemref();
+
+ if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
+ return true;
----------------
DavidTruby wrote:
Does this imply that all reductions on a clause have to be by ref or by val? E.g. if we have an array reduction on the clause does that mean an integer reduction also changes to byref?
https://github.com/llvm/llvm-project/pull/84304
More information about the Mlir-commits
mailing list