[Mlir-commits] [flang] [mlir] [flang][acc] Add ACCUseDeviceCanonicalizer pass (PR #175228)

Razvan Lupusoru llvmlistbot at llvm.org
Tue Jan 13 09:08:54 PST 2026


================
@@ -0,0 +1,428 @@
+//===- ACCUseDeviceCanonicalizer.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass canonicalizes the use_device clause on a host_data construct such
+// that use_device(x) can be lowered to a simple runtime call that takes the
+// actual host pointer as argument.
+//
+// For a use_device operand that is a box type or a reference to a box, the
+// pass:
+//   1. Extracts the host base address for mapping to a device address using
+//      acc.use_device.
+//   2. Creates a new boxed descriptor with the device address as the base
+//      address for use inside the host_data region.
+//
+// The pass also removes unused use_device clauses, reducing the number of
+// runtime calls.
+//
+// Supported use_device operand types:
+//
+//   Scalars:
+//     - !fir.ref<i32>, !fir.ref<f64>, etc.
+//
+//   Arrays:
+//     - Explicit shape (no descriptor): !fir.ref<!fir.array<100xi32>>
+//     - Adjustable size: !fir.ref<!fir.array<?xi32>>
+//     - Assumed shape (handled by hoistBox): !fir.box<!fir.array<?xi32>>
+//     - Assumed size: !fir.ref<!fir.array<?xi32>>
+//     - Deferred shape (handled by hoistRefToBox):
+//         - Allocatable: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+//         - Pointer: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>
+//     - Subarray specification (handled by hoistBox):
+//     !fir.box<!fir.array<?xi32>>
+//
+//   Not yet supported:
+//     - Assumed rank arrays
+//     - Composite variables: !fir.ref<!fir.type<...>>
+//     - Array elements (device pointer arithmetic in host_data region)
+//     - Composite variable members
+//     - Fortran common blocks: use_device(/cm_block/)
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Builder/BoxValue.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenACC/Passes.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/Debug.h"
+#include <cassert>
+
+namespace fir::acc {
+#define GEN_PASS_DEF_ACCUSEDEVICECANONICALIZER
+#include "flang/Optimizer/OpenACC/Passes.h.inc"
+} // namespace fir::acc
+
+#define DEBUG_TYPE "acc-use-device-canonicalizer"
+
+using namespace mlir;
+
+namespace {
+
+static bool isFirRefToBox(Type type) {
+  // First check if it's a FIR reference type
+  if (!fir::isa_ref_type(type))
+    return false;
+
+  // Unwrap the reference to get the inner type
+  Type innerType = fir::unwrapRefType(type);
+
+  // Check if the inner type is a box type
+  return isa<fir::BoxType>(innerType);
+}
+
+struct UseDeviceHostDataHoisting : public OpRewritePattern<acc::HostDataOp> {
+  using OpRewritePattern<acc::HostDataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(acc::HostDataOp op,
+                                PatternRewriter &rewriter) const override {
+
+    // Check if any of the data operands are acc.use_device results
+    bool hasUseDeviceOperand = false;
+    SmallVector<Value> usedOperands;
+    SmallVector<Value> unusedUseDeviceOperands;
+    SmallVector<acc::UseDeviceOp> refToBoxUseDeviceOps;
+    SmallVector<acc::UseDeviceOp> boxUseDeviceOps;
+
+    for (Value operand : op.getDataClauseOperands()) {
+      if (acc::UseDeviceOp useDeviceOp =
+              operand.getDefiningOp<acc::UseDeviceOp>()) {
+        hasUseDeviceOperand = true;
+        if (isFirRefToBox(useDeviceOp.getVar().getType())) {
+          if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+            refToBoxUseDeviceOps.push_back(useDeviceOp);
+        } else if (isa<fir::BoxType>(useDeviceOp.getVar().getType())) {
+          if (!llvm::hasSingleElement(useDeviceOp->getUsers()))
+            boxUseDeviceOps.push_back(useDeviceOp);
+        }
+
+        // host_data is the only user of this use_device operand - mark for
+        // removal
+        if (llvm::hasSingleElement(useDeviceOp->getUsers()))
+          unusedUseDeviceOperands.push_back(useDeviceOp.getResult());
+        else
+          usedOperands.push_back(useDeviceOp.getResult());
+      } else {
+        // use_device operands that don't need hoisting, they are passed as
+        // is to __tgt_acc_get_deviceptr.
+        usedOperands.push_back(operand);
+      }
+    }
+
+    assert(!usedOperands.empty() && "Host_data operation has no used operands");
+
+    if (!unusedUseDeviceOperands.empty()) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "ACCUseDeviceCanonicalizer: Removing "
+                 << unusedUseDeviceOperands.size()
+                 << " unused use_device operands from host_data operation\n");
+
+      // Update the host_data operation to have only used operands
+      rewriter.modifyOpInPlace(op, [&]() {
+        op.getDataClauseOperandsMutable().assign(usedOperands);
+      });
+
+      // Remove unused use_device operations
+      for (Value operand : unusedUseDeviceOperands) {
+        acc::UseDeviceOp useDeviceOp =
+            operand.getDefiningOp<acc::UseDeviceOp>();
+        LLVM_DEBUG(llvm::dbgs() << "ACCUseDeviceCanonicalizer: Erasing: "
+                                << *useDeviceOp << "\n");
+        rewriter.eraseOp(useDeviceOp);
+      }
+      return success();
+    }
+
+    bool patternModified = false;
+    if (!refToBoxUseDeviceOps.empty()) {
+      // Handle pointer to box types
+      for (acc::UseDeviceOp useDeviceOp : refToBoxUseDeviceOps) {
+        patternModified =
+            hoistRefToBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+        break; // Only handle one at a time to avoid iterator invalidation
+      }
+      return patternModified ? success() : failure();
+    }
+
+    if (!boxUseDeviceOps.empty()) {
+      // Handle box types
+      for (acc::UseDeviceOp useDeviceOp : boxUseDeviceOps) {
+        patternModified =
+            hoistBox(rewriter, useDeviceOp.getResult(), useDeviceOp, op);
+        break; // Only handle one at a time to avoid iterator invalidation
+      }
+      return patternModified ? success() : failure();
+    }
+
+    if (!hasUseDeviceOperand)
+      return failure();
+
+    return patternModified ? success() : failure();
+  }
+
+private:
+  /// Create new use_device operation with the given box address as operand.
+  acc::UseDeviceOp createNewUseDeviceOp(PatternRewriter &rewriter,
+                                        acc::UseDeviceOp useDeviceOp,
+                                        acc::HostDataOp hostDataOp,
+                                        fir::BoxAddrOp boxAddr) const {
+
+    // Create use_device on the raw pointer
+    acc::UseDeviceOp newUseDeviceOp = acc::UseDeviceOp::create(
+        rewriter, useDeviceOp.getLoc(), boxAddr.getType(), boxAddr.getResult(),
+        useDeviceOp.getVarTypeAttr(), useDeviceOp.getVarPtrPtr(),
+        useDeviceOp.getBounds(), useDeviceOp.getAsyncOperands(),
+        useDeviceOp.getAsyncOperandsDeviceTypeAttr(),
+        useDeviceOp.getAsyncOnlyAttr(), useDeviceOp.getDataClauseAttr(),
+        useDeviceOp.getStructuredAttr(), useDeviceOp.getImplicitAttr(),
+        useDeviceOp.getModifiersAttr(), useDeviceOp.getNameAttr(),
+        useDeviceOp.getRecipeAttr());
+
+    // Replace the old use_device operand in the host_data with the new one
+    SmallVector<Value> newOperands;
+    for (Value operand : hostDataOp.getDataClauseOperands()) {
+      if (operand == useDeviceOp.getResult())
+        newOperands.push_back(newUseDeviceOp.getResult());
+      else
+        newOperands.push_back(operand);
+    }
----------------
razvanlupusoru wrote:

Yes - done!

https://github.com/llvm/llvm-project/pull/175228


More information about the Mlir-commits mailing list