[flang-commits] [flang] [mlir] [acc] Improve LegalizeDataValues pass to handle data constructs (PR #112990)
Razvan Lupusoru via flang-commits
flang-commits at lists.llvm.org
Fri Oct 18 16:07:11 PDT 2024
https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/112990
>From 51c4116f433ba91e92e6484edb332c06d272439e Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Fri, 18 Oct 2024 15:01:08 -0700
Subject: [PATCH] [acc] Improve LegalizeDataValues pass to handle data
constructs
Renames LegalizeData to LegalizeDataValues since this pass fixes up SSA
values. LegalizeData suggested that it fixed data mapping.
This change also adds support to fix up ssa values for data clause
operations. Effectively, compute regions within a data region use the
ssa values from data operations also. The ssa values within data regions
but not within compute regions are not updated.
This change is to support the requirement in the OpenACC spec which
notes that a visible data clause is not just one on the current compute
construct but on the lexically containing data construct or visible
declare directive.
---
flang/test/Fir/OpenACC/legalize-data.fir | 35 +++++++++-
mlir/include/mlir/Dialect/OpenACC/OpenACC.h | 4 +-
.../mlir/Dialect/OpenACC/Transforms/Passes.h | 7 +-
.../mlir/Dialect/OpenACC/Transforms/Passes.td | 14 ++--
.../Dialect/OpenACC/Transforms/CMakeLists.txt | 2 +-
...egalizeData.cpp => LegalizeDataValues.cpp} | 65 +++++++++++++++----
mlir/test/Dialect/OpenACC/legalize-data.mlir | 30 ++++++++-
7 files changed, 128 insertions(+), 29 deletions(-)
rename mlir/lib/Dialect/OpenACC/Transforms/{LegalizeData.cpp => LegalizeDataValues.cpp} (54%)
diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir
index 3b8695434e6e47..6bc81dc08db303 100644
--- a/flang/test/Fir/OpenACC/legalize-data.fir
+++ b/flang/test/Fir/OpenACC/legalize-data.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
+// RUN: fir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s
func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -22,3 +22,36 @@ func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
+ %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ acc.data dataOperands(%1 : !fir.ref<i32>) {
+ %c0_i32 = arith.constant 0 : i32
+ hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+ acc.serial {
+ hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+ acc.yield
+ }
+ acc.terminator
+ }
+ acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1
+// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
+// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+// CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: hlfir.assign %[[C0]] to %0#0 : i32, !fir.ref<i32>
+// CHECK: acc.serial {
+// CHECK: hlfir.assign %[[C0]] to %[[COPYIN]] : i32, !fir.ref<i32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ca96ce62ae404e..60fe4c5fb9d4cc 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -56,14 +56,14 @@
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
-#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
+#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
mlir::acc::DeclareExitOp
#define ACC_DATA_CONSTRUCT_OPS \
- OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
+ ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
index bb93c78bf6eadf..57d532b078b9e3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -11,9 +11,6 @@
#include "mlir/Pass/Pass.h"
-#define GEN_PASS_DECL
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
-
namespace mlir {
namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
namespace acc {
-/// Create a pass to replace ssa values in region with device/host values.
-std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index abbc27765e3423..9ceb91e5679a1e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -11,18 +11,20 @@
include "mlir/Pass/PassBase.td"
-def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
- let summary = "Legalize the data in the compute region";
+def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
+ let summary = "Legalizes SSA values in compute regions with results from data clause operations";
let description = [{
- This pass replace uses of varPtr in the compute region with their accPtr
- gathered from the data clause operands.
+ This pass replace uses of the `varPtr` in compute regions (kernels,
+ parallel, serial) with the result of data clause operations (`accPtr`).
}];
let options = [
Option<"hostToDevice", "host-to-device", "bool", "true",
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
- "varPtr if false">
+ "varPtr if false">,
+ Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
+ "Replaces varPtr uses with accPtr for acc compute regions contained "
+ "within acc.data or acc.declare region.">
];
- let constructor = "::mlir::acc::createLegalizeDataInRegion()";
}
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 41ba7f8f53d367..7d934956089a5a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
- LegalizeData.cpp
+ LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
similarity index 54%
rename from mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
rename to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index db6b472ff9733a..4038e333adb8b6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeData.cpp - -------------------------------------------------===//
+//===- LegalizeDataValues.cpp - -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
namespace acc {
-#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
namespace {
+static bool insideAccComputeRegion(mlir::Operation *op) {
+ mlir::Operation *parent{op->getParentOp()};
+ while (parent) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
+ return true;
+ }
+ parent = parent->getParentOp();
+ }
+ return false;
+}
+
static void collectPtrs(mlir::ValueRange operands,
llvm::SmallVector<std::pair<Value, Value>> &values,
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
}
}
+template <typename Op>
+static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
+ Region &outerRegion) {
+ for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
+ if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
+ if constexpr (std::is_same_v<Op, acc::DataOp> ||
+ std::is_same_v<Op, acc::DeclareOp>) {
+ // For data construct regions, only replace uses in contained compute
+ // regions.
+ if (insideAccComputeRegion(use.getOwner())) {
+ use.set(replacement);
+ }
+ } else {
+ use.set(replacement);
+ }
+ }
+ }
+}
+
template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,7 +79,9 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
} else {
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
- if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
+ !std::is_same_v<Op, acc::DataOp> &&
+ !std::is_same_v<Op, acc::DeclareOp>) {
collectPtrs(op.getReductionOperands(), values, hostToDevice);
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
@@ -56,18 +89,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
}
for (auto p : values)
- replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+ op.getRegion());
}
-struct LegalizeDataInRegion
- : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+class LegalizeDataValuesInRegion
+ : public acc::impl::LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion> {
+public:
+ using LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
+ !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
+ applyToAccDataConstruct))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
+ } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
+ collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
+ } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
+ collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
+ } else {
+ llvm_unreachable("unsupported acc region op");
}
});
}
};
} // end anonymous namespace
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::acc::createLegalizeDataInRegion() {
- return std::make_unique<LegalizeDataInRegion>();
-}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 113fe90450ab7b..842f8e260c499e 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
-// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
+// RUN: mlir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s --check-prefixes=CHECK,DEVICE
+// RUN: mlir-opt -split-input-file --openacc-legalize-data-values=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
@@ -61,6 +61,32 @@ func.func @test(%a: memref<10xf32>, %i : index) {
// -----
+func.func @test(%a: memref<10xf32>, %i : index) {
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.data dataOperands(%create : memref<10xf32>) {
+ %c0 = arith.constant 0.000000e+00 : f32
+ memref.store %c0, %a[%i] : memref<10xf32>
+ acc.serial {
+ %cs = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.data dataOperands(%[[CREATE]] : memref<10xf32>) {
+// CHECK: memref.store %{{.*}}, %[[A]][%[[I]]] : memref<10xf32>
+// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+// -----
+
func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
More information about the flang-commits
mailing list