[Mlir-commits] [mlir] [mlir][acc] Add LegalizeDataValues support for DeclareEnterOp (PR #138008)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 30 11:14:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-mlir
Author: Susan Tan (ス-ザン タン) (SusanTan)
<details>
<summary>Changes</summary>
The patch extends the existing LegalizeDataValues to support DeclareEnter and DeclareExit pair.
Since unlike other ops, DeclareEnter and DeclareExit don't have a region defined, we use dominance/post dominance information to ensure only the uses within the region dominated by DeclareEnter and post dominated by DeclareExit are updated with data on device.
---
Full diff: https://github.com/llvm/llvm-project/pull/138008.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+2-1)
- (modified) mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp (+51-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ff5845343313c..9c141550b184e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,7 +58,8 @@
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
- mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
+ mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
+ mlir::acc::DeclareEnterOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index f2abeab744d17..24f0ed3d35a0e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
}
}
+// Helper function to process declare enter/exit pairs
+static void processDeclareEnterExit(
+ acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+ DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
+ // For declare enter/exit pairs, verify there is exactly one exit op using the
+ // token
+ if (!op.getToken().hasOneUse())
+ op.emitError("declare enter token must have exactly one use");
+ Operation *user = *op.getToken().getUsers().begin();
+ auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+ if (!declareExit)
+ op.emitError("declare enter token must be used by declare exit op");
+
+ for (auto p : values) {
+ Value hostVal = std::get<0>(p);
+ Value deviceVal = std::get<1>(p);
+ for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
+ Operation *owner = use.getOwner();
+ if (!domInfo.dominates(op.getOperation(), owner) ||
+ !postDomInfo.postDominates(declareExit.getOperation(), owner))
+ continue;
+ if (insideAccComputeRegion(owner))
+ use.set(deviceVal);
+ }
+ }
+}
+
template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+static void
+collectAndReplaceInRegion(Op &op, bool hostToDevice,
+ DominanceInfo *domInfo = nullptr,
+ PostDominanceInfo *postDomInfo = nullptr) {
llvm::SmallVector<std::pair<Value, Value>> values;
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp> &&
- !std::is_same_v<Op, acc::HostDataOp>) {
+ !std::is_same_v<Op, acc::HostDataOp> &&
+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}
- for (auto p : values)
- replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
- op.getRegion());
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+ assert(domInfo && postDomInfo &&
+ "Dominance info required for DeclareEnterOp");
+ processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+ } else {
+ for (auto p : values) {
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+ op.getRegion());
+ }
+ }
}
class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
+ // Get dominance info for the function
+ DominanceInfo domInfo(funcOp);
+ PostDominanceInfo postDomInfo(funcOp);
+
funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
+ } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+ collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
+ &postDomInfo);
} else {
llvm_unreachable("unsupported acc region op");
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/138008
More information about the Mlir-commits
mailing list