[Mlir-commits] [mlir] [mlir][acc] Add ACCSpecializeForDevice and ACCSpecializeForHost passes (PR #173407)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 23 10:54:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openacc

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

Add two new transformation passes for specializing OpenACC IR for different execution contexts:

ACCSpecializeForDevice:
- Strips OpenACC constructs that are invalid in device code
- Replaces data entry ops with their var operands
- Unwraps regions from compute/data constructs
- Erases runtime operations (init, shutdown, wait, etc.)

This pass is applicable in two contexts:
1. Functions marked with `acc.specialized_routine` attribute, where the entire function body is device code
2. Non-specialized functions, where patterns are applied only to `acc` operations nested inside compute constructs (parallel, serial, kernels), not to the constructs themselves

ACCSpecializeForHost:
- Converts orphan `acc` operations for host execution
- Transforms `acc.atomic.*` to load/store via `PointerLikeType` interface
- Converts `acc.loop` to `scf.for` or `scf.execute_region`
- Replaces orphan data entry ops with their var operands

This pass operates in two modes:
1. Default (orphan) mode: Only converts `acc` operations that are not inside or attached to compute regions. Used for host `acc routine`s where compute constructs should be preserved.
2. Host fallback mode (enable-host-fallback=true): Converts ALL `acc` operations including compute constructs, data regions, and runtime ops. This is used to allow testing of the full conversion. These patterns will be used to handle conditional host execution of `acc` regions with if clause.

The pattern population functions (populateACCSpecializeForDevice, populateACCOrphanToHostPatterns, populateACCHostFallbackPatterns) are exposed so other passes can reuse these patterns.

---

Patch is 71.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173407.diff


9 Files Affected:

- (added) mlir/include/mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h (+124) 
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h (+32) 
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+58) 
- (added) mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForDevice.cpp (+176) 
- (added) mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForHost.cpp (+492) 
- (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+3) 
- (added) mlir/test/Dialect/OpenACC/acc-specialize-for-device.mlir (+204) 
- (added) mlir/test/Dialect/OpenACC/acc-specialize-for-host-fallback.mlir (+157) 
- (added) mlir/test/Dialect/OpenACC/acc-specialize-for-host.mlir (+404) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h
new file mode 100644
index 0000000000000..225d61821dab5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h
@@ -0,0 +1,124 @@
+//===- ACCSpecializePatterns.h - Common ACC Specialization Patterns ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains common rewrite pattern templates used by both
+// ACCSpecializeForHost and ACCSpecializeForDevice passes.
+//
+// The patterns provide the following transformations:
+//
+// - ACCOpReplaceWithVarConversion<OpTy>: Replaces a data entry operation
+//   with its var operand. Used for ops like acc.copyin, acc.create, etc.
+//
+// - ACCOpEraseConversion<OpTy>: Simply erases an operation. Used for
+//   data exit ops like acc.copyout, acc.delete, and runtime ops.
+//
+// - ACCRegionUnwrapConversion<OpTy>: Inlines the region of an operation
+//   and erases the wrapper. Used for structured data constructs
+//   (acc.data, acc.host_data) and compute constructs (acc.parallel, etc.)
+//
+// - ACCDeclareEnterOpConversion: Erases acc.declare_enter and its
+//   associated acc.declare_exit operation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H
+#define MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H
+
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace acc {
+
+//===----------------------------------------------------------------------===//
+// Generic pattern templates for ACC specialization
+//===----------------------------------------------------------------------===//
+
+/// Pattern to replace an ACC op with its var operand.
+/// Used for data entry ops like acc.copyin, acc.create, acc.attach, etc.
+template <typename OpTy>
+class ACCOpReplaceWithVarConversion : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    // Replace this op with its var operand; it's possible the op has no uses
+    // if the op that had previously used it was already converted.
+    if (op->use_empty()) {
+      rewriter.eraseOp(op);
+    } else {
+      rewriter.replaceOp(op, op.getVar());
+    }
+    return success();
+  }
+};
+
+/// Pattern to simply erase an ACC op (for ops with no results).
+/// Used for data exit ops like acc.copyout, acc.delete, acc.detach, etc.
+template <typename OpTy>
+class ACCOpEraseConversion : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    assert(op->getNumResults() == 0 && "expected op with no results");
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Pattern to unwrap a region from an ACC op and erase the wrapper.
+/// Moves the region's contents to the parent block and removes the wrapper op.
+/// Used for structured data constructs (acc.data, acc.host_data,
+/// acc.kernel_environment, acc.declare) and compute constructs (acc.parallel,
+/// acc.serial, acc.kernels).
+template <typename OpTy>
+class ACCRegionUnwrapConversion : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(OpTy op,
+                                PatternRewriter &rewriter) const override {
+    assert(op.getRegion().hasOneBlock() && "expected one block");
+    Block *block = &op.getRegion().front();
+    // Erase the terminator (acc.yield or acc.terminator) before unwrapping
+    rewriter.eraseOp(block->getTerminator());
+    rewriter.inlineBlockBefore(block, op);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Pattern to erase acc.declare_enter and its associated acc.declare_exit.
+/// The declare_enter produces a token that is consumed by declare_exit.
+class ACCDeclareEnterOpConversion
+    : public OpRewritePattern<acc::DeclareEnterOp> {
+  using OpRewritePattern<acc::DeclareEnterOp>::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(acc::DeclareEnterOp op,
+                                PatternRewriter &rewriter) const override {
+    // If the enter token is used by an exit, erase exit first.
+    if (!op->use_empty()) {
+      assert(op->hasOneUse() && "expected one use");
+      auto exitOp = dyn_cast<acc::DeclareExitOp>(*op->getUsers().begin());
+      assert(exitOp && "expected declare exit op");
+      rewriter.eraseOp(exitOp);
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+} // namespace acc
+} // namespace mlir
+
+#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_ACCSPECIALIZEPATTERNS_H
+
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
index 27f65aa15f040..b929c3d03dba4 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -22,9 +23,40 @@ class FuncOp;
 
 namespace acc {
 
+class OpenACCSupport;
+
 #define GEN_PASS_DECL
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
 
+//===----------------------------------------------------------------------===//
+// ACCSpecializeForDevice patterns
+//===----------------------------------------------------------------------===//
+
+/// Populates all patterns for device specialization.
+/// In specialized device code (such as specialized acc routine), many ACC
+/// operations do not make sense because they are host-side constructs. This
+/// function adds patterns to remove or transform them.
+void populateACCSpecializeForDevicePatterns(RewritePatternSet &patterns);
+
+//===----------------------------------------------------------------------===//
+// ACCSpecializeForHost patterns
+//===----------------------------------------------------------------------===//
+
+/// Populates patterns for converting orphan ACC operations to host.
+/// All patterns check that the operation is NOT inside or associated with a
+/// compute region before converting.
+/// @param enableLoopConversion Whether to convert orphan acc.loop operations.
+void populateACCOrphanToHostPatterns(RewritePatternSet &patterns,
+                                     OpenACCSupport &accSupport,
+                                     bool enableLoopConversion = true);
+
+/// Populates all patterns for host fallback path (when `if` clause evaluates
+/// to false). In this mode, ALL ACC operations should be converted or removed.
+/// @param enableLoopConversion Whether to convert orphan acc.loop operations.
+void populateACCHostFallbackPatterns(RewritePatternSet &patterns,
+                                     OpenACCSupport &accSupport,
+                                     bool enableLoopConversion = true);
+
 /// Generate the code for registering conversion passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index 253311e12932d..e10fde3c2691f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -194,4 +194,62 @@ def ACCLoopTiling : Pass<"acc-loop-tiling", "mlir::func::FuncOp"> {
   ];
 }
 
+def ACCSpecializeForDevice : Pass<"acc-specialize-for-device", "mlir::func::FuncOp"> {
+  let summary = "Strip OpenACC constructs inside device code";
+  let description = [{
+    In a specialized acc routine or compute construct, many OpenACC operations
+    do not make sense because they are host-side constructs. This pass removes
+    or transforms these operations appropriately.
+
+    The following operations are handled:
+    - Data entry ops (replaced with var): acc.attach, acc.copyin, acc.create,
+      acc.declare_device_resident, acc.declare_link, acc.deviceptr,
+      acc.get_deviceptr, acc.nocreate, acc.present, acc.update_device,
+      acc.use_device
+    - Data exit ops (erased): acc.copyout, acc.delete, acc.detach,
+      acc.update_host
+    - Structured data (inline region): acc.data, acc.host_data,
+      acc.kernel_environment
+    - Unstructured data (erased): acc.enter_data, acc.exit_data, acc.update,
+      acc.declare_enter, acc.declare_exit
+    - Compute constructs (inline region): acc.parallel, acc.serial, acc.kernels
+    - Runtime ops (erased): acc.init, acc.shutdown, acc.set, acc.wait
+  }];
+  let dependentDialects = ["mlir::acc::OpenACCDialect"];
+}
+
+def ACCSpecializeForHost : Pass<"acc-specialize-for-host", "mlir::func::FuncOp"> {
+  let summary = "Convert OpenACC operations for host execution";
+  let description = [{
+    This pass converts OpenACC operations to host-compatible representations.
+    It serves as a conversion pass that transforms ACC constructs to enable
+    execution on the host rather than on accelerator devices.
+
+    There are two modes of operation:
+
+    1. Default mode (orphan operations only): Only orphan operations that are
+       not allowed outside compute regions are converted. Structured/unstructured
+       data constructs, compute constructs, and their associated data operations
+       are NOT removed.
+
+    2. Host fallback mode (enableHostFallback=true): ALL ACC operations within
+       the region are converted to host equivalents. This is used when the `if`
+       clause evaluates to false at runtime.
+
+    The following operations are handled:
+    - Atomic ops: converted to load/store operations
+    - Loop ops: converted to scf.for or scf.execute_region
+    - Data entry ops (orphan): replaced with var operand
+    - In host fallback mode: all data, compute, and runtime ops are removed
+  }];
+  let dependentDialects = ["mlir::acc::OpenACCDialect",
+      "mlir::scf::SCFDialect"];
+  let options = [
+    Option<"enableHostFallback", "enable-host-fallback", "bool", "false",
+           "Enable host fallback mode which converts ALL ACC operations, "
+           "not just orphan operations. Use this when the `if` clause "
+           "evaluates to false.">
+  ];
+}
+
 #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForDevice.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForDevice.cpp
new file mode 100644
index 0000000000000..e23291497165f
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForDevice.cpp
@@ -0,0 +1,176 @@
+//===- ACCSpecializeForDevice.cpp -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass strips OpenACC constructs that are invalid or unnecessary inside
+// device code (specialized acc routines or compute construct regions).
+//
+// Overview:
+// ---------
+// In a specialized acc routine or compute construct, many OpenACC operations
+// do not make sense because they are host-side constructs. This pass removes
+// or transforms these operations appropriately:
+//
+// - Data operations that manage device memory from host perspective
+// - Compute constructs that launch kernels (we're already on device)
+// - Runtime operations like init/shutdown/set/wait
+//
+// Transformations:
+// ----------------
+// The pass applies the following transformations:
+//
+// 1. Data Entry Ops (replaced with var operand):
+//    acc.attach, acc.copyin, acc.create, acc.declare_device_resident,
+//    acc.declare_link, acc.deviceptr, acc.get_deviceptr, acc.nocreate,
+//    acc.present, acc.update_device, acc.use_device
+//
+// 2. Data Exit Ops (erased):
+//    acc.copyout, acc.delete, acc.detach, acc.update_host
+//
+// 3. Structured Data/Compute Constructs (region inlined):
+//    acc.data, acc.host_data, acc.kernel_environment, acc.parallel,
+//    acc.serial, acc.kernels
+//
+// 4. Unstructured Data Ops (erased):
+//    acc.enter_data, acc.exit_data, acc.update, acc.declare_enter,
+//    acc.declare_exit
+//
+// 5. Runtime Ops (erased):
+//    acc.init, acc.shutdown, acc.set, acc.wait
+//
+// Scope of Application:
+// ---------------------
+// - For functions with `acc.specialized_routine` attribute: patterns are
+//   applied to the entire function body.
+// - For non-specialized functions: patterns are applied only to ACC
+//   operations INSIDE compute constructs (parallel, serial, kernels),
+//   not to the compute constructs themselves or their data operands.
+//
+// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
+// transformed by this pass as they are valid in device code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCSPECIALIZEFORDEVICE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+class ACCSpecializeForDevice
+    : public acc::impl::ACCSpecializeForDeviceBase<ACCSpecializeForDevice> {
+public:
+  using ACCSpecializeForDeviceBase<
+      ACCSpecializeForDevice>::ACCSpecializeForDeviceBase;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+
+    RewritePatternSet patterns(&getContext());
+    acc::populateACCSpecializeForDevicePatterns(patterns);
+    GreedyRewriteConfig config;
+    config.setUseTopDownTraversal(true);
+
+    if (acc::isSpecializedAccRoutine(func)) {
+      // For specialized acc routines, apply patterns to the entire function
+      (void)applyPatternsGreedily(func, std::move(patterns), config);
+    } else {
+      // For non-specialized functions, apply patterns only to ACC operations
+      // inside compute constructs (not to the compute constructs themselves).
+      SmallVector<Operation *> opsToTransform;
+      func.walk([&](Operation *op) {
+        if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) {
+          // Walk inside the compute construct and collect ACC ops
+          op->walk([&](Operation *innerOp) {
+            // Skip the compute construct itself
+            if (innerOp == op) {
+              return;
+            }
+            if (isa<acc::OpenACCDialect>(innerOp->getDialect())) {
+              opsToTransform.push_back(innerOp);
+            }
+          });
+        }
+      });
+      if (!opsToTransform.empty()) {
+        (void)applyOpPatternsGreedily(opsToTransform, std::move(patterns),
+                                      config);
+      }
+    }
+  }
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population functions
+//===----------------------------------------------------------------------===//
+
+void mlir::acc::populateACCSpecializeForDevicePatterns(
+    RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+
+  // Declare patterns - erase declare_enter and its associated declare_exit
+  patterns.insert<ACCDeclareEnterOpConversion>(context);
+
+  // Data entry ops - replaced with their var operand
+  // Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
+  // included here - they are valid in device code
+  patterns.insert<ACCOpReplaceWithVarConversion<acc::AttachOp>,
+                  ACCOpReplaceWithVarConversion<acc::CopyinOp>,
+                  ACCOpReplaceWithVarConversion<acc::CreateOp>,
+                  ACCOpReplaceWithVarConversion<acc::DeclareDeviceResidentOp>,
+                  ACCOpReplaceWithVarConversion<acc::DeclareLinkOp>,
+                  ACCOpReplaceWithVarConversion<acc::DevicePtrOp>,
+                  ACCOpReplaceWithVarConversion<acc::GetDevicePtrOp>,
+                  ACCOpReplaceWithVarConversion<acc::NoCreateOp>,
+                  ACCOpReplaceWithVarConversion<acc::PresentOp>,
+                  ACCOpReplaceWithVarConversion<acc::UpdateDeviceOp>,
+                  ACCOpReplaceWithVarConversion<acc::UseDeviceOp>>(context);
+
+  // Data exit ops - simply erased (no results)
+  patterns.insert<ACCOpEraseConversion<acc::CopyoutOp>,
+                  ACCOpEraseConversion<acc::DeleteOp>,
+                  ACCOpEraseConversion<acc::DetachOp>,
+                  ACCOpEraseConversion<acc::UpdateHostOp>>(context);
+
+  // Structured data constructs - unwrap their regions
+  patterns.insert<ACCRegionUnwrapConversion<acc::DataOp>,
+                  ACCRegionUnwrapConversion<acc::HostDataOp>,
+                  ACCRegionUnwrapConversion<acc::KernelEnvironmentOp>>(context);
+
+  // Compute constructs - unwrap their regions
+  patterns.insert<ACCRegionUnwrapConversion<acc::ParallelOp>,
+                  ACCRegionUnwrapConversion<acc::SerialOp>,
+                  ACCRegionUnwrapConversion<acc::KernelsOp>>(context);
+
+  // Unstructured data operations - erase them
+  patterns.insert<ACCOpEraseConversion<acc::EnterDataOp>,
+                  ACCOpEraseConversion<acc::ExitDataOp>,
+                  ACCOpEraseConversion<acc::UpdateOp>>(context);
+
+  // Runtime operations - erase them
+  patterns.insert<ACCOpEraseConversion<acc::InitOp>,
+                  ACCOpEraseConversion<acc::ShutdownOp>,
+                  ACCOpEraseConversion<acc::SetOp>,
+                  ACCOpEraseConversion<acc::WaitOp>>(context);
+}
+
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForHost.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForHost.cpp
new file mode 100644
index 0000000000000..0865cab127da4
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCSpecializeForHost.cpp
@@ -0,0 +1,492 @@
+//===- ACCSpecializeForHost.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass converts OpenACC operations to host-compatible representations,
+// enabling execution on the host rather than on accelerator devices.
+//
+// Overview:
+// ---------
+// The pass operates in two modes depending on the `enableHostFallback` option:
+//
+// 1. Default Mode (Orphan Operations Only):
+//    Only converts "orphan" ACC operations that are not inside or attached to
+//    compute regions. This is used for host routines (acc routine marked for
+//    host) where structured/unstructured data constructs, compute constructs,
+//    and their associated data operations should be preserved.
+//
+// 2. Host Fallback Mode (enableHostFallback=true):
+//    Converts ALL ACC operations within the region to host equivalents. This
+//    is used when the `if` clause evaluates to false at runtime and the
+//    entire ACC region needs to fall back to host execution.
+//
+// Transformations (Orphan Mode):
+// ------------------------------
+// The following orphan operations are converted:
+//
+// 1. Atomic Ops (converted to load/store):
+//    acc.atomic.update -> load + compute + store
+//    acc.atomic.read -> load + store (copy)
+//    acc.atomic.write -> store
+//    acc.atomic.capture -> inline region contents
+//
+// 2. Loop Ops (...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/173407


More information about the Mlir-commits mailing list