[Mlir-commits] [mlir] [mlir][acc] Add ACCIfClauseLowering pass (PR #173573)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 25 07:30:23 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openacc
Author: Razvan Lupusoru (razvanlupusoru)
<details>
<summary>Changes</summary>
This pass lowers OpenACC compute constructs with `if` clauses into `scf.if` with separate device and host paths.
Before:
```
%d = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.parallel dataOperands(%d) if(%cond) {
acc.loop control(%i : i32) = (%c0 : i32) to (%c10 : i32) step (%c1 :
i32) {
// loop body
acc.yield
}
acc.yield
}
acc.copyout accPtr(%d) to varPtr(%a)
```
After:
```
scf.if %cond {
%d = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32>
acc.parallel dataOperands(%d) {
acc.loop control(%i : i32) = (%c0 : i32) to (%c10 : i32) step (%c1
: i32) {
// loop body
acc.yield
}
acc.yield
}
acc.copyout accPtr(%d) to varPtr(%a)
} else {
scf.for %i = %c0 to %c10 step %c1 {
// loop body
}
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/173573.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+35)
- (added) mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp (+245)
- (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+1)
- (added) mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir (+224)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index e10fde3c2691f..68a52e0706d60 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -252,4 +252,39 @@ def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp">
];
}
+def ACCIfClauseLowering : Pass<"acc-if-clause-lowering", "mlir::func::FuncOp"> {
+ let summary = "Lower if clauses in ACC compute constructs";
+ let description = [{
+ This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
+ `if` clauses using region specialization. It creates two execution paths:
+ device execution when the condition is true, host execution when false.
+
+ When an ACC compute construct has an `if` clause, the construct should only
+ execute on the device when the condition is true. If the condition is false,
+ the code should execute on the host instead. This pass transforms:
+
+ ```mlir
+ acc.parallel if(%cond) { ... }
+ ```
+
+ Into:
+
+ ```mlir
+ scf.if %cond {
+ // Device path: clone data ops, compute construct without if, exit ops
+ acc.parallel { ... }
+ } else {
+ // Host path: original region body with ACC ops converted to host
+ }
+ ```
+
+ The transformation handles:
+ - Data entry operations (acc.copyin, acc.create, etc.) are cloned to device path
+ - Data exit operations (acc.copyout, acc.delete, etc.) are cloned to device path
+ - The host path uses `populateACCHostFallbackPatterns` to convert ACC ops
+ }];
+ let dependentDialects = ["mlir::acc::OpenACCDialect",
+ "mlir::scf::SCFDialect"];
+}
+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
new file mode 100644
index 0000000000000..5524c291a80e7
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -0,0 +1,245 @@
+//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
+// `if` clauses using region specialization. It creates two execution paths:
+// device execution when the condition is true, host execution when false.
+//
+// Overview:
+// ---------
+// When an ACC compute construct has an `if` clause, the construct should only
+// execute on the device when the condition is true. If the condition is false,
+// the code should execute on the host instead. This pass transforms:
+//
+// acc.parallel if(%cond) { ... }
+//
+// Into:
+//
+// scf.if %cond {
+// // Device path: clone data ops, compute construct without if, exit ops
+// acc.parallel { ... }
+// } else {
+// // Host path: original region body with ACC ops converted to host
+// }
+//
+// Transformations:
+// ----------------
+// For each compute construct with an `if` clause:
+//
+// 1. Device Path (true branch):
+// - Clone data entry operations (acc.copyin, acc.create, etc.)
+// - Clone the compute construct without the `if` clause
+// - Clone data exit operations (acc.copyout, acc.delete, etc.)
+//
+// 2. Host Path (false branch):
+// - Move the original region body to the else branch
+// - Apply host fallback patterns to convert ACC ops to host equivalents
+//
+// 3. Cleanup:
+// - Erase the original compute construct and data operations
+// - Replace uses of ACC variables with host variables in the else branch
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements exist:
+//
+// 1. Analysis Registration (Optional): If custom behavior is needed for
+// emitting not-yet-implemented messages for unsupported cases, the pipeline
+// should pre-register the `acc::OpenACCSupport` analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-if-clause-lowering"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+class ACCIfClauseLowering
+ : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
+ using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
+
+private:
+ OpenACCSupport *accSupport = nullptr;
+
+ void convertHostRegion(Operation *computeOp, Region ®ion);
+
+ template <typename OpTy>
+ void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
+ SmallVector<Operation *> &eraseOps);
+
+public:
+ void runOnOperation() override;
+};
+
+void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
+ Region ®ion) {
+ // Only collect ACC dialect operations - other ops don't need conversion
+ SmallVector<Operation *> hostOps;
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (isa<acc::OpenACCDialect>(op->getDialect()))
+ hostOps.push_back(op);
+ });
+
+ RewritePatternSet patterns(computeOp->getContext());
+ populateACCHostFallbackPatterns(patterns, *accSupport);
+
+ GreedyRewriteConfig config;
+ config.setUseTopDownTraversal(true);
+ config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+ if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
+ accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
+}
+
+// Template function to handle if condition conversion for ACC compute
+// constructs
+template <typename OpTy>
+void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
+ OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
+ Value ifCond = computeConstructOp.getIfCond();
+ if (!ifCond)
+ return;
+
+ IRRewriter rewriter(computeConstructOp);
+
+ LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
+ << " with if condition: " << computeConstructOp
+ << "\n");
+
+ // Collect data clause operations that need to be recreated in the if
+ // condition
+ SmallVector<Operation *> dataEntryOps;
+ SmallVector<Operation *> dataExitOps;
+
+ // Collect data entry operations
+ for (Value operand : computeConstructOp.getDataClauseOperands()) {
+ if (Operation *defOp = operand.getDefiningOp())
+ if (isa<ACC_DATA_ENTRY_OPS>(defOp))
+ dataEntryOps.push_back(defOp);
+ }
+
+ // Find corresponding exit operations for each entry operation.
+ // Iterate backwards through entry ops since exit ops appear in reverse order.
+ for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
+ for (Operation *user : dataEntryOp->getUsers())
+ if (isa<ACC_DATA_EXIT_OPS>(user))
+ dataExitOps.push_back(user);
+
+ // Create scf.if with device and host execution paths
+ auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
+ TypeRange{}, ifCond, /*withElseRegion=*/true);
+
+ // Declare deviceMapping at function scope for later use
+ IRMapping deviceMapping;
+
+ // Device execution path (true branch)
+ Block &thenBlock = ifOp.getThenRegion().front();
+ rewriter.setInsertionPointToStart(&thenBlock);
+
+ // Clone data entry operations
+ SmallVector<Value> deviceDataOperands;
+
+ LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
+ << " data entry operations for device path\n");
+
+ for (Operation *dataOp : dataEntryOps) {
+ Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
+ deviceDataOperands.push_back(clonedDataOp->getResult(0));
+ deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
+ }
+
+ // Create new compute op without if condition for device execution by
+ // cloning
+ OpTy newComputeOp = cast<OpTy>(
+ rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
+ newComputeOp.getIfCondMutable().clear();
+ newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
+
+ // Clone data exit operations
+ rewriter.setInsertionPointAfter(newComputeOp);
+ for (Operation *dataOp : dataExitOps)
+ rewriter.clone(*dataOp, deviceMapping);
+
+ rewriter.setInsertionPointToEnd(&thenBlock);
+ if (!thenBlock.getTerminator())
+ scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
+
+ // Host execution path (false branch)
+ if (!computeConstructOp.getRegion().hasOneBlock()) {
+ accSupport->emitNYI(computeConstructOp.getLoc(),
+ "region with multiple blocks");
+ return;
+ }
+
+ // Don't need to clone original ops, just take them and legalize for host
+ ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
+
+ // Swap acc yield for scf yield
+ Block &elseBlock = ifOp.getElseRegion().front();
+ elseBlock.getTerminator()->erase();
+ rewriter.setInsertionPointToEnd(&elseBlock);
+ scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
+
+ convertHostRegion(computeConstructOp, ifOp.getElseRegion());
+
+ // The original op is now empty and can be erased
+ eraseOps.push_back(computeConstructOp);
+
+ // TODO: Can probably 'move' the data ops instead of cloning them
+ // which would eliminate need to explicitly erase
+ for (Operation *dataOp : dataExitOps)
+ eraseOps.push_back(dataOp);
+
+ for (Operation *dataOp : dataEntryOps) {
+ // The new host code may contain uses of the acc variables. Replace them by
+ // the host values.
+ getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
+ eraseOps.push_back(dataOp);
+ }
+}
+
+void ACCIfClauseLowering::runOnOperation() {
+ func::FuncOp funcOp = getOperation();
+ accSupport = &getAnalysis<OpenACCSupport>();
+
+ SmallVector<Operation *> eraseOps;
+ funcOp.walk([&](Operation *op) {
+ if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
+ lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
+ else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
+ lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
+ else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
+ lowerIfClauseForComputeConstruct(serialOp, eraseOps);
+ });
+
+ for (Operation *op : eraseOps)
+ op->erase();
+}
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index e94ac6f332834..3a0ca338766e4 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
+ ACCIfClauseLowering.cpp
ACCImplicitData.cpp
ACCLoopTiling.cpp
ACCImplicitDeclare.cpp
diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
new file mode 100644
index 0000000000000..3f0df18619bc0
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -0,0 +1,224 @@
+// RUN: mlir-opt %s -acc-if-clause-lowering -split-input-file | FileCheck %s
+
+// Test acc.parallel with if condition
+// CHECK-LABEL: func.func @test_parallel_if
+func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+ %create = acc.create varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+ // CHECK-NOT: acc.parallel if
+ // CHECK: scf.if %{{.*}} {
+ // CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}}) -> memref<10xi32>
+ // CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}}) -> memref<10xi32>
+ // CHECK: acc.parallel dataOperands(%[[COPYIN]], %[[CREATE]] : memref<10xi32>, memref<10xi32>) {
+ // CHECK: scf.for
+ // CHECK: acc.yield
+ // CHECK: }
+ // CHECK: acc.delete accPtr(%[[CREATE]] : memref<10xi32>)
+ // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr(%{{.*}} : memref<10xi32>)
+ // CHECK: } else {
+ // CHECK: scf.for
+ // CHECK: }
+ acc.parallel dataOperands(%copyin, %create : memref<10xi32>, memref<10xi32>) if(%cond) {
+ scf.for %i = %c1 to %c10 step %c1 {
+ memref.store %c0_i32, %arg0[%i] : memref<10xi32>
+ }
+ acc.yield
+ }
+
+ acc.delete accPtr(%create : memref<10xi32>)
+ acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+ return
+}
+
+// -----
+
+// Test acc.kernels with if condition
+// CHECK-LABEL: func.func @test_kernels_if
+func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
+ %c1_i32 = arith.constant 1 : i32
+ %c1 = arith.constant 1 : index
+ %c5 = arith.constant 5 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<5xi32>) -> memref<5xi32>
+ %create = acc.create varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+ // CHECK-NOT: acc.kernels if
+ // CHECK: scf.if %{{.*}} {
+ // CHECK: %[[COPYIN:.*]] = acc.copyin
+ // CHECK: %[[CREATE:.*]] = acc.create
+ // CHECK: acc.kernels dataOperands(%[[COPYIN]], %[[CREATE]] : memref<5xi32>, memref<5xi32>) {
+ // CHECK: scf.for
+ // CHECK: acc.terminator
+ // CHECK: }
+ // CHECK: acc.delete accPtr(%[[CREATE]] : memref<5xi32>)
+ // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<5xi32>) to varPtr(%{{.*}} : memref<5xi32>)
+ // CHECK: } else {
+ // CHECK: scf.for
+ // CHECK: }
+ acc.kernels dataOperands(%copyin, %create : memref<5xi32>, memref<5xi32>) if(%cond) {
+ scf.for %i = %c1 to %c5 step %c1 {
+ memref.store %c1_i32, %arg0[%i] : memref<5xi32>
+ }
+ acc.terminator
+ }
+
+ acc.delete accPtr(%create : memref<5xi32>)
+ acc.copyout accPtr(%copyin : memref<5xi32>) to varPtr(%arg0 : memref<5xi32>)
+ return
+}
+
+// -----
+
+// Test acc.serial with if condition
+// CHECK-LABEL: func.func @test_serial_if
+func.func @test_serial_if(%arg0: memref<8xi32>, %cond: i1) {
+ %c2_i32 = arith.constant 2 : i32
+ %c1 = arith.constant 1 : index
+ %c8 = arith.constant 8 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<8xi32>) -> memref<8xi32>
+ %create = acc.create varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+ // CHECK-NOT: acc.serial if
+ // CHECK: scf.if %{{.*}} {
+ // CHECK: %[[COPYIN:.*]] = acc.copyin
+ // CHECK: %[[CREATE:.*]] = acc.create
+ // CHECK: acc.serial dataOperands(%[[COPYIN]], %[[CREATE]] : memref<8xi32>, memref<8xi32>) {
+ // CHECK: scf.for
+ // CHECK: acc.yield
+ // CHECK: }
+ // CHECK: acc.delete accPtr(%[[CREATE]] : memref<8xi32>)
+ // CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<8xi32>) to varPtr(%{{.*}} : memref<8xi32>)
+ // CHECK: } else {
+ // CHECK: scf.for
+ // CHECK: }
+ acc.serial dataOperands(%copyin, %create : memref<8xi32>, memref<8xi32>) if(%cond) {
+ scf.for %i = %c1 to %c8 step %c1 {
+ memref.store %c2_i32, %arg0[%i] : memref<8xi32>
+ }
+ acc.yield
+ }
+
+ acc.delete accPtr(%create : memref<8xi32>)
+ acc.copyout accPtr(%copyin : memref<8xi32>) to varPtr(%arg0 : memref<8xi32>)
+ return
+}
+
+// -----
+
+// Test that acc.parallel without if condition is not modified
+// CHECK-LABEL: func.func @test_parallel_no_if
+func.func @test_parallel_no_if(%arg0: memref<10xi32>) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+
+ // CHECK-NOT: scf.if
+ // CHECK: acc.parallel dataOperands(%{{.*}}) {
+ acc.parallel dataOperands(%copyin : memref<10xi32>) {
+ scf.for %i = %c1 to %c10 step %c1 {
+ memref.store %c0_i32, %arg0[%i] : memref<10xi32>
+ }
+ acc.yield
+ }
+
+ acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+ return
+}
+
+// -----
+
+// Test with private and reduction clauses inside compute construct
+acc.private.recipe @privatization_memref_i32 : memref<i32> init {
+^bb0(%arg0: memref<i32>):
+ %0 = memref.alloca() : memref<i32>
+ acc.yield %0 : memref<i32>
+}
+
+acc.reduction.recipe @reduction_add_memref_f32 : memref<f32> reduction_operator <add> init {
+^bb0(%arg0: memref<f32>):
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = memref.alloca() : memref<f32>
+ memref.store %cst, %0[] : memref<f32>
+ acc.yield %0 : memref<f32>
+} combiner {
+^bb0(%arg0: memref<f32>, %arg1: memref<f32>):
+ %0 = memref.load %arg0[] : memref<f32>
+ %1 = memref.load %arg1[] : memref<f32>
+ %2 = arith.addf %0, %1 : f32
+ memref.store %2, %arg0[] : memref<f32>
+ acc.yield %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: func.func @test_reduction_if
+func.func @test_reduction_if(%r: memref<f32>, %a: memref<8xf32>, %cond: i1) {
+ %c8_i32 = arith.constant 8 : i32
+ %c1_i32 = arith.constant 1 : i32
+
+ %copyin = acc.copyin varPtr(%r : memref<f32>) -> memref<f32> {dataClause = #acc<data_clause acc_reduction>, implicit = true}
+
+ // CHECK: scf.if
+ // CHECK: acc.parallel
+ // CHECK: } else {
+ // The else branch should have acc ops converted to host
+ // CHECK-NOT: acc.loop
+ // CHECK-NOT: acc.reduction
+ // CHECK-NOT: acc.private
+ // CHECK: }
+ acc.parallel combined(loop) dataOperands(%copyin : memref<f32>) if(%cond) {
+ %red = acc.reduction varPtr(%r : memref<f32>) recipe(@reduction_add_memref_f32) -> memref<f32>
+ %iter_var = memref.alloca() : memref<i32>
+ %priv = acc.private varPtr(%iter_var : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
+ acc.loop combined(parallel) vector private(%priv : memref<i32>) reduction(%red : memref<f32>) control(%iv : i32) = (%c1_i32 : i32) to (%c8_i32 : i32) step (%c1_i32 : i32) {
+ memref.store %iv, %priv[] : memref<i32>
+ %idx = memref.load %priv[] : memref<i32>
+ %idx_cast = arith.index_cast %idx : i32 to index
+ %elem = memref.load %a[%idx_cast] : memref<8xf32>
+ %r_val = memref.load %r[] : memref<f32>
+ %new_r = arith.addf %r_val, %elem : f32
+ memref.store %new_r, %r[] : memref<f32>
+ acc.yield
+ } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
+ acc.yield
+ }
+
+ acc.copyout accPtr(%copyin : memref<f32>) to varPtr(%r : memref<f32>) {dataClause = #acc<data_clause acc_reduction>, implicit = true}
+ return
+}
+
+// -----
+
+// Test that acc variable uses in host path are replaced with host variables
+// CHECK-LABEL: func.func @test_acc_var_replacement
+func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) {
+ %c0_i32 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+
+ %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+
+ // In the else branch, uses of %copyin should be replaced with %arg0
+ // CHECK: scf.if
+ // CHECK: } else {
+ // CHECK: scf.for
+ // CHECK: memref.store %{{.*}}, %arg0[%{{.*}}]
+ // CHECK: }
+ acc.parallel dataOperands(%copyin : memref<10xi32>) if(%cond) {
+ scf.for %i = %c1 to %c10 step %c1 {
+ // Use the acc ptr inside the region
+ memref.store %c0_i32, %copyin[%i] : memref<10xi32>
+ }
+ acc.yield
+ }
+
+ acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+ return
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/173573
More information about the Mlir-commits
mailing list