[Mlir-commits] [mlir] [acc] Improve LegalizeDataValues pass to handle data constructs (PR #112990)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Oct 18 15:03:14 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Razvan Lupusoru (razvanlupusoru)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/112990.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+2-2)
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h (+2-5)
- (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+8-6)
- (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+1-1)
- (renamed) mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp (+53-12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ca96ce62ae404e..60fe4c5fb9d4cc 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -56,14 +56,14 @@
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
-#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
+#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
mlir::acc::DeclareExitOp
#define ACC_DATA_CONSTRUCT_OPS \
- OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
+ ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
index bb93c78bf6eadf..57d532b078b9e3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -11,9 +11,6 @@
#include "mlir/Pass/Pass.h"
-#define GEN_PASS_DECL
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
-
namespace mlir {
namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
namespace acc {
-/// Create a pass to replace ssa values in region with device/host values.
-std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index abbc27765e3423..9ceb91e5679a1e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -11,18 +11,20 @@
include "mlir/Pass/PassBase.td"
-def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
- let summary = "Legalize the data in the compute region";
+def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
+ let summary = "Legalizes SSA values in compute regions with results from data clause operations";
let description = [{
- This pass replace uses of varPtr in the compute region with their accPtr
- gathered from the data clause operands.
+ This pass replace uses of the `varPtr` in compute regions (kernels,
+ parallel, serial) with the result of data clause operations (`accPtr`).
}];
let options = [
Option<"hostToDevice", "host-to-device", "bool", "true",
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
- "varPtr if false">
+ "varPtr if false">,
+ Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
+ "Replaces varPtr uses with accPtr for acc compute regions contained "
+ "within acc.data or acc.declare region.">
];
- let constructor = "::mlir::acc::createLegalizeDataInRegion()";
}
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 41ba7f8f53d367..7d934956089a5a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
- LegalizeData.cpp
+ LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
similarity index 54%
rename from mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
rename to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index db6b472ff9733a..4038e333adb8b6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeData.cpp - -------------------------------------------------===//
+//===- LegalizeDataValues.cpp - -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
namespace acc {
-#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
namespace {
+static bool insideAccComputeRegion(mlir::Operation *op) {
+ mlir::Operation *parent{op->getParentOp()};
+ while (parent) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
+ return true;
+ }
+ parent = parent->getParentOp();
+ }
+ return false;
+}
+
static void collectPtrs(mlir::ValueRange operands,
llvm::SmallVector<std::pair<Value, Value>> &values,
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
}
}
+template <typename Op>
+static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
+ Region &outerRegion) {
+ for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
+ if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
+ if constexpr (std::is_same_v<Op, acc::DataOp> ||
+ std::is_same_v<Op, acc::DeclareOp>) {
+ // For data construct regions, only replace uses in contained compute
+ // regions.
+ if (insideAccComputeRegion(use.getOwner())) {
+ use.set(replacement);
+ }
+ } else {
+ use.set(replacement);
+ }
+ }
+ }
+}
+
template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,7 +79,9 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
} else {
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
- if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
+ !std::is_same_v<Op, acc::DataOp> &&
+ !std::is_same_v<Op, acc::DeclareOp>) {
collectPtrs(op.getReductionOperands(), values, hostToDevice);
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
@@ -56,18 +89,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
}
for (auto p : values)
- replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+ op.getRegion());
}
-struct LegalizeDataInRegion
- : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+class LegalizeDataValuesInRegion
+ : public acc::impl::LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion> {
+public:
+ using LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
+ !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
+ applyToAccDataConstruct))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
+ } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
+ collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
+ } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
+ collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
+ } else {
+ llvm_unreachable("unsupported acc region op");
}
});
}
};
} // end anonymous namespace
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::acc::createLegalizeDataInRegion() {
- return std::make_unique<LegalizeDataInRegion>();
-}
``````````
</details>
https://github.com/llvm/llvm-project/pull/112990
More information about the Mlir-commits
mailing list