[Mlir-commits] [mlir] 0a415db - Revert "[mlir][acc] Add ACCIfClauseLowering pass (#173447)"

Emilio Cota llvmlistbot at llvm.org
Wed Dec 24 19:33:29 PST 2025


Author: Emilio Cota
Date: 2025-12-24T22:29:42-05:00
New Revision: 0a415db5e20b362fde0d2033c9a077828188b59f

URL: https://github.com/llvm/llvm-project/commit/0a415db5e20b362fde0d2033c9a077828188b59f
DIFF: https://github.com/llvm/llvm-project/commit/0a415db5e20b362fde0d2033c9a077828188b59f.diff

LOG: Revert "[mlir][acc] Add ACCIfClauseLowering pass (#173447)"

This reverts commit f64bc988959f1ac028d2b64500791014537d3706.

The revert is needed because this commit depends on a previous commit
(PR #173407) that is about to be reverted due to a use-after-free -- see
https://github.com/llvm/llvm-project/pull/173407#issuecomment-3690793823

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
    mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Removed: 
    mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
    mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index 68a52e0706d60..e10fde3c2691f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -252,39 +252,4 @@ def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp">
   ];
 }
 
-def ACCIfClauseLowering : Pass<"acc-if-clause-lowering", "mlir::func::FuncOp"> {
-  let summary = "Lower if clauses in ACC compute constructs";
-  let description = [{
-    This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
-    `if` clauses using region specialization. It creates two execution paths:
-    device execution when the condition is true, host execution when false.
-
-    When an ACC compute construct has an `if` clause, the construct should only
-    execute on the device when the condition is true. If the condition is false,
-    the code should execute on the host instead. This pass transforms:
-
-    ```mlir
-    acc.parallel if(%cond) { ... }
-    ```
-
-    Into:
-
-    ```mlir
-    scf.if %cond {
-      // Device path: clone data ops, compute construct without if, exit ops
-      acc.parallel { ... }
-    } else {
-      // Host path: original region body with ACC ops converted to host
-    }
-    ```
-
-    The transformation handles:
-    - Data entry operations (acc.copyin, acc.create, etc.) are cloned to device path
-    - Data exit operations (acc.copyout, acc.delete, etc.) are cloned to device path
-    - The host path uses `populateACCHostFallbackPatterns` to convert ACC ops
-  }];
-  let dependentDialects = ["mlir::acc::OpenACCDialect",
-      "mlir::scf::SCFDialect"];
-}
-
 #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
deleted file mode 100644
index 5524c291a80e7..0000000000000
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp
+++ /dev/null
@@ -1,245 +0,0 @@
-//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
-// `if` clauses using region specialization. It creates two execution paths:
-// device execution when the condition is true, host execution when false.
-//
-// Overview:
-// ---------
-// When an ACC compute construct has an `if` clause, the construct should only
-// execute on the device when the condition is true. If the condition is false,
-// the code should execute on the host instead. This pass transforms:
-//
-//   acc.parallel if(%cond) { ... }
-//
-// Into:
-//
-//   scf.if %cond {
-//     // Device path: clone data ops, compute construct without if, exit ops
-//     acc.parallel { ... }
-//   } else {
-//     // Host path: original region body with ACC ops converted to host
-//   }
-//
-// Transformations:
-// ----------------
-// For each compute construct with an `if` clause:
-//
-// 1. Device Path (true branch):
-//    - Clone data entry operations (acc.copyin, acc.create, etc.)
-//    - Clone the compute construct without the `if` clause
-//    - Clone data exit operations (acc.copyout, acc.delete, etc.)
-//
-// 2. Host Path (false branch):
-//    - Move the original region body to the else branch
-//    - Apply host fallback patterns to convert ACC ops to host equivalents
-//
-// 3. Cleanup:
-//    - Erase the original compute construct and data operations
-//    - Replace uses of ACC variables with host variables in the else branch
-//
-// Requirements:
-// -------------
-// To use this pass in a pipeline, the following requirements exist:
-//
-// 1. Analysis Registration (Optional): If custom behavior is needed for
-//    emitting not-yet-implemented messages for unsupported cases, the pipeline
-//    should pre-register the `acc::OpenACCSupport` analysis.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
-
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
-#include "mlir/Dialect/OpenACC/OpenACC.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
-
-namespace mlir {
-namespace acc {
-#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
-} // namespace acc
-} // namespace mlir
-
-#define DEBUG_TYPE "acc-if-clause-lowering"
-
-using namespace mlir;
-using namespace mlir::acc;
-
-namespace {
-
-class ACCIfClauseLowering
-    : public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
-  using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
-
-private:
-  OpenACCSupport *accSupport = nullptr;
-
-  void convertHostRegion(Operation *computeOp, Region &region);
-
-  template <typename OpTy>
-  void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
-                                        SmallVector<Operation *> &eraseOps);
-
-public:
-  void runOnOperation() override;
-};
-
-void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
-                                            Region &region) {
-  // Only collect ACC dialect operations - other ops don't need conversion
-  SmallVector<Operation *> hostOps;
-  region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (isa<acc::OpenACCDialect>(op->getDialect()))
-      hostOps.push_back(op);
-  });
-
-  RewritePatternSet patterns(computeOp->getContext());
-  populateACCHostFallbackPatterns(patterns, *accSupport);
-
-  GreedyRewriteConfig config;
-  config.setUseTopDownTraversal(true);
-  config.setStrictness(GreedyRewriteStrictness::ExistingOps);
-  if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
-    accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
-}
-
-// Template function to handle if condition conversion for ACC compute
-// constructs
-template <typename OpTy>
-void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
-    OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
-  Value ifCond = computeConstructOp.getIfCond();
-  if (!ifCond)
-    return;
-
-  IRRewriter rewriter(computeConstructOp);
-
-  LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
-                          << " with if condition: " << computeConstructOp
-                          << "\n");
-
-  // Collect data clause operations that need to be recreated in the if
-  // condition
-  SmallVector<Operation *> dataEntryOps;
-  SmallVector<Operation *> dataExitOps;
-
-  // Collect data entry operations
-  for (Value operand : computeConstructOp.getDataClauseOperands()) {
-    if (Operation *defOp = operand.getDefiningOp())
-      if (isa<ACC_DATA_ENTRY_OPS>(defOp))
-        dataEntryOps.push_back(defOp);
-  }
-
-  // Find corresponding exit operations for each entry operation.
-  // Iterate backwards through entry ops since exit ops appear in reverse order.
-  for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
-    for (Operation *user : dataEntryOp->getUsers())
-      if (isa<ACC_DATA_EXIT_OPS>(user))
-        dataExitOps.push_back(user);
-
-  // Create scf.if with device and host execution paths
-  auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
-                                TypeRange{}, ifCond, /*withElseRegion=*/true);
-
-  // Declare deviceMapping at function scope for later use
-  IRMapping deviceMapping;
-
-  // Device execution path (true branch)
-  Block &thenBlock = ifOp.getThenRegion().front();
-  rewriter.setInsertionPointToStart(&thenBlock);
-
-  // Clone data entry operations
-  SmallVector<Value> deviceDataOperands;
-
-  LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
-                          << " data entry operations for device path\n");
-
-  for (Operation *dataOp : dataEntryOps) {
-    Operation *clonedDataOp = rewriter.clone(*dataOp, deviceMapping);
-    deviceDataOperands.push_back(clonedDataOp->getResult(0));
-    deviceMapping.map(dataOp->getResult(0), clonedDataOp->getResult(0));
-  }
-
-  // Create new compute op without if condition for device execution by
-  // cloning
-  OpTy newComputeOp = cast<OpTy>(
-      rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
-  newComputeOp.getIfCondMutable().clear();
-  newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
-
-  // Clone data exit operations
-  rewriter.setInsertionPointAfter(newComputeOp);
-  for (Operation *dataOp : dataExitOps)
-    rewriter.clone(*dataOp, deviceMapping);
-
-  rewriter.setInsertionPointToEnd(&thenBlock);
-  if (!thenBlock.getTerminator())
-    scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
-
-  // Host execution path (false branch)
-  if (!computeConstructOp.getRegion().hasOneBlock()) {
-    accSupport->emitNYI(computeConstructOp.getLoc(),
-                        "region with multiple blocks");
-    return;
-  }
-
-  // Don't need to clone original ops, just take them and legalize for host
-  ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
-
-  // Swap acc yield for scf yield
-  Block &elseBlock = ifOp.getElseRegion().front();
-  elseBlock.getTerminator()->erase();
-  rewriter.setInsertionPointToEnd(&elseBlock);
-  scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
-
-  convertHostRegion(computeConstructOp, ifOp.getElseRegion());
-
-  // The original op is now empty and can be erased
-  eraseOps.push_back(computeConstructOp);
-
-  // TODO: Can probably 'move' the data ops instead of cloning them
-  // which would eliminate need to explicitly erase
-  for (Operation *dataOp : dataExitOps)
-    eraseOps.push_back(dataOp);
-
-  for (Operation *dataOp : dataEntryOps) {
-    // The new host code may contain uses of the acc variables. Replace them by
-    // the host values.
-    getAccVar(dataOp).replaceAllUsesWith(getVar(dataOp));
-    eraseOps.push_back(dataOp);
-  }
-}
-
-void ACCIfClauseLowering::runOnOperation() {
-  func::FuncOp funcOp = getOperation();
-  accSupport = &getAnalysis<OpenACCSupport>();
-
-  SmallVector<Operation *> eraseOps;
-  funcOp.walk([&](Operation *op) {
-    if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
-      lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
-    else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
-      lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
-    else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
-      lowerIfClauseForComputeConstruct(serialOp, eraseOps);
-  });
-
-  for (Operation *op : eraseOps)
-    op->erase();
-}
-
-} // namespace

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 3a0ca338766e4..e94ac6f332834 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIROpenACCTransforms
-  ACCIfClauseLowering.cpp
   ACCImplicitData.cpp
   ACCLoopTiling.cpp
   ACCImplicitDeclare.cpp

diff  --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
deleted file mode 100644
index 3f0df18619bc0..0000000000000
--- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir
+++ /dev/null
@@ -1,224 +0,0 @@
-// RUN: mlir-opt %s -acc-if-clause-lowering -split-input-file | FileCheck %s
-
-// Test acc.parallel with if condition
-// CHECK-LABEL: func.func @test_parallel_if
-func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
-  %c0_i32 = arith.constant 0 : i32
-  %c1 = arith.constant 1 : index
-  %c10 = arith.constant 10 : index
-
-  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
-  %create = acc.create varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copyout>}
-
-  // CHECK-NOT: acc.parallel if
-  // CHECK: scf.if %{{.*}} {
-  // CHECK:   %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}}) -> memref<10xi32>
-  // CHECK:   %[[CREATE:.*]] = acc.create varPtr(%{{.*}}) -> memref<10xi32>
-  // CHECK:   acc.parallel dataOperands(%[[COPYIN]], %[[CREATE]] : memref<10xi32>, memref<10xi32>) {
-  // CHECK:     scf.for
-  // CHECK:     acc.yield
-  // CHECK:   }
-  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<10xi32>)
-  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr(%{{.*}} : memref<10xi32>)
-  // CHECK: } else {
-  // CHECK:   scf.for
-  // CHECK: }
-  acc.parallel dataOperands(%copyin, %create : memref<10xi32>, memref<10xi32>) if(%cond) {
-    scf.for %i = %c1 to %c10 step %c1 {
-      memref.store %c0_i32, %arg0[%i] : memref<10xi32>
-    }
-    acc.yield
-  }
-
-  acc.delete accPtr(%create : memref<10xi32>)
-  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
-  return
-}
-
-// -----
-
-// Test acc.kernels with if condition
-// CHECK-LABEL: func.func @test_kernels_if
-func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
-  %c1_i32 = arith.constant 1 : i32
-  %c1 = arith.constant 1 : index
-  %c5 = arith.constant 5 : index
-
-  %copyin = acc.copyin varPtr(%arg0 : memref<5xi32>) -> memref<5xi32>
-  %create = acc.create varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> {dataClause = #acc<data_clause acc_copyout>}
-
-  // CHECK-NOT: acc.kernels if
-  // CHECK: scf.if %{{.*}} {
-  // CHECK:   %[[COPYIN:.*]] = acc.copyin
-  // CHECK:   %[[CREATE:.*]] = acc.create
-  // CHECK:   acc.kernels dataOperands(%[[COPYIN]], %[[CREATE]] : memref<5xi32>, memref<5xi32>) {
-  // CHECK:     scf.for
-  // CHECK:     acc.terminator
-  // CHECK:   }
-  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<5xi32>)
-  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<5xi32>) to varPtr(%{{.*}} : memref<5xi32>)
-  // CHECK: } else {
-  // CHECK:   scf.for
-  // CHECK: }
-  acc.kernels dataOperands(%copyin, %create : memref<5xi32>, memref<5xi32>) if(%cond) {
-    scf.for %i = %c1 to %c5 step %c1 {
-      memref.store %c1_i32, %arg0[%i] : memref<5xi32>
-    }
-    acc.terminator
-  }
-
-  acc.delete accPtr(%create : memref<5xi32>)
-  acc.copyout accPtr(%copyin : memref<5xi32>) to varPtr(%arg0 : memref<5xi32>)
-  return
-}
-
-// -----
-
-// Test acc.serial with if condition
-// CHECK-LABEL: func.func @test_serial_if
-func.func @test_serial_if(%arg0: memref<8xi32>, %cond: i1) {
-  %c2_i32 = arith.constant 2 : i32
-  %c1 = arith.constant 1 : index
-  %c8 = arith.constant 8 : index
-
-  %copyin = acc.copyin varPtr(%arg0 : memref<8xi32>) -> memref<8xi32>
-  %create = acc.create varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> {dataClause = #acc<data_clause acc_copyout>}
-
-  // CHECK-NOT: acc.serial if
-  // CHECK: scf.if %{{.*}} {
-  // CHECK:   %[[COPYIN:.*]] = acc.copyin
-  // CHECK:   %[[CREATE:.*]] = acc.create
-  // CHECK:   acc.serial dataOperands(%[[COPYIN]], %[[CREATE]] : memref<8xi32>, memref<8xi32>) {
-  // CHECK:     scf.for
-  // CHECK:     acc.yield
-  // CHECK:   }
-  // CHECK:   acc.delete accPtr(%[[CREATE]] : memref<8xi32>)
-  // CHECK:   acc.copyout accPtr(%[[COPYIN]] : memref<8xi32>) to varPtr(%{{.*}} : memref<8xi32>)
-  // CHECK: } else {
-  // CHECK:   scf.for
-  // CHECK: }
-  acc.serial dataOperands(%copyin, %create : memref<8xi32>, memref<8xi32>) if(%cond) {
-    scf.for %i = %c1 to %c8 step %c1 {
-      memref.store %c2_i32, %arg0[%i] : memref<8xi32>
-    }
-    acc.yield
-  }
-
-  acc.delete accPtr(%create : memref<8xi32>)
-  acc.copyout accPtr(%copyin : memref<8xi32>) to varPtr(%arg0 : memref<8xi32>)
-  return
-}
-
-// -----
-
-// Test that acc.parallel without if condition is not modified
-// CHECK-LABEL: func.func @test_parallel_no_if
-func.func @test_parallel_no_if(%arg0: memref<10xi32>) {
-  %c0_i32 = arith.constant 0 : i32
-  %c1 = arith.constant 1 : index
-  %c10 = arith.constant 10 : index
-
-  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
-
-  // CHECK-NOT: scf.if
-  // CHECK: acc.parallel dataOperands(%{{.*}}) {
-  acc.parallel dataOperands(%copyin : memref<10xi32>) {
-    scf.for %i = %c1 to %c10 step %c1 {
-      memref.store %c0_i32, %arg0[%i] : memref<10xi32>
-    }
-    acc.yield
-  }
-
-  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
-  return
-}
-
-// -----
-
-// Test with private and reduction clauses inside compute construct
-acc.private.recipe @privatization_memref_i32 : memref<i32> init {
-^bb0(%arg0: memref<i32>):
-  %0 = memref.alloca() : memref<i32>
-  acc.yield %0 : memref<i32>
-}
-
-acc.reduction.recipe @reduction_add_memref_f32 : memref<f32> reduction_operator <add> init {
-^bb0(%arg0: memref<f32>):
-  %cst = arith.constant 0.000000e+00 : f32
-  %0 = memref.alloca() : memref<f32>
-  memref.store %cst, %0[] : memref<f32>
-  acc.yield %0 : memref<f32>
-} combiner {
-^bb0(%arg0: memref<f32>, %arg1: memref<f32>):
-  %0 = memref.load %arg0[] : memref<f32>
-  %1 = memref.load %arg1[] : memref<f32>
-  %2 = arith.addf %0, %1 : f32
-  memref.store %2, %arg0[] : memref<f32>
-  acc.yield %arg0 : memref<f32>
-}
-
-// CHECK-LABEL: func.func @test_reduction_if
-func.func @test_reduction_if(%r: memref<f32>, %a: memref<8xf32>, %cond: i1) {
-  %c8_i32 = arith.constant 8 : i32
-  %c1_i32 = arith.constant 1 : i32
-
-  %copyin = acc.copyin varPtr(%r : memref<f32>) -> memref<f32> {dataClause = #acc<data_clause acc_reduction>, implicit = true}
-
-  // CHECK: scf.if
-  // CHECK:   acc.parallel
-  // CHECK: } else {
-  // The else branch should have acc ops converted to host
-  // CHECK-NOT: acc.loop
-  // CHECK-NOT: acc.reduction
-  // CHECK-NOT: acc.private
-  // CHECK: }
-  acc.parallel combined(loop) dataOperands(%copyin : memref<f32>) if(%cond) {
-    %red = acc.reduction varPtr(%r : memref<f32>) recipe(@reduction_add_memref_f32) -> memref<f32>
-    %iter_var = memref.alloca() : memref<i32>
-    %priv = acc.private varPtr(%iter_var : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
-    acc.loop combined(parallel) vector private(%priv : memref<i32>) reduction(%red : memref<f32>) control(%iv : i32) = (%c1_i32 : i32) to (%c8_i32 : i32) step (%c1_i32 : i32) {
-      memref.store %iv, %priv[] : memref<i32>
-      %idx = memref.load %priv[] : memref<i32>
-      %idx_cast = arith.index_cast %idx : i32 to index
-      %elem = memref.load %a[%idx_cast] : memref<8xf32>
-      %r_val = memref.load %r[] : memref<f32>
-      %new_r = arith.addf %r_val, %elem : f32
-      memref.store %new_r, %r[] : memref<f32>
-      acc.yield
-    } attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
-    acc.yield
-  }
-
-  acc.copyout accPtr(%copyin : memref<f32>) to varPtr(%r : memref<f32>) {dataClause = #acc<data_clause acc_reduction>, implicit = true}
-  return
-}
-
-// -----
-
-// Test that acc variable uses in host path are replaced with host variables
-// CHECK-LABEL: func.func @test_acc_var_replacement
-func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) {
-  %c0_i32 = arith.constant 0 : i32
-  %c1 = arith.constant 1 : index
-  %c10 = arith.constant 10 : index
-
-  %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
-
-  // In the else branch, uses of %copyin should be replaced with %arg0
-  // CHECK: scf.if
-  // CHECK: } else {
-  // CHECK:   scf.for
-  // CHECK:     memref.store %{{.*}}, %arg0[%{{.*}}]
-  // CHECK: }
-  acc.parallel dataOperands(%copyin : memref<10xi32>) if(%cond) {
-    scf.for %i = %c1 to %c10 step %c1 {
-      // Use the acc ptr inside the region
-      memref.store %c0_i32, %copyin[%i] : memref<10xi32>
-    }
-    acc.yield
-  }
-
-  acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
-  return
-}
-


        


More information about the Mlir-commits mailing list