[Mlir-commits] [mlir] 668d474 - [OpenACC][MLIR] clone private operands during ACCIfClauseLowering (#177458)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 23 05:58:01 PST 2026
Author: Scott Manley
Date: 2026-01-23T07:57:56-06:00
New Revision: 668d474185340099377820464eb1f334ce8cd875
URL: https://github.com/llvm/llvm-project/commit/668d474185340099377820464eb1f334ce8cd875
DIFF: https://github.com/llvm/llvm-project/commit/668d474185340099377820464eb1f334ce8cd875.diff
LOG: [OpenACC][MLIR] clone private operands during ACCIfClauseLowering (#177458)
Clone the private operands into the compute region side. This also fixes
an issue where references to acc.private remain on the host side.
Added:
Modified:
mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
index c103ba29ed287..9095d7c915fa8 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -137,25 +137,14 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
SmallVector<Operation *> dataEntryOps;
SmallVector<Operation *> dataExitOps;
SmallVector<Operation *> firstprivateOps;
+ SmallVector<Operation *> privateOps;
SmallVector<Operation *> reductionOps;
// Collect data entry operations
- for (Value operand : computeConstructOp.getDataClauseOperands()) {
+ for (Value operand : computeConstructOp.getDataClauseOperands())
if (Operation *defOp = operand.getDefiningOp())
if (isa<ACC_DATA_ENTRY_OPS>(defOp))
dataEntryOps.push_back(defOp);
- }
- // Collect firstprivate operations
- for (Value operand : computeConstructOp.getFirstprivateOperands()) {
- if (Operation *defOp = operand.getDefiningOp())
- firstprivateOps.push_back(defOp);
- }
-
- // Collect reduction operations
- for (Value operand : computeConstructOp.getReductionOperands()) {
- if (Operation *defOp = operand.getDefiningOp())
- reductionOps.push_back(defOp);
- }
// Find corresponding exit operations for each entry operation.
// Iterate backwards through entry ops since exit ops appear in reverse order.
@@ -164,6 +153,16 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
if (isa<ACC_DATA_EXIT_OPS>(user))
dataExitOps.push_back(user);
+ // Collect firstprivate, private, and reduction operations
+ auto collectOps = [&](SmallVector<Operation *> &ops, OperandRange operands) {
+ for (Value operand : operands)
+ if (Operation *defOp = operand.getDefiningOp())
+ ops.push_back(defOp);
+ };
+ collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
+ collectOps(privateOps, computeConstructOp.getPrivateOperands());
+ collectOps(reductionOps, computeConstructOp.getReductionOperands());
+
// Create scf.if with device and host execution paths
auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
TypeRange{}, ifCond, /*withElseRegion=*/true);
@@ -178,25 +177,23 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
// Clone data entry operations
SmallVector<Value> deviceDataOperands;
SmallVector<Value> firstprivateOperands;
+ SmallVector<Value> privateOperands;
SmallVector<Value> reductionOperands;
// Map the data entry and firstprivate ops for the cloned region
IRMapping deviceMapping;
- for (Operation *dataOp : dataEntryOps) {
- Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
- deviceDataOperands.push_back(clonedDataOp->getResult(0));
- deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
- }
- for (Operation *firstprivateOp : firstprivateOps) {
- Operation *clonedOp = rewriter.clone(*firstprivateOp, deviceMapping);
- firstprivateOperands.push_back(clonedOp->getResult(0));
- deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0));
- }
- for (Operation *reductionOp : reductionOps) {
- Operation *clonedOp = rewriter.clone(*reductionOp, deviceMapping);
- reductionOperands.push_back(clonedOp->getResult(0));
- deviceMapping.map(reductionOp->getResult(0), clonedOp->getResult(0));
- }
+ auto cloneAndMapOps = [&](SmallVector<Operation *> &ops,
+ SmallVector<Value> &operands) {
+ for (Operation *op : ops) {
+ Operation *clonedOp = rewriter.clone(*op, deviceMapping);
+ operands.push_back(clonedOp->getResult(0));
+ deviceMapping.map(op->getResult(0), clonedOp->getResult(0));
+ }
+ };
+ cloneAndMapOps(dataEntryOps, deviceDataOperands);
+ cloneAndMapOps(firstprivateOps, firstprivateOperands);
+ cloneAndMapOps(privateOps, privateOperands);
+ cloneAndMapOps(reductionOps, reductionOperands);
// Create new compute op without if condition for device execution by
// cloning
@@ -205,6 +202,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
newComputeOp.getIfCondMutable().clear();
newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
+ newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
// Clone data exit operations
@@ -244,18 +242,16 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
// The new host code may contain uses of the acc variables. Replace them by
// the host values.
- for (Operation *dataOp : dataEntryOps) {
- getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
- eraseOps.push_back(dataOp);
- }
- for (Operation *firstprivateOp : firstprivateOps) {
- getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp));
- eraseOps.push_back(firstprivateOp);
- }
- for (Operation *reductionOp : reductionOps) {
- getAccVar(reductionOp).replaceAllUsesWith(getVar(reductionOp));
- eraseOps.push_back(reductionOp);
- }
+ auto replaceAndEraseOps = [&](SmallVector<Operation *> &ops) {
+ for (Operation *op : ops) {
+ getAccVar(op).replaceAllUsesWith(getVar(op));
+ eraseOps.push_back(op);
+ }
+ };
+ replaceAndEraseOps(dataEntryOps);
+ replaceAndEraseOps(firstprivateOps);
+ replaceAndEraseOps(privateOps);
+ replaceAndEraseOps(reductionOps);
}
void ACCIfClauseLowering::runOnOperation() {
diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
index 5b942c121d568..75f3a5cd211e0 100644
--- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -312,3 +312,29 @@ func.func @test_acc_reduction(%arg0: memref<i32>, %cond: i1) {
}
return
}
+
+acc.private.recipe @privatization_memref_i32 : memref<i32> init {
+^bb0(%arg0: memref<i32>):
+ %0 = memref.alloca() : memref<i32>
+ acc.yield %0 : memref<i32>
+}
+
+func.func @test_acc_private(%arg0: memref<i32>, %cond: i1) {
+
+ %c0_i32 = arith.constant 0 : i32
+ %private = acc.private varPtr(%arg0 : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
+
+ // In the else branch, uses of %private should be replaced with %arg0
+ // CHECK: scf.if
+ // CHECK: [[PRIVATE:%.*]] = acc.private varPtr(%arg0 : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
+ // CHECK: acc.parallel private([[PRIVATE]] : memref<i32>) {
+ // CHECK: } else {
+ // CHECK: memref.store {{.*}}, %arg0[] : memref<i32>
+ // CHECK: }
+
+ acc.parallel private(%private : memref<i32>) if(%cond) {
+ memref.store %c0_i32, %private[] : memref<i32>
+ acc.yield
+ }
+ return
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list