[Mlir-commits] [mlir] [OpenACC][MLIR] clone reduction operands during ACCIfClauseLowering (PR #177196)

Scott Manley llvmlistbot at llvm.org
Wed Jan 21 08:32:02 PST 2026


https://github.com/rscottmanley created https://github.com/llvm/llvm-project/pull/177196

Clone the reduction operands into the compute region side. This also fixes an issue where references to acc.reduction remain on the host side.

>From e13e64aa646f37d9640a6accb7e9c089c6fe05c1 Mon Sep 17 00:00:00 2001
From: Scott Manley <scmanley at nvidia.com>
Date: Wed, 21 Jan 2026 08:30:46 -0800
Subject: [PATCH] [OpenACC][MLIR] clone reduction operands during
 ACCIfClauseLowering

Clone the reduction operands into the compute region side. This also
fixes an issue where references to acc.reduction remain on the host
side.
---
 .../Transforms/ACCIfClauseLowering.cpp        | 18 ++++++++
 .../OpenACC/acc-if-clause-lowering.mlir       | 45 ++++++++++++++++++-
 2 files changed, 62 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
index ddefa1653f213..c103ba29ed287 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -137,6 +137,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
   SmallVector<Operation *> dataEntryOps;
   SmallVector<Operation *> dataExitOps;
   SmallVector<Operation *> firstprivateOps;
+  SmallVector<Operation *> reductionOps;
 
   // Collect data entry operations
   for (Value operand : computeConstructOp.getDataClauseOperands()) {
@@ -150,6 +151,12 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
       firstprivateOps.push_back(defOp);
   }
 
+  // Collect reduction operations
+  for (Value operand : computeConstructOp.getReductionOperands()) {
+    if (Operation *defOp = operand.getDefiningOp())
+      reductionOps.push_back(defOp);
+  }
+
   // Find corresponding exit operations for each entry operation.
   // Iterate backwards through entry ops since exit ops appear in reverse order.
   for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
@@ -171,6 +178,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
   // Clone data entry operations
   SmallVector<Value> deviceDataOperands;
   SmallVector<Value> firstprivateOperands;
+  SmallVector<Value> reductionOperands;
 
   // Map the data entry and firstprivate ops for the cloned region
   IRMapping deviceMapping;
@@ -184,6 +192,11 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
     firstprivateOperands.push_back(clonedOp->getResult(0));
     deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0));
   }
+  for (Operation *reductionOp : reductionOps) {
+    Operation *clonedOp = rewriter.clone(*reductionOp, deviceMapping);
+    reductionOperands.push_back(clonedOp->getResult(0));
+    deviceMapping.map(reductionOp->getResult(0), clonedOp->getResult(0));
+  }
 
   // Create new compute op without if condition for device execution by
   // cloning
@@ -192,6 +205,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
   newComputeOp.getIfCondMutable().clear();
   newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
   newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
+  newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
 
   // Clone data exit operations
   rewriter.setInsertionPointAfter(newComputeOp);
@@ -238,6 +252,10 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
     getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp));
     eraseOps.push_back(firstprivateOp);
   }
+  for (Operation *reductionOp : reductionOps) {
+    getAccVar(reductionOp).replaceAllUsesWith(getVar(reductionOp));
+    eraseOps.push_back(reductionOp);
+  }
 }
 
 void ACCIfClauseLowering::runOnOperation() {
diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
index fdef532fb8cb4..e2c07d3db0e36 100644
--- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -247,7 +247,7 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %con
   %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
   %firstprivate = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
 
-  // In the else branch, uses of %copyin should be replaced with %arg0
+  // In the else branch, uses of %firstprivate should be replaced with %arg0
   // CHECK: scf.if
   // CHECK: [[FIRSTPRIVATE:%.*]] = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
   // CHECK: acc.parallel {{.*}} firstprivate([[FIRSTPRIVATE]] : memref<i32>) {
@@ -268,3 +268,46 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %con
   acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
   return
 }
+
+// -----
+
+// Test that acc variable uses in host path are replaced with host variables;
+// and the reduction operands are cloned
+// CHECK-LABEL: func.func @test_acc_reduction
+
+acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator <add> init {
+^bb0(%arg0: memref<i32>):
+  %c0_i32 = arith.constant 0 : i32
+  %0 = memref.alloca() : memref<i32>
+  memref.store %c0_i32, %0[] : memref<i32>
+  acc.yield %0 : memref<i32>
+} combiner {
+^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
+  %0 = memref.load %arg1[] : memref<i32>
+  %1 = memref.load %arg0[] : memref<i32>
+  %2 = arith.addi %1, %0 : i32
+  memref.store %2, %arg0[] : memref<i32>
+  acc.yield %arg0 : memref<i32>
+}
+
+func.func @test_acc_reduction(%arg0: memref<i32>, %cond: i1) {
+
+  %reduction = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32>
+
+  // In the else branch, uses of %reduction should be replaced with %arg0
+  // CHECK: scf.if
+  // CHECK: [[REDUCTION:%.*]] = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@memref_i32) -> memref<i32>
+  // CHECK: acc.parallel {{.*}} reduction([[REDUCTION]] : memref<i32>) {
+  // CHECK: } else {
+  // CHECK: [[LOAD:%.*]] = memref.load %arg0[] : memref<i32>
+  // CHECK: memref.store {{.*}}, %arg0[] : memref<i32>
+  // CHECK: }
+
+  acc.parallel reduction(%reduction : memref<i32>) if(%cond) {
+    %load = memref.load %reduction[] : memref<i32>
+    %add = arith.addi %load, %c0_i32 : i32
+    memref.store %add, %reduction[] : memref<i32>
+    acc.yield
+  }
+  return
+}
\ No newline at end of file



More information about the Mlir-commits mailing list