[Mlir-commits] [mlir] 3016679 - [mlir][acc] Add ACCIfClauseLowering pass (#173573)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 25 11:41:12 PST 2025


Author: Razvan Lupusoru
Date: 2025-12-25T11:41:08-08:00
New Revision: 30166796e8286d96f813dcb6668930dc2427a0d0

URL: https://github.com/llvm/llvm-project/commit/30166796e8286d96f813dcb6668930dc2427a0d0
DIFF: https://github.com/llvm/llvm-project/commit/30166796e8286d96f813dcb6668930dc2427a0d0.diff

LOG: [mlir][acc] Add ACCIfClauseLowering pass (#173573)

This pass lowers OpenACC compute constructs with `if` clauses into
`scf.if` with separate device and host paths.

Before:
```
  %d = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32>
  acc.parallel dataOperands(%d) if(%cond) {
    acc.loop control(%i : i32) = (%c0 : i32) to (%c10 : i32) step (%c1 :
i32) {
      // loop body
      acc.yield
    }
    acc.yield
  }
  acc.copyout accPtr(%d) to varPtr(%a)
```

After:
```
  scf.if %cond {
    %d = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32>
    acc.parallel dataOperands(%d) {
      acc.loop control(%i : i32) = (%c0 : i32) to (%c10 : i32) step (%c1
: i32) {
        // loop body
        acc.yield
      }
      acc.yield
    }
    acc.copyout accPtr(%d) to varPtr(%a)
  } else {
    scf.for %i = %c0 to %c10 step %c1 {
      // loop body
    }
  }
```

Co-authored-by: Susan Tan <zujunt at nvidia.com>

Added: 
    mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
    mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir

Modified: 
    mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
    mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index e10fde3c2691f..68a52e0706d60 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -252,4 +252,39 @@ def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp">
   ];
 }
 
+def ACCIfClauseLowering : Pass<"acc-if-clause-lowering", "mlir::func::FuncOp"> {
+  let summary = "Lower if clauses in ACC compute constructs";
+  let description = [{
+    This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
+    `if` clauses using region specialization. It creates two execution paths:
+    device execution when the condition is true, host execution when false.
+
+    When an ACC compute construct has an `if` clause, the construct should only
+    execute on the device when the condition is true. If the condition is false,
+    the code should execute on the host instead. This pass transforms:
+
+    ```mlir
+    acc.parallel if(%cond) { ... }
+    ```
+
+    Into:
+
+    ```mlir
+    scf.if %cond {
+      // Device path: clone data ops, compute construct without if, exit ops
+      acc.parallel { ... }
+    } else {
+      // Host path: original region body with ACC ops converted to host
+    }
+    ```
+
+    The transformation handles:
+    - Data entry operations (acc.copyin, acc.create, etc.) are cloned to device path
+    - Data exit operations (acc.copyout, acc.delete, etc.) are cloned to device path
+    - The host path uses `populateACCHostFallbackPatterns` to convert ACC ops
+  }];
+  let dependentDialects = ["mlir::acc::OpenACCDialect",
+      "mlir::scf::SCFDialect"];
+}
+
 #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
new file mode 100644
index 0000000000000..5524c291a80e7
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
@@ -0,0 +1,245 @@
+//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
+// `if` clauses using region specialization. It creates two execution paths:
+// device execution when the condition is true, host execution when false.
+//
+// Overview:
+// ---------
+// When an ACC compute construct has an `if` clause, the construct should only
+// execute on the device when the condition is true. If the condition is false,
+// the code should execute on the host instead. This pass transforms:
+//
+//   acc.parallel if(%cond) { ... }
+//
+// Into:
+//
+//   scf.if %cond {
+//     // Device path: clone data ops, compute construct without if, exit ops
+//     acc.parallel { ... }
+//   } else {
+//     // Host path: original region body with ACC ops converted to host
+//   }
+//
+// Transformations:
+// ----------------
+// For each compute construct with an `if` clause:
+//
+// 1. Device Path (true branch):
+//    - Clone data entry operations (acc.copyin, acc.create, etc.)
+//    - Clone the compute construct without the `if` clause
+//    - Clone data exit operations (acc.copyout, acc.delete, etc.)
+//
+// 2. Host Path (false branch):
+//    - Move the original region body to the else branch
+//    - Apply host fallback patterns to convert ACC ops to host equivalents
+//
+// 3. Cleanup:
+//    - Erase the original compute construct and data operations
+//    - Replace uses of ACC variables with host variables in the else branch
+//
+// Requirements:
+// -------------
+// To use this pass in a pipeline, the following requirements exist:
+//
+// 1. Analysis Registration (Optional): If custom behavior is needed for
+//    emitting not-yet-implemented messages for unsupported cases, the pipeline
+//    should pre-register the `acc::OpenACCSupport` analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-if-clause-lowering"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+class ACCIfClauseLowering
+    : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
+  using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
+
+private:
+  OpenACCSupport *accSupport = nullptr;
+
+  void convertHostRegion(Operation *computeOp, Region &region);
+
+  template <typename OpTy>
+  void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
+                                        SmallVector<Operation *> &eraseOps);
+
+public:
+  void runOnOperation() override;
+};
+
+void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
+                                            Region &region) {
+  // Only collect ACC dialect operations - other ops don't need conversion
+  SmallVector<Operation *> hostOps;
+  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (isa<acc::OpenACCDialect>(op->getDialect()))
+      hostOps.push_back(op);
+  });
+
+  RewritePatternSet patterns(computeOp->getContext());
+  populateACCHostFallbackPatterns(patterns, *accSupport);
+
+  GreedyRewriteConfig config;
+  config.setUseTopDownTraversal(true);
+  config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+  if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
+    accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
+}
+
+// Template function to handle if condition conversion for ACC compute
+// constructs
+template <typename OpTy>
+void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
+    OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
+  Value ifCond = computeConstructOp.getIfCond();
+  if (!ifCond)
+    return;
+
+  IRRewriter rewriter(computeConstructOp);
+
+  LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
+                          << " with if condition: " << computeConstructOp
+                          << "\n");
+
+  // Collect data clause operations that need to be recreated in the if
+  // condition
+  SmallVector<Operation *> dataEntryOps;
+  SmallVector<Operation *> dataExitOps;
+
+  // Collect data entry operations
+  for (Value operand : computeConstructOp.getDataClauseOperands()) {
+    if (Operation *defOp = operand.getDefiningOp())
+      if (isa<ACC_DATA_ENTRY_OPS>(defOp))
+        dataEntryOps.push_back(defOp);
+  }
+
+  // Find corresponding exit operations for each entry operation.
+  // Iterate backwards through entry ops since exit ops appear in reverse order.
+  for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
+    for (Operation *user : dataEntryOp->getUsers())
+      if (isa<ACC_DATA_EXIT_OPS>(user))
+        dataExitOps.push_back(user);
+
+  // Create scf.if with device and host execution paths
+  auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
+                                TypeRange{}, ifCond, /*withElseRegion=*/true);
+
+  // Declare deviceMapping at function scope for later use
+  IRMapping deviceMapping;
+
+  // Device execution path (true branch)
+  Block &thenBlock = ifOp.getThenRegion().front();
+  rewriter.setInsertionPointToStart(&thenBlock);
+
+  // Clone data entry operations
+  SmallVector<Value> deviceDataOperands;
+
+  LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
+                          << " data entry operations for device path\n");
+
+  for (Operation *dataOp : dataEntryOps) {
+    Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
+    deviceDataOperands.push_back(clonedDataOp->getResult(0));
+    deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
+  }
+
+  // Create new compute op without if condition for device execution by
+  // cloning
+  OpTy newComputeOp = cast<OpTy>(
+      rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
+  newComputeOp.getIfCondMutable().clear();
+  newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
+
+  // Clone data exit operations
+  rewriter.setInsertionPointAfter(newComputeOp);
+  for (Operation *dataOp : dataExitOps)
+    rewriter.clone(*dataOp, deviceMapping);
+
+  rewriter.setInsertionPointToEnd(&thenBlock);
+  if (!thenBlock.getTerminator())
+    scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
+
+  // Host execution path (false branch)
+  if (!computeConstructOp.getRegion().hasOneBlock()) {
+    accSupport->emitNYI(computeConstructOp.getLoc(),
+                        "region with multiple blocks");
+    return;
+  }
+
+  // Don't need to clone original ops, just take them and legalize for host
+  ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
+
+  // Swap acc yield for scf yield
+  Block &elseBlock = ifOp.getElseRegion().front();
+  elseBlock.getTerminator()->erase();
+  rewriter.setInsertionPointToEnd(&elseBlock);
+  scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
+
+  convertHostRegion(computeConstructOp, ifOp.getElseRegion());
+
+  // The original op is now empty and can be erased
+  eraseOps.push_back(computeConstructOp);
+
+  // TODO: Can probably 'move' the data ops instead of cloning them
+  // which would eliminate need to explicitly erase
+  for (Operation *dataOp : dataExitOps)
+    eraseOps.push_back(dataOp);
+
+  for (Operation *dataOp : dataEntryOps) {
+    // The new host code may contain uses of the acc variables. Replace them by
+    // the host values.
+    getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
+    eraseOps.push_back(dataOp);
+  }
+}
+
+void ACCIfClauseLowering::runOnOperation() {
+  func::FuncOp funcOp = getOperation();
+  accSupport = &getAnalysis<OpenACCSupport>();
+
+  SmallVector<Operation *> eraseOps;
+  funcOp.walk([&](Operation *op) {
+    if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
+      lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
+    else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
+      lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
+    else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
+      lowerIfClauseForComputeConstruct(serialOp, eraseOps);
+  });
+
+  for (Operation *op : eraseOps)
+    op->erase();
+}
+
+} // namespace

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index e94ac6f332834..3a0ca338766e4 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIROpenACCTransforms
+  ACCIfClauseLowering.cpp
   ACCImplicitData.cpp
   ACCLoopTiling.cpp
   ACCImplicitDeclare.cpp

diff  --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
new file mode 100644
index 0000000000000..3f0df18619bc0
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
@@ -0,0 +1,224 @@
+// RUN: mlir-opt %s -acc-if-clause-lowering -split-input-file | FileCheck %s
+
+// Test acc.parallel with if condition
+// CHECK-LABEL: func.func @test_parallel_if
+func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
+  %c0_i32 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+  %create = acc.create varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+  // CHECK-NOT: acc.parallel if
+  // CHECK: scf.if %{{.*}} {
+  // CHECK:   %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}}) -> memref<10xi32>
+  // CHECK:   %[[CREATE:.*]] = acc.create varPtr(%{{.*}}) -> memref<10xi32>
+  // CHECK:   acc.parallel dataOperands(%[[COPYIN]], %[[CREATE]] : memref<10xi32>, memref<10xi32>) {
+  // CHECK:     scf.for
+  // CHECK:     acc.yield
+  // CHECK:   }
+  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<10xi32>)
+  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr(%{{.*}} : memref<10xi32>)
+  // CHECK: } else {
+  // CHECK:   scf.for
+  // CHECK: }
+  acc.parallel dataOperands(%copyin, %create : memref<10xi32>, memref<10xi32>) if(%cond) {
+    scf.for %i = %c1 to %c10 step %c1 {
+      memref.store %c0_i32, %arg0[%i] : memref<10xi32>
+    }
+    acc.yield
+  }
+
+  acc.delete accPtr(%create : memref<10xi32>)
+  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+  return
+}
+
+// -----
+
+// Test acc.kernels with if condition
+// CHECK-LABEL: func.func @test_kernels_if
+func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
+  %c1_i32 = arith.constant 1 : i32
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+
+  %copyin = acc.copyin varPtr(%arg0 : memref<5xi32>) -> memref<5xi32>
+  %create = acc.create varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+  // CHECK-NOT: acc.kernels if
+  // CHECK: scf.if %{{.*}} {
+  // CHECK:   %[[COPYIN:.*]] = acc.copyin
+  // CHECK:   %[[CREATE:.*]] = acc.create
+  // CHECK:   acc.kernels dataOperands(%[[COPYIN]], %[[CREATE]] : memref<5xi32>, memref<5xi32>) {
+  // CHECK:     scf.for
+  // CHECK:     acc.terminator
+  // CHECK:   }
+  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<5xi32>)
+  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<5xi32>) to varPtr(%{{.*}} : memref<5xi32>)
+  // CHECK: } else {
+  // CHECK:   scf.for
+  // CHECK: }
+  acc.kernels dataOperands(%copyin, %create : memref<5xi32>, memref<5xi32>) if(%cond) {
+    scf.for %i = %c1 to %c5 step %c1 {
+      memref.store %c1_i32, %arg0[%i] : memref<5xi32>
+    }
+    acc.terminator
+  }
+
+  acc.delete accPtr(%create : memref<5xi32>)
+  acc.copyout accPtr(%copyin : memref<5xi32>) to varPtr(%arg0 : memref<5xi32>)
+  return
+}
+
+// -----
+
+// Test acc.serial with if condition
+// CHECK-LABEL: func.func @test_serial_if
+func.func @test_serial_if(%arg0: memref<8xi32>, %cond: i1) {
+  %c2_i32 = arith.constant 2 : i32
+  %c1 = arith.constant 1 : index
+  %c8 = arith.constant 8 : index
+
+  %copyin = acc.copyin varPtr(%arg0 : memref<8xi32>) -> memref<8xi32>
+  %create = acc.create varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> {dataClause = #acc<data_clause acc_copyout>}
+
+  // CHECK-NOT: acc.serial if
+  // CHECK: scf.if %{{.*}} {
+  // CHECK:   %[[COPYIN:.*]] = acc.copyin
+  // CHECK:   %[[CREATE:.*]] = acc.create
+  // CHECK:   acc.serial dataOperands(%[[COPYIN]], %[[CREATE]] : memref<8xi32>, memref<8xi32>) {
+  // CHECK:     scf.for
+  // CHECK:     acc.yield
+  // CHECK:   }
+  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<8xi32>)
+  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<8xi32>) to varPtr(%{{.*}} : memref<8xi32>)
+  // CHECK: } else {
+  // CHECK:   scf.for
+  // CHECK: }
+  acc.serial dataOperands(%copyin, %create : memref<8xi32>, memref<8xi32>) if(%cond) {
+    scf.for %i = %c1 to %c8 step %c1 {
+      memref.store %c2_i32, %arg0[%i] : memref<8xi32>
+    }
+    acc.yield
+  }
+
+  acc.delete accPtr(%create : memref<8xi32>)
+  acc.copyout accPtr(%copyin : memref<8xi32>) to varPtr(%arg0 : memref<8xi32>)
+  return
+}
+
+// -----
+
+// Test that acc.parallel without if condition is not modified
+// CHECK-LABEL: func.func @test_parallel_no_if
+func.func @test_parallel_no_if(%arg0: memref<10xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+
+  // CHECK-NOT: scf.if
+  // CHECK: acc.parallel dataOperands(%{{.*}}) {
+  acc.parallel dataOperands(%copyin : memref<10xi32>) {
+    scf.for %i = %c1 to %c10 step %c1 {
+      memref.store %c0_i32, %arg0[%i] : memref<10xi32>
+    }
+    acc.yield
+  }
+
+  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+  return
+}
+
+// -----
+
+// Test with private and reduction clauses inside compute construct
+acc.private.recipe @privatization_memref_i32 : memref<i32> init {
+^bb0(%arg0: memref<i32>):
+  %0 = memref.alloca() : memref<i32>
+  acc.yield %0 : memref<i32>
+}
+
+acc.reduction.recipe @reduction_add_memref_f32 : memref<f32> reduction_operator <add> init {
+^bb0(%arg0: memref<f32>):
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = memref.alloca() : memref<f32>
+  memref.store %cst, %0[] : memref<f32>
+  acc.yield %0 : memref<f32>
+} combiner {
+^bb0(%arg0: memref<f32>, %arg1: memref<f32>):
+  %0 = memref.load %arg0[] : memref<f32>
+  %1 = memref.load %arg1[] : memref<f32>
+  %2 = arith.addf %0, %1 : f32
+  memref.store %2, %arg0[] : memref<f32>
+  acc.yield %arg0 : memref<f32>
+}
+
+// CHECK-LABEL: func.func @test_reduction_if
+func.func @test_reduction_if(%r: memref<f32>, %a: memref<8xf32>, %cond: i1) {
+  %c8_i32 = arith.constant 8 : i32
+  %c1_i32 = arith.constant 1 : i32
+
+  %copyin = acc.copyin varPtr(%r : memref<f32>) -> memref<f32> {dataClause = #acc<data_clause acc_reduction>, implicit = true}
+
+  // CHECK: scf.if
+  // CHECK:   acc.parallel
+  // CHECK: } else {
+  // The else branch should have acc ops converted to host
+  // CHECK-NOT: acc.loop
+  // CHECK-NOT: acc.reduction
+  // CHECK-NOT: acc.private
+  // CHECK: }
+  acc.parallel combined(loop) dataOperands(%copyin : memref<f32>) if(%cond) {
+    %red = acc.reduction varPtr(%r : memref<f32>) recipe(@reduction_add_memref_f32) -> memref<f32>
+    %iter_var = memref.alloca() : memref<i32>
+    %priv = acc.private varPtr(%iter_var : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
+    acc.loop combined(parallel) vector private(%priv : memref<i32>) reduction(%red : memref<f32>) control(%iv : i32) = (%c1_i32 : i32) to (%c8_i32 : i32) step (%c1_i32 : i32) {
+      memref.store %iv, %priv[] : memref<i32>
+      %idx = memref.load %priv[] : memref<i32>
+      %idx_cast = arith.index_cast %idx : i32 to index
+      %elem = memref.load %a[%idx_cast] : memref<8xf32>
+      %r_val = memref.load %r[] : memref<f32>
+      %new_r = arith.addf %r_val, %elem : f32
+      memref.store %new_r, %r[] : memref<f32>
+      acc.yield
+    } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
+    acc.yield
+  }
+
+  acc.copyout accPtr(%copyin : memref<f32>) to varPtr(%r : memref<f32>) {dataClause = #acc<data_clause acc_reduction>, implicit = true}
+  return
+}
+
+// -----
+
+// Test that acc variable uses in host path are replaced with host variables
+// CHECK-LABEL: func.func @test_acc_var_replacement
+func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) {
+  %c0_i32 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+
+  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
+
+  // In the else branch, uses of %copyin should be replaced with %arg0
+  // CHECK: scf.if
+  // CHECK: } else {
+  // CHECK:   scf.for
+  // CHECK:     memref.store %{{.*}}, %arg0[%{{.*}}]
+  // CHECK: }
+  acc.parallel dataOperands(%copyin : memref<10xi32>) if(%cond) {
+    scf.for %i = %c1 to %c10 step %c1 {
+      // Use the acc ptr inside the region
+      memref.store %c0_i32, %copyin[%i] : memref<10xi32>
+    }
+    acc.yield
+  }
+
+  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
+  return
+}
+


        


More information about the Mlir-commits mailing list