[Mlir-commits] [mlir] [mlir][acc] Add LegalizeDataValues support for DeclareEnterOp (PR #138008)
Susan Tan ス-ザン タン
llvmlistbot at llvm.org
Wed Apr 30 12:55:59 PDT 2025
https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/138008
>From 4c458ea3cbe0d88b8179010e45148463d1723eac Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 11:10:26 -0700
Subject: [PATCH 1/5] add support for acc.declare_enter
---
mlir/include/mlir/Dialect/OpenACC/OpenACC.h | 3 +-
.../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++++++--
2 files changed, 53 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ff5845343313c..9c141550b184e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,7 +58,8 @@
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
- mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
+ mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
+ mlir::acc::DeclareEnterOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index f2abeab744d17..24f0ed3d35a0e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
}
}
+// Helper function to process declare enter/exit pairs
+static void processDeclareEnterExit(
+ acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+ DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
+ // For declare enter/exit pairs, verify there is exactly one exit op using the
+ // token
+ if (!op.getToken().hasOneUse())
+ op.emitError("declare enter token must have exactly one use");
+ Operation *user = *op.getToken().getUsers().begin();
+ auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+ if (!declareExit)
+ op.emitError("declare enter token must be used by declare exit op");
+
+ for (auto p : values) {
+ Value hostVal = std::get<0>(p);
+ Value deviceVal = std::get<1>(p);
+ for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
+ Operation *owner = use.getOwner();
+ if (!domInfo.dominates(op.getOperation(), owner) ||
+ !postDomInfo.postDominates(declareExit.getOperation(), owner))
+ continue;
+ if (insideAccComputeRegion(owner))
+ use.set(deviceVal);
+ }
+ }
+}
+
template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+static void
+collectAndReplaceInRegion(Op &op, bool hostToDevice,
+ DominanceInfo *domInfo = nullptr,
+ PostDominanceInfo *postDomInfo = nullptr) {
llvm::SmallVector<std::pair<Value, Value>> values;
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp> &&
- !std::is_same_v<Op, acc::HostDataOp>) {
+ !std::is_same_v<Op, acc::HostDataOp> &&
+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}
- for (auto p : values)
- replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
- op.getRegion());
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+ assert(domInfo && postDomInfo &&
+ "Dominance info required for DeclareEnterOp");
+ processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+ } else {
+ for (auto p : values) {
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+ op.getRegion());
+ }
+ }
}
class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
+ // Get dominance info for the function
+ DominanceInfo domInfo(funcOp);
+ PostDominanceInfo postDomInfo(funcOp);
+
funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
+ } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+ collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
+ &postDomInfo);
} else {
llvm_unreachable("unsupported acc region op");
}
>From 03b89b17de143e1237f7644b22281077a17ccd56 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:14:23 -0700
Subject: [PATCH 2/5] add support for enter data
---
mlir/include/mlir/Dialect/OpenACC/OpenACC.h | 3 +-
.../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++------
2 files changed, 41 insertions(+), 18 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 9c141550b184e..ff5845343313c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,8 +58,7 @@
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
- mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp, \
- mlir::acc::DeclareEnterOp
+ mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 24f0ed3d35a0e..dd47aabdecf76 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -71,26 +71,43 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
}
}
-// Helper function to process declare enter/exit pairs
-static void processDeclareEnterExit(
- acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+template <typename Op>
+static void replaceAllUsesInUnstructuredComputeRegionWith(
+ Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
- // For declare enter/exit pairs, verify there is exactly one exit op using the
- // token
- if (!op.getToken().hasOneUse())
- op.emitError("declare enter token must have exactly one use");
- Operation *user = *op.getToken().getUsers().begin();
- auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
- if (!declareExit)
- op.emitError("declare enter token must be used by declare exit op");
+
+ Operation *exitOp = op.getOperation();
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+ // For declare enter/exit pairs, verify there is exactly one exit op using
+ // the token
+ if (!op.getToken().hasOneUse())
+ op.emitError("declare enter token must have exactly one use");
+ Operation *user = *op.getToken().getUsers().begin();
+ auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+ if (!declareExit)
+ op.emitError("declare enter token must be used by declare exit op");
+ exitOp = declareExit;
+ } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
+ // For enter/exit data pairs, find the corresponding exit_data op
+ Operation *nextOp = op.getOperation()->getNextNode();
+ while (nextOp && !isa<acc::ExitDataOp>(nextOp))
+ nextOp = nextOp->getNextNode();
+ if (!nextOp)
+ op.emitError("enter data must have a corresponding exit data op");
+ exitOp = nextOp;
+ }
for (auto p : values) {
Value hostVal = std::get<0>(p);
Value deviceVal = std::get<1>(p);
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
Operation *owner = use.getOwner();
+ // Check that:
+ // It's the case that the acc entry operation dominates the use.
+ // It's the case that one of the acc exit operations consuming the token
+ // post-dominates the use
if (!domInfo.dominates(op.getOperation(), owner) ||
- !postDomInfo.postDominates(declareExit.getOperation(), owner))
+ !postDomInfo.postDominates(exitOp, owner))
continue;
if (insideAccComputeRegion(owner))
use.set(deviceVal);
@@ -114,17 +131,20 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp> &&
!std::is_same_v<Op, acc::HostDataOp> &&
- !std::is_same_v<Op, acc::DeclareEnterOp>) {
+ !std::is_same_v<Op, acc::DeclareEnterOp> &&
+ !std::is_same_v<Op, acc::EnterDataOp>) {
collectVars(op.getReductionOperands(), values, hostToDevice);
collectVars(op.getPrivateOperands(), values, hostToDevice);
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}
- if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
+ std::is_same_v<Op, acc::EnterDataOp>) {
assert(domInfo && postDomInfo &&
"Dominance info required for DeclareEnterOp");
- processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+ replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
+ *postDomInfo);
} else {
for (auto p : values) {
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
@@ -151,7 +171,8 @@ class LegalizeDataValuesInRegion
funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
- applyToAccDataConstruct))
+ applyToAccDataConstruct) &&
+ !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -171,6 +192,9 @@ class LegalizeDataValuesInRegion
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
+ } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+ collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
+ &postDomInfo);
} else {
llvm_unreachable("unsupported acc region op");
}
>From cd70987cfa06e9caca53abf72c92119c7f5554c2 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:30:18 -0700
Subject: [PATCH 3/5] change how dominfo is computed
---
.../OpenACC/Transforms/LegalizeDataValues.cpp | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index dd47aabdecf76..fac3171c27d50 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -164,9 +164,10 @@ class LegalizeDataValuesInRegion
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
- // Get dominance info for the function
+ // Initialize dominance info
DominanceInfo domInfo(funcOp);
PostDominanceInfo postDomInfo(funcOp);
+ bool computedDomInfo = false;
funcOp.walk([&](Operation *op) {
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
@@ -190,9 +191,19 @@ class LegalizeDataValuesInRegion
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+ if (!computedDomInfo) {
+ domInfo = DominanceInfo(funcOp);
+ postDomInfo = PostDominanceInfo(funcOp);
+ computedDomInfo = true;
+ }
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
} else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+ if (!computedDomInfo) {
+ domInfo = DominanceInfo(funcOp);
+ postDomInfo = PostDominanceInfo(funcOp);
+ computedDomInfo = true;
+ }
collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
&postDomInfo);
} else {
>From 014a8b84f93b447ef03d721fc346a679dee24e26 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:32:00 -0700
Subject: [PATCH 4/5] leave out initialization
---
mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index fac3171c27d50..d991da3ca4ff6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -165,8 +165,8 @@ class LegalizeDataValuesInRegion
bool replaceHostVsDevice = this->hostToDevice.getValue();
// Initialize dominance info
- DominanceInfo domInfo(funcOp);
- PostDominanceInfo postDomInfo(funcOp);
+ DominanceInfo domInfo;
+ PostDominanceInfo postDomInfo;
bool computedDomInfo = false;
funcOp.walk([&](Operation *op) {
>From 823d31d006405976a70d10e831323061da20ac6e Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:55:47 -0700
Subject: [PATCH 5/5] add lit test
---
mlir/test/Dialect/OpenACC/legalize-data.mlir | 29 +++++++++++++++++++-
1 file changed, 28 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 9461225e9a7e0..28ef6761a6ef4 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -245,4 +245,31 @@ func.func private @foo(memref<10xf32>)
// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
// DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
// CHECK: acc.terminator
-// CHECK: }
\ No newline at end of file
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+ %declare = acc.create varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
+ %token = acc.declare_enter dataOperands(%declare : memref<10xf32>)
+ acc.kernels dataOperands(%declare : memref<10xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1.000000e+00 : f32
+ memref.store %c1, %a[%c0] : memref<10xf32>
+ acc.terminator
+ }
+ acc.declare_exit token(%token) dataOperands(%declare : memref<10xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[DECLARE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DECLARE]] : memref<10xf32>)
+// CHECK: acc.kernels dataOperands(%[[DECLARE]] : memref<10xf32>) {
+// DEVICE: memref.store %{{.*}}, %[[DECLARE]][%{{.*}}] : memref<10xf32>
+// HOST: memref.store %{{.*}}, %[[A]][%{{.*}}] : memref<10xf32>
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DECLARE]] : memref<10xf32>)
+// CHECK: return
\ No newline at end of file
More information about the Mlir-commits
mailing list