[Mlir-commits] [mlir] ce6dd8b - [flang][acc] Add ACCUseDeviceCanonicalizer pass (#175228)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 14 08:19:18 PST 2026


Author: Razvan Lupusoru
Date: 2026-01-14T08:19:10-08:00
New Revision: ce6dd8b02d13c33612ddeb4284a37ba6e13ee446

URL: https://github.com/llvm/llvm-project/commit/ce6dd8b02d13c33612ddeb4284a37ba6e13ee446
DIFF: https://github.com/llvm/llvm-project/commit/ce6dd8b02d13c33612ddeb4284a37ba6e13ee446.diff

LOG: [flang][acc] Add ACCUseDeviceCanonicalizer pass (#175228)

This pass canonicalizes the use_device clause on acc.host_data
constructs to enable simpler runtime lowering. For use_device operands
that are box types or references to boxes, the pass:

1. Extracts the host base address for mapping to a device address using
acc.use_device
2. Creates a new boxed descriptor with the device address as the base
address for use inside the host_data region

The pass also removes unused use_device clauses to reduce runtime calls.

This canonicalization hoists load/box_addr patterns out of the host_data
region so they are applied to the host variable before acc.use_device,
ensuring the device pointer is used directly inside the region.

Example transformation for a reference to a box (!fir.ref<!fir.box<>>):

Before:
```
  %ptr = acc.use_device varPtr(%ref : !fir.ref<!fir.box<!fir.ptr<i32>>>)
  acc.host_data dataOperands(%ptr) {
    %box = fir.load %ptr
    %addr = fir.box_addr %box
    // use %addr
  }
```

After:
```
  %box = fir.load %ref
  %addr = fir.box_addr %box
  %dev_ptr = acc.use_device varPtr(%addr : !fir.ptr<i32>)
  acc.host_data dataOperands(%dev_ptr) {
    %new_box = fir.embox %dev_ptr
    // use device pointer through new descriptor
  }
```

---------

Co-authored-by: nvptm <pmathew at nvidia.com>

Added: 
    flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp
    flang/test/Fir/OpenACC/use-device-canonicalizer.mlir

Modified: 
    flang/include/flang/Optimizer/OpenACC/Passes.h
    flang/include/flang/Optimizer/OpenACC/Passes.td
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/OpenACC/Passes.h b/flang/include/flang/Optimizer/OpenACC/Passes.h
index c27c7ebc3b06f..64ddb84e63c3e 100644
--- a/flang/include/flang/Optimizer/OpenACC/Passes.h
+++ b/flang/include/flang/Optimizer/OpenACC/Passes.h
@@ -30,6 +30,7 @@ namespace acc {
 
 std::unique_ptr<mlir::Pass> createACCInitializeFIRAnalysesPass();
 std::unique_ptr<mlir::Pass> createACCRecipeBufferizationPass();
+std::unique_ptr<mlir::Pass> createACCUseDeviceCanonicalizerPass();
 
 } // namespace acc
 } // namespace fir

diff  --git a/flang/include/flang/Optimizer/OpenACC/Passes.td b/flang/include/flang/Optimizer/OpenACC/Passes.td
index d947aa470494a..8579a471d9a56 100644
--- a/flang/include/flang/Optimizer/OpenACC/Passes.td
+++ b/flang/include/flang/Optimizer/OpenACC/Passes.td
@@ -49,4 +49,25 @@ def ACCRecipeBufferization
   }];
 }
 
+def ACCUseDeviceCanonicalizer
+    : Pass<"acc-use-device-canonicalizer", "mlir::func::FuncOp"> {
+  let summary = "Canonicalize acc.use_device operations for FIR box types";
+  let description = [{
+    This pass canonicalizes the use_device clause on a host_data construct such
+    that use_device(x) can be lowered to a simple runtime call that takes the
+    actual host pointer as argument.
+
+    For a use_device operand that is a box type or a reference to a box, the
+    pass:
+      1. Extracts the host base address for mapping to a device address using
+         acc.use_device.
+      2. Creates a new boxed descriptor with the device address as the base
+         address for use inside the host_data region.
+
+    The pass also removes unused use_device clauses, reducing the number of
+    runtime calls.
+  }];
+  let dependentDialects = ["mlir::acc::OpenACCDialect", "fir::FIROpsDialect"];
+}
+
 #endif // FORTRAN_OPTIMIZER_OPENACC_PASSES

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 062366f87eb09..cfbcf06c4b39e 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3399,7 +3399,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
-      // When CUDA Fotran is enabled, extra symbols are used in the host_data
+      // When CUDA Fortran is enabled, extra symbols are used in the host_data
       // region. Look for them and bind their values with the symbols in the
       // outer scope.
       if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {

diff  --git a/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp
new file mode 100644
index 0000000000000..51ab796021e4b
--- /dev/null
+++ b/flang/lib/Optimizer/OpenACC/Transforms/ACCUseDeviceCanonicalizer.cpp
@@ -0,0 +1,400 @@
+//===- ACCUseDeviceCanonicalizer.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass canonicalizes the use_device clause on a host_data construct such
+// that use_device(x) can be lowered to a simple runtime call that takes the
+// actual host pointer as argument.
+//
+// For a use_device operand that is a box type or a reference to a box, the
+// pass:
+//   1. Extracts the host base address for mapping to a device address using
+//      acc.use_device.
+//   2. Creates a new boxed descriptor with the device address as the base
+//      address for use inside the host_data region.
+//
+// The pass also removes unused use_device clauses, reducing the number of
+// runtime calls.
+//
+// Supported use_device operand types:
+//
+//   Scalars:
+//     - !fir.ref<i32>, !fir.ref<f64>, etc.
+//
+//   Arrays:
+//     - Explicit shape (no descriptor): !fir.ref<!fir.array<100xi32>>
+//     - Adjustable size: !fir.ref<!fir.array<?xi32>>
+//     - Assumed shape (handled by hoistBox): !fir.box<!fir.array<?xi32>>
+//     - Assumed size: !fir.ref<!fir.array<?xi32>>
+//     - Deferred shape (handled by hoistRefToBox):
+//         - Allocatable: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+//         - Pointer: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+//     - Subarray specification (handled by hoistBox):
+//     !fir.box<!fir.array<?xi32>>
+//
+//   Not yet supported:
+//     - Assumed rank arrays
+//     - Composite variables: !fir.ref<!fir.type<...>>
+//     - Array elements (device pointer arithmetic in host_data region)
+//     - Composite variable members
+//     - Fortran common blocks: use_device(/cm_block/)
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include <cassert>
+
+namespace fir::acc {
+#define GEN_PASS_DEF_ACCUSEDEVICECANONICALIZER
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace fir::acc
+
+#define DEBUG_TYPE "acc-use-device-canonicalizer"
+
+using namespace mlir;
+
+namespace {
+
+struct UseDeviceHostDataHoisting : public OpRewritePattern<acc::HostDataOp> {
+  using OpRewritePattern<acc::HostDataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(acc::HostDataOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> usedOperands;
+    SmallVector<Value> unusedUseDeviceOperands;
+    SmallVector<acc::UseDeviceOp> refToBoxUseDeviceOps;
+    SmallVector<acc::UseDeviceOp> boxUseDeviceOps;
+
+    for (Value operand : op.getDataClauseOperands()) {
+      if (acc::UseDeviceOp useDeviceOp =
+              operand.getDefiningOp<acc::UseDeviceOp>()) {
+        if (fir::isBoxAddress(useDeviceOp.getVar().getType())) {
+          if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+            refToBoxUseDeviceOps.push_back(useDeviceOp);
+        } else if (isa<fir::BoxType>(useDeviceOp.getVar().getType())) {
+          if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+            boxUseDeviceOps.push_back(useDeviceOp);
+        }
+
+        // host_data is the only user of this use_device operand - mark for
+        // removal
+        if (llvm::hasSingleElement(useDeviceOp->getUsers()))
+          unusedUseDeviceOperands.push_back(useDeviceOp.getResult());
+        else
+          usedOperands.push_back(useDeviceOp.getResult());
+      } else {
+        // Operand is not an `acc.use_device` result, keep it as is.
+        usedOperands.push_back(operand);
+      }
+    }
+
+    assert(!usedOperands.empty() && "Host_data operation has no used operands");
+
+    if (!unusedUseDeviceOperands.empty()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "ACCUseDeviceCanonicalizer: Removing "
+                 << unusedUseDeviceOperands.size()
+                 << " unused use_device operands from host_data operation\n");
+
+      // Update the host_data operation to have only used operands
+      rewriter.modifyOpInPlace(op, [&]() {
+        op.getDataClauseOperandsMutable().assign(usedOperands);
+      });
+
+      // Remove unused use_device operations
+      for (Value operand : unusedUseDeviceOperands) {
+        acc::UseDeviceOp useDeviceOp =
+            operand.getDefiningOp<acc::UseDeviceOp>();
+        LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Erasing: "
+                                << *useDeviceOp << "\n");
+        rewriter.eraseOp(useDeviceOp);
+      }
+      return success();
+    }
+
+    // Handle references to box types
+    bool modified = false;
+    for (acc::UseDeviceOp useDeviceOp : refToBoxUseDeviceOps)
+      modified |=
+          hoistRefToBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+
+    // Handle box types
+    for (acc::UseDeviceOp useDeviceOp : boxUseDeviceOps)
+      modified |= hoistBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+
+    return modified ? success() : failure();
+  }
+
+private:
+  /// Collect users of `acc.use_device` operation inside the `acc.host_data`
+  /// region that need to be updated with the final replacement value.
+  void collectUseDeviceUsersToUpdate(
+      acc::UseDeviceOp useDeviceOp, acc::HostDataOp hostDataOp,
+      SmallVectorImpl<Operation *> &usersToUpdate) const {
+    for (mlir::Operation *user : useDeviceOp->getUsers())
+      if (hostDataOp.getRegion().isAncestor(user->getParentRegion()))
+        usersToUpdate.push_back(user);
+  }
+
+  /// Create new `acc.use_device` operation with the given box address as
+  /// operand. Updates the `acc.host_data` operation to use the new
+  /// `acc.use_device` result.
+  acc::UseDeviceOp createNewUseDeviceOp(PatternRewriter &rewriter,
+                                        acc::UseDeviceOp useDeviceOp,
+                                        acc::HostDataOp hostDataOp,
+                                        fir::BoxAddrOp boxAddr) const {
+    // Create use_device on the raw pointer
+    acc::UseDeviceOp newUseDeviceOp = acc::UseDeviceOp::create(
+        rewriter, useDeviceOp.getLoc(), boxAddr.getType(), boxAddr.getResult(),
+        useDeviceOp.getVarTypeAttr(), useDeviceOp.getVarPtrPtr(),
+        useDeviceOp.getBounds(), useDeviceOp.getAsyncOperands(),
+        useDeviceOp.getAsyncOperandsDeviceTypeAttr(),
+        useDeviceOp.getAsyncOnlyAttr(), useDeviceOp.getDataClauseAttr(),
+        useDeviceOp.getStructuredAttr(), useDeviceOp.getImplicitAttr(),
+        useDeviceOp.getModifiersAttr(), useDeviceOp.getNameAttr(),
+        useDeviceOp.getRecipeAttr());
+
+    LLVM_DEBUG(llvm::dbgs() << "Created new hoisted pattern for box access:\n"
+                            << "  box_addr: " << *boxAddr << "\n"
+                            << "  new use_device: " << *newUseDeviceOp << "\n");
+
+    // Replace the old `acc.use_device` operand in the `acc.host_data` operation
+    // with the new one
+    rewriter.modifyOpInPlace(hostDataOp, [&]() {
+      hostDataOp->replaceUsesOfWith(useDeviceOp.getResult(),
+                                    newUseDeviceOp.getResult());
+    });
+
+    return newUseDeviceOp;
+  }
+
+  /// Canonicalize  use_device operand that is a reference to a box.
+  /// Transforms:
+  ///   %3 = fir.address_of(@_QFEtgt) : !fir.ref<i32>
+  ///   %5 = fir.embox %3 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
+  ///   fir.store %5 to %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  ///   %9 = acc.use_device varPtr(%0 : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+  ///   -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
+  ///   acc.host_data dataOperands(%9 : !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+  ///     %loaded = fir.load %9 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  ///     %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
+  ///     !fir.ptr<i32> %conv = fir.convert %addr : (!fir.ptr<i32>) -> i64
+  ///     fir.call @foo(%conv) : (i64) -> ()
+  ///     acc.terminator
+  ///   }
+  /// into:
+  ///   %loaded = fir.load %0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  ///   %addr = fir.box_addr %loaded : (!fir.box<!fir.ptr<i32>>) ->
+  ///   !fir.ptr<i32>
+  ///   %dev_ptr = acc.use_device varPtr(%addr : !fir.ptr<i32>) ->
+  ///   !fir.ptr<i32>
+  ///   -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
+  ///   acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.ptr<i32>>>)
+  ///   {
+  ///     %embox = fir.embox %dev_ptr : (!fir.ptr<i32>) ->
+  ///     !fir.box<!fir.ptr<i32>> %alloca = fir.alloca !fir.box<!fir.ptr<i32>>
+  ///     fir.store %embox to %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  ///     %loaded2 = fir.load %alloca : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  ///     %addr2 = fir.box_addr %loaded2 : (!fir.box<!fir.ptr<i32>>) ->
+  ///     !fir.ptr<i32> %conv = fir.convert %addr2 : (!fir.ptr<i32>) -> i64
+  ///     fir.call @foo(%conv) : (i64) -> ()
+  ///     acc.terminator
+  ///   }
+  bool hoistRefToBox(PatternRewriter &rewriter, Value operand,
+                     acc::UseDeviceOp useDeviceOp,
+                     acc::HostDataOp hostDataOp) const {
+
+    // Safety check: if the use_device operation is already using a box_addr
+    // result, it means it has already been processed, so skip to avoid infinite
+    // loop
+    if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
+      LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Skipping "
+                                 "already processed use_device operation\n");
+      return false;
+    }
+    // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
+    ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
+
+    // Collect users of the original `acc.use_device` operation that need to be
+    // updated
+    SmallVector<Operation *> usersToUpdate;
+    collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
+
+    rewriter.setInsertionPoint(useDeviceOp);
+    // Create a load operation to get the box from the variable
+    fir::LoadOp box = fir::LoadOp::create(rewriter, useDeviceOp.getLoc(),
+                                          useDeviceOp.getVar());
+    // Create a box_addr operation to get the address from the box
+    fir::BoxAddrOp boxAddr =
+        fir::BoxAddrOp::create(rewriter, useDeviceOp.getLoc(), box);
+
+    acc::UseDeviceOp newUseDeviceOp =
+        createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "Created new hoisted pattern for pointer access:\n"
+               << "  load box: " << *box << "\n"
+               << "  box_addr: " << *boxAddr << "\n"
+               << "  new use_device: " << *newUseDeviceOp << "\n");
+
+    // Set insertion point to the first op inside the host_data region
+    rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
+
+    // Create a FirOpBuilder from the PatternRewriter using the module we got
+    // earlier
+    fir::FirOpBuilder builder(rewriter, mod);
+    Value newBoxwithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
+        builder, useDeviceOp.getLoc(), box.getResult(),
+        newUseDeviceOp.getResult());
+
+    // Create new memory location and store the newBoxwithDevicePtr into new
+    // memory location
+    fir::AllocaOp newMemLoc = fir::AllocaOp::create(
+        rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr.getType());
+    [[maybe_unused]] fir::StoreOp newStoreOp = fir::StoreOp::create(
+        rewriter, useDeviceOp.getLoc(), newBoxwithDevicePtr, newMemLoc);
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "host_data region updated with new host descriptor "
+                  "containing device pointer:\n"
+               << "  box with device pointer: "
+               << *newBoxwithDevicePtr.getDefiningOp() << "\n"
+               << "  mem loc: " << *newMemLoc << "\n"
+               << "  store op: " << *newStoreOp << "\n");
+
+    // Replace all uses of the original `acc.use_device` operation inside the
+    // `acc.host_data` region with the new memory location containing the box
+    // with device pointer
+    for (mlir::Operation *user : usersToUpdate)
+      user->replaceUsesOfWith(useDeviceOp.getResult(), newMemLoc);
+
+    assert(useDeviceOp.getResult().use_empty() &&
+           "expected all uses of use_device to be replaced");
+    rewriter.eraseOp(useDeviceOp);
+    return true;
+  }
+
+  /// Canonicalize use_device operand that is a box type.
+  /// Transforms:
+  ///   %box = ... : !fir.box<!fir.array<?xi32>>
+  ///   %dev_box = acc.use_device varPtr(%box : !fir.box<!fir.array<?xi32>>)
+  ///   -> !fir.box<!fir.array<?xi32>>
+  ///   acc.host_data dataOperands(%dev_box : !fir.box<!fir.array<?xi32>>) {
+  ///     %addr = fir.box_addr %dev_box : (!fir.box<!fir.array<?xi32>>) ->
+  ///     !fir.heap<!fir.array<?xi32>>
+  ///     // use %addr
+  ///   }
+  /// into:
+  ///   %box = ... : !fir.box<!fir.array<?xi32>>
+  ///   %addr = fir.box_addr %box : (!fir.box<!fir.array<?xi32>>) ->
+  ///   !fir.heap<!fir.array<?xi32>>
+  ///   %dev_ptr = acc.use_device varPtr(%addr : !fir.heap<!fir.array<?xi32>>)
+  ///   -> !fir.heap<!fir.array<?xi32>>
+  ///   acc.host_data dataOperands(%dev_ptr : !fir.heap<!fir.array<?xi32>>) {
+  ///     %new_box = fir.embox %dev_ptr ... : !fir.box<!fir.array<?xi32>>
+  ///     %new_addr = fir.box_addr %new_box : (!fir.box<!fir.array<?xi32>>) ->
+  ///     !fir.heap<!fir.array<?xi32>>
+  ///     // use %new_addr instead of %addr
+  ///   }
+  bool hoistBox(PatternRewriter &rewriter, Value operand,
+                acc::UseDeviceOp useDeviceOp,
+                acc::HostDataOp hostDataOp) const {
+
+    // Safety check: if the use_device operation is already using a box_addr
+    // result, it means it has already been processed, so skip to avoid infinite
+    // loop
+    if (useDeviceOp.getVar().getDefiningOp<fir::BoxAddrOp>()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "ACCUseDeviceCanonicalizer: Skipping "
+                    "already processed box use_device operation\n");
+      return false;
+    }
+
+    // Collect users of the original `acc.use_device` operation that need to be
+    // updated
+    SmallVector<Operation *> usersToUpdate;
+    collectUseDeviceUsersToUpdate(useDeviceOp, hostDataOp, usersToUpdate);
+
+    // Get the ModuleOp before we erase useDeviceOp to avoid invalid reference
+    ModuleOp mod = useDeviceOp->getParentOfType<ModuleOp>();
+
+    rewriter.setInsertionPoint(useDeviceOp);
+    // Extract the raw pointer from the box descriptor
+    fir::BoxAddrOp boxAddr = fir::BoxAddrOp::create(
+        rewriter, useDeviceOp.getLoc(), useDeviceOp.getVar());
+
+    acc::UseDeviceOp newUseDeviceOp =
+        createNewUseDeviceOp(rewriter, useDeviceOp, hostDataOp, boxAddr);
+
+    // Set insertion point to the first op inside the host_data region
+    rewriter.setInsertionPoint(&hostDataOp.getRegion().front().front());
+
+    // Create a FirOpBuilder from the PatternRewriter using the module we got
+    // earlier
+    fir::FirOpBuilder builder(rewriter, mod);
+
+    // Create a new host descriptor at the start of the host_data region
+    // with the device pointer as the base address
+    Value newBoxWithDevicePtr = fir::factory::getDescriptorWithNewBaseAddress(
+        builder, useDeviceOp.getLoc(), useDeviceOp.getVar(),
+        newUseDeviceOp.getResult());
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "host_data region updated with new host descriptor "
+                  "containing device pointer:\n"
+               << "  box with device pointer: "
+               << *newBoxWithDevicePtr.getDefiningOp() << "\n");
+
+    // Replace all uses of the original `acc.use_device` operation inside the
+    // `acc.host_data` region with the new box containing device pointer
+    for (mlir::Operation *user : usersToUpdate)
+      user->replaceUsesOfWith(useDeviceOp.getResult(), newBoxWithDevicePtr);
+
+    assert(useDeviceOp.getResult().use_empty() &&
+           "expected all uses of use_device to be replaced");
+    rewriter.eraseOp(useDeviceOp);
+    return true;
+  }
+};
+
+class ACCUseDeviceCanonicalizer
+    : public fir::acc::impl::ACCUseDeviceCanonicalizerBase<
+          ACCUseDeviceCanonicalizer> {
+public:
+  void runOnOperation() override {
+    MLIRContext *context = getOperation()->getContext();
+
+    RewritePatternSet patterns(context);
+
+    // Add the custom use_device canonicalization patterns
+    patterns.insert<UseDeviceHostDataHoisting>(context);
+
+    // Apply patterns greedily
+    GreedyRewriteConfig config;
+    // Prevent the pattern driver from merging blocks.
+    config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
+    config.setUseTopDownTraversal(true);
+
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+  }
+};
+
+} // namespace
+
+std::unique_ptr<mlir::Pass> fir::acc::createACCUseDeviceCanonicalizerPass() {
+  return std::make_unique<ACCUseDeviceCanonicalizer>();
+}

diff  --git a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
index d41e99a6c0679..9bd56c1618544 100644
--- a/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_flang_library(FIROpenACCTransforms
+  ACCUseDeviceCanonicalizer.cpp
   ACCInitializeFIRAnalyses.cpp
   ACCRecipeBufferization.cpp
 
@@ -7,6 +8,7 @@ add_flang_library(FIROpenACCTransforms
 
   LINK_LIBS
   FIRAnalysis
+  FIRBuilder
   FIRDialect
   FIROpenACCAnalysis
   HLFIRDialect
@@ -16,4 +18,5 @@ add_flang_library(FIROpenACCTransforms
   MLIRPass
   MLIROpenACCDialect
   MLIROpenACCUtils
+  MLIRTransformUtils
 )

diff  --git a/flang/test/Fir/OpenACC/use-device-canonicalizer.mlir b/flang/test/Fir/OpenACC/use-device-canonicalizer.mlir
new file mode 100644
index 0000000000000..6ec583dd4fe22
--- /dev/null
+++ b/flang/test/Fir/OpenACC/use-device-canonicalizer.mlir
@@ -0,0 +1,96 @@
+// RUN: fir-opt %s --acc-use-device-canonicalizer -split-input-file | FileCheck %s
+
+// -----
+
+// Test hoisting of load/box_addr/convert pattern out of acc.host_data with function call
+func.func @test_host_data_hoisting_function_call(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) {
+  // CHECK: %[[LOADED:.*]] = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[ADDR:.*]] = fir.box_addr %[[LOADED]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.heap<!fir.array<?xf64>>
+  // CHECK: %[[DEV_PTR:.*]] = acc.use_device varPtr(%[[ADDR]] : !fir.heap<!fir.array<?xf64>>) varType(!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.heap<!fir.array<?xf64>>
+  // CHECK: acc.host_data dataOperands(%[[DEV_PTR]]
+  // CHECK: %[[EMBOX:.*]] = fir.embox %[[DEV_PTR]]
+  // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf64>>>
+  // CHECK: fir.store %[[EMBOX]] to %[[ALLOCA]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[LOAD2:.*]] = fir.load %[[ALLOCA]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[DEV_PTRADDR:.*]] = fir.box_addr %[[LOAD2]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>
+  // CHECK: %[[CONV:.*]] = fir.convert %[[DEV_PTRADDR]] : (!fir.heap<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+  // CHECK: fir.call @_QMmPvadd(%[[CONV]]
+  %dev_ptr = acc.use_device varPtr(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>> {name = "a"}
+  acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) {
+    %loaded = fir.load %dev_ptr : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+    %addr = fir.box_addr %loaded : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.heap<!fir.array<?xf64>>
+    %conv = fir.convert %addr : (!fir.heap<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+    fir.call @_QMmPvadd(%conv, %conv) : (!fir.ref<!fir.array<?xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+    acc.terminator
+  }
+  return
+}
+
+// -----
+
+// Test hoisting of load/box_addr/convert pattern out of acc.host_data with load operation
+func.func @test_host_data_hoisting_load(%arg0: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) {
+  // CHECK: %[[LOADED:.*]] = fir.load %arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[ADDR:.*]] = fir.box_addr %[[LOADED]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.heap<!fir.array<?xf64>>
+  // CHECK: %[[DEV_PTR:.*]] = acc.use_device varPtr(%[[ADDR]] :
+  // CHECK: acc.host_data dataOperands(%[[DEV_PTR]]
+  // CHECK: %[[EMBOX:.*]] = fir.embox %[[DEV_PTR]]
+  // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?xf64>>>
+  // CHECK: fir.store %[[EMBOX]] to %[[ALLOCA]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[LOAD2:.*]] = fir.load %[[ALLOCA]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+  // CHECK: %[[DEV_PTRADDR:.*]] = fir.box_addr %[[LOAD2]] : (!fir.box<!fir.heap<!fir.array<?xf64>>>
+  // CHECK: %[[CONV:.*]] = fir.convert %[[DEV_PTRADDR]] : (!fir.heap<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+  // CHECK: %[[VAL:.*]] = fir.load %[[CONV]]
+  %dev_ptr = acc.use_device varPtr(%arg0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>> {name = "a"}
+  acc.host_data dataOperands(%dev_ptr : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>) {
+    %loaded = fir.load %dev_ptr : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf64>>>>
+    %addr = fir.box_addr %loaded : (!fir.box<!fir.heap<!fir.array<?xf64>>>) -> !fir.heap<!fir.array<?xf64>>
+    %conv = fir.convert %addr : (!fir.heap<!fir.array<?xf64>>) -> !fir.ref<!fir.array<?xf64>>
+    %val = fir.load %conv : !fir.ref<!fir.array<?xf64>>
+    fir.call @foo(%val) : (!fir.array<?xf64>) -> ()
+    acc.terminator
+  }
+  return
+}
+
+// -----
+
+// Test hoisting for pointer attributes: load/box_addr hoisted, remove additional
+// unused use_device clause for a 
diff erent variable
+func.func @test_host_data_hoisting_ref_to_box() {
+  %1 = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "ptr", uniq_name = "_QFEptr"}
+  // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "ptr", uniq_name = "_QFEptr"}
+  %4 = fir.declare %1 {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFEptr"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>>
+  // CHECK: %[[DECLARE:.*]] = fir.declare %[[ALLOCA]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFEptr"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>>
+  // Second pointer variable (unused in host_data region)
+  %ptr2_alloca = fir.alloca !fir.box<!fir.ptr<i32>> {bindc_name = "ptr2", uniq_name = "_QFEptr2"}
+  %ptr2_decl = fir.declare %ptr2_alloca {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFEptr2"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>>
+  %5 = fir.address_of(@_QFEtgt) : !fir.ref<i32>
+  %6 = fir.declare %5 {fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFEtgt"} : (!fir.ref<i32>) -> !fir.ref<i32>
+  %8 = fir.embox %6 : (!fir.ref<i32>) -> !fir.box<!fir.ptr<i32>>
+  fir.store %8 to %4 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  fir.store %8 to %ptr2_decl : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  // CHECK: %[[LOAD:.*]] = fir.load %[[DECLARE]]
+  // CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[LOAD]]
+  // CHECK: %[[DEV_PTR:.*]] = acc.use_device varPtr(%[[BOXADDR]] : !fir.ptr<i32>) varType(!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32> {name = "ptr"}
+  // CHECK: acc.host_data dataOperands(%[[DEV_PTR]] : !fir.ptr<i32>) {
+  // CHECK: %[[EMBOX:.*]] = fir.embox %[[DEV_PTR]] : (!fir.ptr<i32>) -> !fir.box<!fir.ptr<i32>>
+  // CHECK: %[[ALLOCA2:.*]] = fir.alloca !fir.box<!fir.ptr<i32>>
+  // CHECK: fir.store %[[EMBOX]] to %[[ALLOCA2]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  // CHECK: %[[LOAD2:.*]] = fir.load %[[ALLOCA2]] : !fir.ref<!fir.box<!fir.ptr<i32>>>
+  // CHECK: %[[DEV_PTRADDR:.*]] = fir.box_addr %[[LOAD2]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
+  // CHECK: %[[CONV:.*]] = fir.convert %[[DEV_PTRADDR]] : (!fir.ptr<i32>) -> i64
+  // CHECK: fir.call @foo(%[[CONV]]) : (i64) -> ()
+  %9 = acc.use_device varPtr(%4 : !fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr"}
+  // This use_device clause is for a 
diff erent variable (ptr2) and has no uses - should be removed
+  %12 = acc.use_device varPtr(%ptr2_decl : !fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "ptr2"}
+  acc.host_data dataOperands(%9, %12 : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+    %14 = fir.load %9 : !fir.ref<!fir.box<!fir.ptr<i32>>>
+    %15 = fir.box_addr %14 : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
+    %16 = fir.convert %15 : (!fir.ptr<i32>) -> i64
+    fir.call @foo(%16) : (i64) -> ()
+    acc.terminator
+  }
+  return
+}
+

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 50b4d0563faef..66cba7f07adb6 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2889,9 +2889,17 @@ LogicalResult acc::HostDataOp::verify() {
     return emitError("at least one operand must appear on the host_data "
                      "operation");
 
-  for (mlir::Value operand : getDataClauseOperands())
-    if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
+  llvm::SmallPtrSet<mlir::Value, 4> seenVars;
+  for (mlir::Value operand : getDataClauseOperands()) {
+    auto useDeviceOp =
+        mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
+    if (!useDeviceOp)
       return emitError("expect data entry operation as defining op");
+
+    // Check for duplicate use_device clauses
+    if (!seenVars.insert(useDeviceOp.getVar()).second)
+      return emitError("duplicate use_device variable");
+  }
   return success();
 }
 

diff  --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index d1a1c93800264..1ca27df984073 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -956,3 +956,15 @@ func.func @verify_data(%arg0 : memref<i32>) {
   }
   return
 }
+
+// -----
+
+func.func @verify_host_data_duplicate_use_device(%arg0 : memref<i32>) {
+  %0 = acc.use_device varPtr(%arg0 : memref<i32>) -> memref<i32>
+  %1 = acc.use_device varPtr(%arg0 : memref<i32>) -> memref<i32>
+// expected-error @below {{duplicate use_device variable}}
+  acc.host_data dataOperands(%0, %1 : memref<i32>, memref<i32>) {
+    acc.terminator
+  }
+  return
+}


        


More information about the Mlir-commits mailing list