[Mlir-commits] [mlir] [OpenACC][MLIR] clone firstprivate operands during ACCIfClauseLowering (PR #176856)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 21:12:03 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir
Author: Scott Manley (rscottmanley)
<details>
<summary>Changes</summary>
Clone the firstprivate operands into the compute region side. This also fixes an issue where references to acc.firstprivate remain on the host side.
---
Full diff: https://github.com/llvm/llvm-project/pull/176856.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp (+23-7)
- (modified) mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir (+46)
``````````diff
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
index 5524c291a80e7..ddefa1653f213 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -136,6 +136,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
// condition
SmallVector<Operation *> dataEntryOps;
SmallVector<Operation *> dataExitOps;
+ SmallVector<Operation *> firstprivateOps;
// Collect data entry operations
for (Value operand : computeConstructOp.getDataClauseOperands()) {
@@ -143,6 +144,11 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
if (isa<ACC_DATA_ENTRY_OPS>(defOp))
dataEntryOps.push_back(defOp);
}
+ // Collect firstprivate operations
+ for (Value operand : computeConstructOp.getFirstprivateOperands()) {
+ if (Operation *defOp = operand.getDefiningOp())
+ firstprivateOps.push_back(defOp);
+ }
// Find corresponding exit operations for each entry operation.
// Iterate backwards through entry ops since exit ops appear in reverse order.
@@ -155,8 +161,8 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
TypeRange{}, ifCond, /*withElseRegion=*/true);
- // Declare deviceMapping at function scope for later use
- IRMapping deviceMapping;
+ LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
+ << " data entry operations for device path\n");
// Device execution path (true branch)
Block &thenBlock = ifOp.getThenRegion().front();
@@ -164,15 +170,20 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
// Clone data entry operations
SmallVector<Value> deviceDataOperands;
+ SmallVector<Value> firstprivateOperands;
- LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
- << " data entry operations for device path\n");
-
+ // Map the data entry and firstprivate ops for the cloned region
+ IRMapping deviceMapping;
for (Operation *dataOp : dataEntryOps) {
Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
deviceDataOperands.push_back(clonedDataOp->getResult(0));
deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
}
+ for (Operation *firstprivateOp : firstprivateOps) {
+ Operation *clonedOp = rewriter.clone(*firstprivateOp, deviceMapping);
+ firstprivateOperands.push_back(clonedOp->getResult(0));
+ deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0));
+ }
// Create new compute op without if condition for device execution by
// cloning
@@ -180,6 +191,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
newComputeOp.getIfCondMutable().clear();
newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
+ newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
// Clone data exit operations
rewriter.setInsertionPointAfter(newComputeOp);
@@ -216,12 +228,16 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
for (Operation *dataOp : dataExitOps)
eraseOps.push_back(dataOp);
+ // The new host code may contain uses of the acc variables. Replace them by
+ // the host values.
for (Operation *dataOp : dataEntryOps) {
- // The new host code may contain uses of the acc variables. Replace them by
- // the host values.
getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
eraseOps.push_back(dataOp);
}
+ for (Operation *firstprivateOp : firstprivateOps) {
+ getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp));
+ eraseOps.push_back(firstprivateOp);
+ }
}
void ACCIfClauseLowering::runOnOperation() {
diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
index 3f0df18619bc0..fdef532fb8cb4 100644
--- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -222,3 +222,49 @@ func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) {
return
}
+// -----
+
+// Test that acc variable uses in host path are replaced with host variables;
+// and the firstprivate operands are cloned
+// CHECK-LABEL: func.func @test_acc_firstprivate
+
+acc.firstprivate.recipe @memref_i32 : memref<i32> init {
+^bb0(%arg0: memref<i32>):
+ %0 = memref.alloca() : memref<i32>
+ acc.yield %0 : memref<i32>
+} copy {
+^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
+ %0 = memref.load %arg0[] : memref<i32>
+ memref.store %0, %arg1[] : memref<i32>
+ acc.terminator
+}
+
+func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %cond: i1) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+ %firstprivate = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
+
+ // In the else branch, uses of %copyin should be replaced with %arg0
+ // CHECK: scf.if
+ // CHECK: [[FIRSTPRIVATE:%.*]] = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
+ // CHECK: acc.parallel {{.*}} firstprivate([[FIRSTPRIVATE]] : memref<i32>) {
+ // CHECK: } else {
+ // CHECK: [[LOAD:%.*]] = memref.load %arg1[] : memref<i32>
+ // CHECK: }
+
+ acc.parallel dataOperands(%copyin : memref<10xi32>) firstprivate(%firstprivate : memref<i32>) if(%cond) {
+ %load = memref.load %firstprivate[] : memref<i32>
+ %ub = arith.index_cast %load : i32 to index
+ scf.for %i = %c1 to %ub step %c1 {
+ // Use the acc ptr inside the region
+ memref.store %c0_i32, %copyin[%i] : memref<10xi32>
+ }
+ acc.yield
+ }
+
+ acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/176856
More information about the Mlir-commits
mailing list