[Mlir-commits] [mlir] [acc] Improve LegalizeDataValues pass to handle data constructs (PR #112990)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 18 15:03:14 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/112990.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+2-2) 
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h (+2-5) 
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+8-6) 
- (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+1-1) 
- (renamed) mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp (+53-12) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ca96ce62ae404e..60fe4c5fb9d4cc 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -56,14 +56,14 @@
   mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
-#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS                                  \
+#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                  \
   mlir::acc::DataOp, mlir::acc::DeclareOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
       mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp,                        \
       mlir::acc::DeclareExitOp
 #define ACC_DATA_CONSTRUCT_OPS                                                 \
-  OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
+  ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
 #define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
 #define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS                                \
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
index bb93c78bf6eadf..57d532b078b9e3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -11,9 +11,6 @@
 
 #include "mlir/Pass/Pass.h"
 
-#define GEN_PASS_DECL
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
-
 namespace mlir {
 
 namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
 
 namespace acc {
 
-/// Create a pass to replace ssa values in region with device/host values.
-std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
 
 /// Generate the code for registering conversion passes.
 #define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index abbc27765e3423..9ceb91e5679a1e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -11,18 +11,20 @@
 
 include "mlir/Pass/PassBase.td"
 
-def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
-  let summary = "Legalize the data in the compute region";
+def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
+  let summary = "Legalizes SSA values in compute regions with results from data clause operations";
   let description = [{
-    This pass replace uses of varPtr in the compute region with their accPtr
-    gathered from the data clause operands.
+    This pass replace uses of the `varPtr` in compute regions (kernels,
+    parallel, serial) with the result of data clause operations (`accPtr`).
   }];
   let options = [
     Option<"hostToDevice", "host-to-device", "bool", "true",
            "Replace varPtr uses with accPtr if true. Replace accPtr uses with "
-           "varPtr if false">
+           "varPtr if false">,
+    Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
+           "Replaces varPtr uses with accPtr for acc compute regions contained "
+           "within acc.data or acc.declare region.">
   ];
-  let constructor = "::mlir::acc::createLegalizeDataInRegion()";
 }
 
 #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 41ba7f8f53d367..7d934956089a5a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIROpenACCTransforms
-  LegalizeData.cpp
+  LegalizeDataValues.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
similarity index 54%
rename from mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
rename to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index db6b472ff9733a..4038e333adb8b6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeData.cpp - -------------------------------------------------===//
+//===- LegalizeDataValues.cpp - -------------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace mlir {
 namespace acc {
-#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
 #include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
 } // namespace acc
 } // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
 
 namespace {
 
+static bool insideAccComputeRegion(mlir::Operation *op) {
+  mlir::Operation *parent{op->getParentOp()};
+  while (parent) {
+    if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
+      return true;
+    }
+    parent = parent->getParentOp();
+  }
+  return false;
+}
+
 static void collectPtrs(mlir::ValueRange operands,
                         llvm::SmallVector<std::pair<Value, Value>> &values,
                         bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
   }
 }
 
+template <typename Op>
+static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
+                                                  Region &outerRegion) {
+  for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
+    if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
+      if constexpr (std::is_same_v<Op, acc::DataOp> ||
+                    std::is_same_v<Op, acc::DeclareOp>) {
+        // For data construct regions, only replace uses in contained compute
+        // regions.
+        if (insideAccComputeRegion(use.getOwner())) {
+          use.set(replacement);
+        }
+      } else {
+        use.set(replacement);
+      }
+    }
+  }
+}
+
 template <typename Op>
 static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
   llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,7 +79,9 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
     collectPtrs(op.getPrivateOperands(), values, hostToDevice);
   } else {
     collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
-    if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+    if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
+                  !std::is_same_v<Op, acc::DataOp> &&
+                  !std::is_same_v<Op, acc::DeclareOp>) {
       collectPtrs(op.getReductionOperands(), values, hostToDevice);
       collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
       collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
@@ -56,18 +89,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
   }
 
   for (auto p : values)
-    replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+    replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+                                              op.getRegion());
 }
 
-struct LegalizeDataInRegion
-    : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+class LegalizeDataValuesInRegion
+    : public acc::impl::LegalizeDataValuesInRegionBase<
+          LegalizeDataValuesInRegion> {
+public:
+  using LegalizeDataValuesInRegionBase<
+      LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
 
   void runOnOperation() override {
     func::FuncOp funcOp = getOperation();
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     funcOp.walk([&](Operation *op) {
-      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
+      if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
+          !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
+            applyToAccDataConstruct))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
       } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
         collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
+      } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
+        collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
+      } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
+        collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
+      } else {
+        llvm_unreachable("unsupported acc region op");
       }
     });
   }
 };
 
 } // end anonymous namespace
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::acc::createLegalizeDataInRegion() {
-  return std::make_unique<LegalizeDataInRegion>();
-}

``````````

</details>


https://github.com/llvm/llvm-project/pull/112990


More information about the Mlir-commits mailing list