[Mlir-commits] [mlir] [mlir][acc] Add LegalizeDataValues support for DeclareEnterOp (PR #138008)

Susan Tan ス-ザン タン llvmlistbot at llvm.org
Wed Apr 30 12:32:11 PDT 2025


https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/138008

>From 4c458ea3cbe0d88b8179010e45148463d1723eac Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 11:10:26 -0700
Subject: [PATCH 1/4] add support for acc.declare_enter

---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h   |  3 +-
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++++++--
 2 files changed, 53 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ff5845343313c..9c141550b184e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,7 +58,8 @@
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
 #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                      \
-  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
+  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp,              \
+      mlir::acc::DeclareEnterOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
       mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index f2abeab744d17..24f0ed3d35a0e 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -70,8 +71,38 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
   }
 }
 
+// Helper function to process declare enter/exit pairs
+static void processDeclareEnterExit(
+    acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+    DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
+  // For declare enter/exit pairs, verify there is exactly one exit op using the
+  // token
+  if (!op.getToken().hasOneUse())
+    op.emitError("declare enter token must have exactly one use");
+  Operation *user = *op.getToken().getUsers().begin();
+  auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+  if (!declareExit)
+    op.emitError("declare enter token must be used by declare exit op");
+
+  for (auto p : values) {
+    Value hostVal = std::get<0>(p);
+    Value deviceVal = std::get<1>(p);
+    for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
+      Operation *owner = use.getOwner();
+      if (!domInfo.dominates(op.getOperation(), owner) ||
+          !postDomInfo.postDominates(declareExit.getOperation(), owner))
+        continue;
+      if (insideAccComputeRegion(owner))
+        use.set(deviceVal);
+    }
+  }
+}
+
 template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+static void
+collectAndReplaceInRegion(Op &op, bool hostToDevice,
+                          DominanceInfo *domInfo = nullptr,
+                          PostDominanceInfo *postDomInfo = nullptr) {
   llvm::SmallVector<std::pair<Value, Value>> values;
 
   if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +113,24 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp> &&
-                  !std::is_same_v<Op, acc::HostDataOp>) {
+                  !std::is_same_v<Op, acc::HostDataOp> &&
+                  !std::is_same_v<Op, acc::DeclareEnterOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 
-  for (auto p : values)
-    replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
-                                              op.getRegion());
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+    assert(domInfo && postDomInfo &&
+           "Dominance info required for DeclareEnterOp");
+    processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+  } else {
+    for (auto p : values) {
+      replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+                                                op.getRegion());
+    }
+  }
 }
 
 class LegalizeDataValuesInRegion
@@ -105,6 +144,10 @@ class LegalizeDataValuesInRegion
     func::FuncOp funcOp = getOperation();
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
+    // Get dominance info for the function
+    DominanceInfo domInfo(funcOp);
+    PostDominanceInfo postDomInfo(funcOp);
+
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
@@ -125,6 +168,9 @@ class LegalizeDataValuesInRegion
         collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
       } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
         collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
+      } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+        collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
+                                  &postDomInfo);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

>From 03b89b17de143e1237f7644b22281077a17ccd56 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:14:23 -0700
Subject: [PATCH 2/4] add support for enter data

---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h   |  3 +-
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 56 +++++++++++++------
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 9c141550b184e..ff5845343313c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,8 +58,7 @@
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
 #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                      \
-  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp,              \
-      mlir::acc::DeclareEnterOp
+  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
       mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 24f0ed3d35a0e..dd47aabdecf76 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -71,26 +71,43 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
   }
 }
 
-// Helper function to process declare enter/exit pairs
-static void processDeclareEnterExit(
-    acc::DeclareEnterOp op, llvm::SmallVector<std::pair<Value, Value>> &values,
+template <typename Op>
+static void replaceAllUsesInUnstructuredComputeRegionWith(
+    Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
     DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
-  // For declare enter/exit pairs, verify there is exactly one exit op using the
-  // token
-  if (!op.getToken().hasOneUse())
-    op.emitError("declare enter token must have exactly one use");
-  Operation *user = *op.getToken().getUsers().begin();
-  auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
-  if (!declareExit)
-    op.emitError("declare enter token must be used by declare exit op");
+
+  Operation *exitOp = op.getOperation();
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+    // For declare enter/exit pairs, verify there is exactly one exit op using
+    // the token
+    if (!op.getToken().hasOneUse())
+      op.emitError("declare enter token must have exactly one use");
+    Operation *user = *op.getToken().getUsers().begin();
+    auto declareExit = dyn_cast<acc::DeclareExitOp>(user);
+    if (!declareExit)
+      op.emitError("declare enter token must be used by declare exit op");
+    exitOp = declareExit;
+  } else if constexpr (std::is_same_v<Op, acc::EnterDataOp>) {
+    // For enter/exit data pairs, find the corresponding exit_data op
+    Operation *nextOp = op.getOperation()->getNextNode();
+    while (nextOp && !isa<acc::ExitDataOp>(nextOp))
+      nextOp = nextOp->getNextNode();
+    if (!nextOp)
+      op.emitError("enter data must have a corresponding exit data op");
+    exitOp = nextOp;
+  }
 
   for (auto p : values) {
     Value hostVal = std::get<0>(p);
     Value deviceVal = std::get<1>(p);
     for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
       Operation *owner = use.getOwner();
+      // Check that:
+      // It's the case that the acc entry operation dominates the use.
+      // It's the case that one of the acc exit operations consuming the token
+      // post-dominates the use
       if (!domInfo.dominates(op.getOperation(), owner) ||
-          !postDomInfo.postDominates(declareExit.getOperation(), owner))
+          !postDomInfo.postDominates(exitOp, owner))
         continue;
       if (insideAccComputeRegion(owner))
         use.set(deviceVal);
@@ -114,17 +131,20 @@ collectAndReplaceInRegion(Op &op, bool hostToDevice,
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp> &&
                   !std::is_same_v<Op, acc::HostDataOp> &&
-                  !std::is_same_v<Op, acc::DeclareEnterOp>) {
+                  !std::is_same_v<Op, acc::DeclareEnterOp> &&
+                  !std::is_same_v<Op, acc::EnterDataOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 
-  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
+  if constexpr (std::is_same_v<Op, acc::DeclareEnterOp> ||
+                std::is_same_v<Op, acc::EnterDataOp>) {
     assert(domInfo && postDomInfo &&
            "Dominance info required for DeclareEnterOp");
-    processDeclareEnterExit(op, values, *domInfo, *postDomInfo);
+    replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
+                                                      *postDomInfo);
   } else {
     for (auto p : values) {
       replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
@@ -151,7 +171,8 @@ class LegalizeDataValuesInRegion
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
           !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
-            applyToAccDataConstruct))
+            applyToAccDataConstruct) &&
+          !isa<acc::DeclareEnterOp>(*op) && !isa<acc::EnterDataOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -171,6 +192,9 @@ class LegalizeDataValuesInRegion
       } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
         collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
+      } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+        collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
+                                  &postDomInfo);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

>From cd70987cfa06e9caca53abf72c92119c7f5554c2 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:30:18 -0700
Subject: [PATCH 3/4] change how dominfo is computed

---
 .../OpenACC/Transforms/LegalizeDataValues.cpp       | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index dd47aabdecf76..fac3171c27d50 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -164,9 +164,10 @@ class LegalizeDataValuesInRegion
     func::FuncOp funcOp = getOperation();
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
-    // Get dominance info for the function
+    // Initialize dominance info
     DominanceInfo domInfo(funcOp);
     PostDominanceInfo postDomInfo(funcOp);
+    bool computedDomInfo = false;
 
     funcOp.walk([&](Operation *op) {
       if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
@@ -190,9 +191,19 @@ class LegalizeDataValuesInRegion
       } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
         collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
       } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
+        if (!computedDomInfo) {
+          domInfo = DominanceInfo(funcOp);
+          postDomInfo = PostDominanceInfo(funcOp);
+          computedDomInfo = true;
+        }
         collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
       } else if (auto enterDataOp = dyn_cast<acc::EnterDataOp>(*op)) {
+        if (!computedDomInfo) {
+          domInfo = DominanceInfo(funcOp);
+          postDomInfo = PostDominanceInfo(funcOp);
+          computedDomInfo = true;
+        }
         collectAndReplaceInRegion(enterDataOp, replaceHostVsDevice, &domInfo,
                                   &postDomInfo);
       } else {

>From 014a8b84f93b447ef03d721fc346a679dee24e26 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 30 Apr 2025 12:32:00 -0700
Subject: [PATCH 4/4] leave out initialization

---
 mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index fac3171c27d50..d991da3ca4ff6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -165,8 +165,8 @@ class LegalizeDataValuesInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     // Initialize dominance info
-    DominanceInfo domInfo(funcOp);
-    PostDominanceInfo postDomInfo(funcOp);
+    DominanceInfo domInfo;
+    PostDominanceInfo postDomInfo;
     bool computedDomInfo = false;
 
     funcOp.walk([&](Operation *op) {



More information about the Mlir-commits mailing list