[Mlir-commits] [mlir] [mlir][acc] Add LegalizeDataValues support for DeclareEnterOp (PR #138008)

Susan Tan ス-ザン タン llvmlistbot at llvm.org
Wed Apr 30 13:08:08 PDT 2025


https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/138008

>From 4c458ea3cbe0d88b8179010e45148463d1723eac Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 11:10:26 -0700
Subject: [PATCH 1/7] add support for acc.declare_enter

---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h   |  3 +-
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++++++--
 2 files changed, 53 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ff5845343313c..9c141550b184e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,7 +58,8 @@
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
 #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                      \
-  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
+  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp,              \
+      mlir::acc::DeclareEnterOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
       mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index f2abeab744d17..24f0ed3d35a0e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
   }
 }
 
+// Helper function to process declare enter/exit pairs
+static void processDeclareEnterExit(
+    acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+    DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
+  // For declare enter/exit pairs, verify there is exactly one exit op using the
+  // token
+  if (!op.getToken().hasOneUse())
+    op.emitError("declare enter token must have exactly one use");
+  Operation *user = *op.getToken().getUsers().begin();
+  auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+  if (!declareExit)
+    op.emitError("declare enter token must be used by declare exit op");
+
+  for (auto p : values) {
+    Value hostVal = std::get<0>(p);
+    Value deviceVal = std::get<1>(p);
+    for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
+      Operation *owner = use.getOwner();
+      if (!domInfo.dominates(op.getOperation(), owner) ||
+          !postDomInfo.postDominates(declareExit.getOperation(), owner))
+        continue;
+      if (insideAccComputeRegion(owner))
+        use.set(deviceVal);
+    }
+  }
+}
+
 template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+static void
+collectAndReplaceInRegion(Op &op, bool hostToDevice,
+                          DominanceInfo *domInfo = nullptr,
+                          PostDominanceInfo *postDomInfo = nullptr) {
   llvm::SmallVector<std::pair<Value, Value>> values;
 
   if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp> &&
-                  !std::is_same_v<Op, acc::HostDataOp>) {
+                  !std::is_same_v<Op, acc::HostDataOp> &&
+                  !std::is_same_v<Op, acc::DeclareEnterOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 
-  for (auto p : values)
-    replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
-                                              op.getRegion());
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+    assert(domInfo && postDomInfo &&
+           "Dominance info required for DeclareEnterOp");
+    processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+  } else {
+    for (auto p : values) {
+      replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+                                                op.getRegion());
+    }
+  }
 }
 
 class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
     func::FuncOp funcOp = getOperation();
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
+    // Get dominance info for the function
+    DominanceInfo domInfo(funcOp);
+    PostDominanceInfo postDomInfo(funcOp);
+
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
         collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
       } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
         collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
+      } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+        collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
+                                  &postDomInfo);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

>From 03b89b17de143e1237f7644b22281077a17ccd56 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:14:23 -0700
Subject: [PATCH 2/7] add support for enter data

---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h   |  3 +-
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++------
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 9c141550b184e..ff5845343313c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,8 +58,7 @@
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
 #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                      \
-  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp,              \
-      mlir::acc::DeclareEnterOp
+  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
       mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 24f0ed3d35a0e..dd47aabdecf76 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -71,26 +71,43 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
   }
 }
 
-// Helper function to process declare enter/exit pairs
-static void processDeclareEnterExit(
-    acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+template <typename Op>
+static void replaceAllUsesInUnstructuredComputeRegionWith(
+    Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
     DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
-  // For declare enter/exit pairs, verify there is exactly one exit op using the
-  // token
-  if (!op.getToken().hasOneUse())
-    op.emitError("declare enter token must have exactly one use");
-  Operation *user = *op.getToken().getUsers().begin();
-  auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
-  if (!declareExit)
-    op.emitError("declare enter token must be used by declare exit op");
+
+  Operation *exitOp = op.getOperation();
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+    // For declare enter/exit pairs, verify there is exactly one exit op using
+    // the token
+    if (!op.getToken().hasOneUse())
+      op.emitError("declare enter token must have exactly one use");
+    Operation *user = *op.getToken().getUsers().begin();
+    auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+    if (!declareExit)
+      op.emitError("declare enter token must be used by declare exit op");
+    exitOp = declareExit;
+  } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
+    // For enter/exit data pairs, find the corresponding exit_data op
+    Operation *nextOp = op.getOperation()->getNextNode();
+    while (nextOp && !isa<acc::ExitDataOp>(nextOp))
+      nextOp = nextOp->getNextNode();
+    if (!nextOp)
+      op.emitError("enter data must have a corresponding exit data op");
+    exitOp = nextOp;
+  }
 
   for (auto p : values) {
     Value hostVal = std::get<0>(p);
     Value deviceVal = std::get<1>(p);
     for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
       Operation *owner = use.getOwner();
+      // Check that:
+      // It's the case that the acc entry operation dominates the use.
+      // It's the case that one of the acc exit operations consuming the token
+      // post-dominates the use
       if (!domInfo.dominates(op.getOperation(), owner) ||
-          !postDomInfo.postDominates(declareExit.getOperation(), owner))
+          !postDomInfo.postDominates(exitOp, owner))
         continue;
       if (insideAccComputeRegion(owner))
         use.set(deviceVal);
@@ -114,17 +131,20 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp> &&
                   !std::is_same_v<Op, acc::HostDataOp> &&
-                  !std::is_same_v<Op, acc::DeclareEnterOp>) {
+                  !std::is_same_v<Op, acc::DeclareEnterOp> &&
+                  !std::is_same_v<Op, acc::EnterDataOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 
-  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
+                std::is_same_v<Op, acc::EnterDataOp>) {
     assert(domInfo && postDomInfo &&
            "Dominance info required for DeclareEnterOp");
-    processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+    replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
+                                                      *postDomInfo);
   } else {
     for (auto p : values) {
       replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
@@ -151,7 +171,8 @@ class LegalizeDataValuesInRegion
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
-            applyToAccDataConstruct))
+            applyToAccDataConstruct) &&
+          !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -171,6 +192,9 @@ class LegalizeDataValuesInRegion
       } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
         collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
+      } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+        collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
+                                  &postDomInfo);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

>From cd70987cfa06e9caca53abf72c92119c7f5554c2 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:30:18 -0700
Subject: [PATCH 3/7] change how dominfo is computed

---
 .../OpenACC/Transforms/LegalizeDataValues.cpp       | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index dd47aabdecf76..fac3171c27d50 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -164,9 +164,10 @@ class LegalizeDataValuesInRegion
     func::FuncOp funcOp = getOperation();
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
-    // Get dominance info for the function
+    // Initialize dominance info
     DominanceInfo domInfo(funcOp);
     PostDominanceInfo postDomInfo(funcOp);
+    bool computedDomInfo = false;
 
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
@@ -190,9 +191,19 @@ class LegalizeDataValuesInRegion
       } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
         collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
       } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+        if (!computedDomInfo) {
+          domInfo = DominanceInfo(funcOp);
+          postDomInfo = PostDominanceInfo(funcOp);
+          computedDomInfo = true;
+        }
         collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
       } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+        if (!computedDomInfo) {
+          domInfo = DominanceInfo(funcOp);
+          postDomInfo = PostDominanceInfo(funcOp);
+          computedDomInfo = true;
+        }
         collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
       } else {

>From 014a8b84f93b447ef03d721fc346a679dee24e26 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:32:00 -0700
Subject: [PATCH 4/7] leave out initialization

---
 mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index fac3171c27d50..d991da3ca4ff6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -165,8 +165,8 @@ class LegalizeDataValuesInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     // Initialize dominance info
-    DominanceInfo domInfo(funcOp);
-    PostDominanceInfo postDomInfo(funcOp);
+    DominanceInfo domInfo;
+    PostDominanceInfo postDomInfo;
     bool computedDomInfo = false;
 
     funcOp.walk([&](Operation *op) {

>From 823d31d006405976a70d10e831323061da20ac6e Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:55:47 -0700
Subject: [PATCH 5/7] add lit test

---
 mlir/test/Dialect/OpenACC/legalize-data.mlir | 29 +++++++++++++++++++-
 1 file changed, 28 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 9461225e9a7e0..28ef6761a6ef4 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -245,4 +245,31 @@ func.func private @foo(memref<10xf32>)
 // CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
 // DEVICE:   func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
 // CHECK:   acc.terminator
-// CHECK: }
\ No newline at end of file
+// CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+  %declare = acc.create varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
+  %token = acc.declare_enter dataOperands(%declare : memref<10xf32>)
+  acc.kernels dataOperands(%declare : memref<10xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1.000000e+00 : f32
+    memref.store %c1, %a[%c0] : memref<10xf32>
+    acc.terminator
+  }
+  acc.declare_exit token(%token) dataOperands(%declare : memref<10xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[DECLARE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DECLARE]] : memref<10xf32>)
+// CHECK: acc.kernels dataOperands(%[[DECLARE]] : memref<10xf32>) {
+// DEVICE:   memref.store %{{.*}}, %[[DECLARE]][%{{.*}}] : memref<10xf32>
+// HOST:    memref.store %{{.*}}, %[[A]][%{{.*}}] : memref<10xf32>
+// CHECK:   acc.terminator
+// CHECK: }
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DECLARE]] : memref<10xf32>)
+// CHECK: return
\ No newline at end of file

>From 4814f8f7a3a4e0c6da2f4fc7325e07cbf6063948 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 13:04:48 -0700
Subject: [PATCH 6/7] change to support multiple exits

---
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 61 ++++++++-----------
 1 file changed, 27 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index d991da3ca4ff6..073961f44a7e1 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -76,25 +76,16 @@ static void replaceAllUsesInUnstructuredComputeRegionWith(
     Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
     DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
 
-  Operation *exitOp = op.getOperation();
+  SmallVector<Operation *> exitOps;
   if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
-    // For declare enter/exit pairs, verify there is exactly one exit op using
-    // the token
-    if (!op.getToken().hasOneUse())
-      op.emitError("declare enter token must have exactly one use");
-    Operation *user = *op.getToken().getUsers().begin();
-    auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
-    if (!declareExit)
-      op.emitError("declare enter token must be used by declare exit op");
-    exitOp = declareExit;
-  } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
-    // For enter/exit data pairs, find the corresponding exit_data op
-    Operation *nextOp = op.getOperation()->getNextNode();
-    while (nextOp && !isa<acc::ExitDataOp>(nextOp))
-      nextOp = nextOp->getNextNode();
-    if (!nextOp)
-      op.emitError("enter data must have a corresponding exit data op");
-    exitOp = nextOp;
+    // For declare enter/exit pairs, collect all exit ops
+    for (auto *user : op.getToken().getUsers()) {
+      if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
+        exitOps.push_back(declareExit);
+    }
+    if (exitOps.empty())
+      op.emitError(
+          "declare enter token must be used by at least one declare exit op");
   }
 
   for (auto p : values) {
@@ -102,13 +93,24 @@ static void replaceAllUsesInUnstructuredComputeRegionWith(
     Value deviceVal = std::get<1>(p);
     for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
       Operation *owner = use.getOwner();
-      // Check that:
-      // It's the case that the acc entry operation dominates the use.
-      // It's the case that one of the acc exit operations consuming the token
+
+      // Check It's the case that the acc entry operation dominates the use.
+      if (!domInfo.dominates(op.getOperation(), owner))
+        continue;
+
+      // Check It's the case that at least one of the acc exit operations
       // post-dominates the use
-      if (!domInfo.dominates(op.getOperation(), owner) ||
-          !postDomInfo.postDominates(exitOp, owner))
+      bool hasPostDominatingExit = false;
+      for (auto *exit : exitOps) {
+        if (postDomInfo.postDominates(exit, owner)) {
+          hasPostDominatingExit = true;
+          break;
+        }
+      }
+
+      if (!hasPostDominatingExit)
         continue;
+
       if (insideAccComputeRegion(owner))
         use.set(deviceVal);
     }
@@ -131,8 +133,7 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp> &&
                   !std::is_same_v<Op, acc::HostDataOp> &&
-                  !std::is_same_v<Op, acc::DeclareEnterOp> &&
-                  !std::is_same_v<Op, acc::EnterDataOp>) {
+                  !std::is_same_v<Op, acc::DeclareEnterOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
@@ -173,7 +174,7 @@ class LegalizeDataValuesInRegion
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
             applyToAccDataConstruct) &&
-          !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
+          !isa<acc::DeclareEnterOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -198,14 +199,6 @@ class LegalizeDataValuesInRegion
         }
         collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
-      } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
-        if (!computedDomInfo) {
-          domInfo = DominanceInfo(funcOp);
-          postDomInfo = PostDominanceInfo(funcOp);
-          computedDomInfo = true;
-        }
-        collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
-                                  &postDomInfo);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

>From cb0a1b097c08d69323346b35315eb1a1285cc798 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 13:07:56 -0700
Subject: [PATCH 7/7] remove remaining  enter_data

---
 mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 073961f44a7e1..b63840bbe0b8e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -140,8 +140,7 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
     }
   }
 
-  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
-                std::is_same_v<Op, acc::EnterDataOp>) {
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
     assert(domInfo && postDomInfo &&
            "Dominance info required for DeclareEnterOp");
     replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,



More information about the Mlir-commits mailing list