[Mlir-commits] [mlir] ac9ee61 - [acc] Improve LegalizeDataValues pass to handle data constructs (#112990)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 21 09:50:02 PDT 2024
Author: Razvan Lupusoru
Date: 2024-10-21T09:49:58-07:00
New Revision: ac9ee618572537bcd77c58899aaab1d41dbad206
URL: https://github.com/llvm/llvm-project/commit/ac9ee618572537bcd77c58899aaab1d41dbad206
DIFF: https://github.com/llvm/llvm-project/commit/ac9ee618572537bcd77c58899aaab1d41dbad206.diff
LOG: [acc] Improve LegalizeDataValues pass to handle data constructs (#112990)
Renames LegalizeData to LegalizeDataValues since this pass fixes up SSA
values. LegalizeData suggested that it fixed data mapping.
This change also adds support to fix up ssa values for data clause
operations. Effectively, compute regions within a data region use the
ssa values from data operations also. The ssa values within data regions
but not within compute regions are not updated.
This change is to support the requirement in the OpenACC spec which
notes that a visible data clause is not just one on the current compute
construct but on the lexically containing data construct or visible
declare directive.
Added:
mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
Modified:
flang/test/Fir/OpenACC/legalize-data.fir
mlir/include/mlir/Dialect/OpenACC/OpenACC.h
mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
mlir/test/Dialect/OpenACC/legalize-data.mlir
Removed:
mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
################################################################################
diff --git a/flang/test/Fir/OpenACC/legalize-data.fir b/flang/test/Fir/OpenACC/legalize-data.fir
index 3b8695434e6e47..6bc81dc08db303 100644
--- a/flang/test/Fir/OpenACC/legalize-data.fir
+++ b/flang/test/Fir/OpenACC/legalize-data.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s
+// RUN: fir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s
func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
%0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
@@ -22,3 +22,36 @@ func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<i32> {fir.bindc_name = "i"}) {
+ %0:2 = hlfir.declare %arg0 {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %1 = acc.copyin varPtr(%0#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ acc.data dataOperands(%1 : !fir.ref<i32>) {
+ %c0_i32 = arith.constant 0 : i32
+ hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+ acc.serial {
+ hlfir.assign %c0_i32 to %0#0 : i32, !fir.ref<i32>
+ acc.yield
+ }
+ acc.terminator
+ }
+ acc.copyout accPtr(%1 : !fir.ref<i32>) to varPtr(%0#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
+ return
+}
+
+// CHECK-LABEL: func.func @_QPsub1
+// CHECK-SAME: (%[[ARG0:.*]]: !fir.ref<i32> {fir.bindc_name = "i"})
+// CHECK: %[[I:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%[[I]]#0 : !fir.ref<i32>) -> !fir.ref<i32> {dataClause = #acc<data_clause acc_copy>, name = "i"}
+// CHECK: acc.data dataOperands(%[[COPYIN]] : !fir.ref<i32>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: hlfir.assign %[[C0]] to %0#0 : i32, !fir.ref<i32>
+// CHECK: acc.serial {
+// CHECK: hlfir.assign %[[C0]] to %[[COPYIN]] : i32, !fir.ref<i32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.terminator
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[COPYIN]] : !fir.ref<i32>) to varPtr(%[[I]]#0 : !fir.ref<i32>) {dataClause = #acc<data_clause acc_copy>, name = "i"}
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index ca96ce62ae404e..cda07d6a913649 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -56,14 +56,14 @@
mlir::acc::ParallelOp, mlir::acc::KernelsOp, mlir::acc::SerialOp
#define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
-#define OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS \
+#define ACC_DATA_CONSTRUCT_STRUCTURED_OPS \
mlir::acc::DataOp, mlir::acc::DeclareOp
#define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS \
mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp, \
mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp, \
mlir::acc::DeclareExitOp
#define ACC_DATA_CONSTRUCT_OPS \
- OPENACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
+ ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS \
ACC_COMPUTE_CONSTRUCT_OPS, ACC_DATA_CONSTRUCT_OPS
#define ACC_COMPUTE_LOOP_AND_DATA_CONSTRUCT_OPS \
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
index bb93c78bf6eadf..57d532b078b9e3 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.h
@@ -11,9 +11,6 @@
#include "mlir/Pass/Pass.h"
-#define GEN_PASS_DECL
-#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
-
namespace mlir {
namespace func {
@@ -22,8 +19,8 @@ class FuncOp;
namespace acc {
-/// Create a pass to replace ssa values in region with device/host values.
-std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeDataInRegion();
+#define GEN_PASS_DECL
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
/// Generate the code for registering conversion passes.
#define GEN_PASS_REGISTRATION
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index abbc27765e3423..9ceb91e5679a1e 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -11,18 +11,20 @@
include "mlir/Pass/PassBase.td"
-def LegalizeDataInRegion : Pass<"openacc-legalize-data", "mlir::func::FuncOp"> {
- let summary = "Legalize the data in the compute region";
+def LegalizeDataValuesInRegion : Pass<"openacc-legalize-data-values", "mlir::func::FuncOp"> {
+ let summary = "Legalizes SSA values in compute regions with results from data clause operations";
let description = [{
- This pass replace uses of varPtr in the compute region with their accPtr
- gathered from the data clause operands.
+ This pass replace uses of the `varPtr` in compute regions (kernels,
+ parallel, serial) with the result of data clause operations (`accPtr`).
}];
let options = [
Option<"hostToDevice", "host-to-device", "bool", "true",
"Replace varPtr uses with accPtr if true. Replace accPtr uses with "
- "varPtr if false">
+ "varPtr if false">,
+ Option<"applyToAccDataConstruct", "apply-to-acc-data-construct", "bool", "true",
+ "Replaces varPtr uses with accPtr for acc compute regions contained "
+ "within acc.data or acc.declare region.">
];
- let constructor = "::mlir::acc::createLegalizeDataInRegion()";
}
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 41ba7f8f53d367..7d934956089a5a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
- LegalizeData.cpp
+ LegalizeDataValues.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
similarity index 54%
rename from mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
rename to mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index db6b472ff9733a..4038e333adb8b6 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -1,4 +1,4 @@
-//===- LegalizeData.cpp - -------------------------------------------------===//
+//===- LegalizeDataValues.cpp - -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -12,10 +12,11 @@
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/Support/ErrorHandling.h"
namespace mlir {
namespace acc {
-#define GEN_PASS_DEF_LEGALIZEDATAINREGION
+#define GEN_PASS_DEF_LEGALIZEDATAVALUESINREGION
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
} // namespace acc
} // namespace mlir
@@ -24,6 +25,17 @@ using namespace mlir;
namespace {
+static bool insideAccComputeRegion(mlir::Operation *op) {
+ mlir::Operation *parent{op->getParentOp()};
+ while (parent) {
+ if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(parent)) {
+ return true;
+ }
+ parent = parent->getParentOp();
+ }
+ return false;
+}
+
static void collectPtrs(mlir::ValueRange operands,
llvm::SmallVector<std::pair<Value, Value>> &values,
bool hostToDevice) {
@@ -39,6 +51,25 @@ static void collectPtrs(mlir::ValueRange operands,
}
}
+template <typename Op>
+static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
+ Region &outerRegion) {
+ for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
+ if (outerRegion.isAncestor(use.getOwner()->getParentRegion())) {
+ if constexpr (std::is_same_v<Op, acc::DataOp> ||
+ std::is_same_v<Op, acc::DeclareOp>) {
+ // For data construct regions, only replace uses in contained compute
+ // regions.
+ if (insideAccComputeRegion(use.getOwner())) {
+ use.set(replacement);
+ }
+ } else {
+ use.set(replacement);
+ }
+ }
+ }
+}
+
template <typename Op>
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
@@ -48,7 +79,9 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
} else {
collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
- if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
+ !std::is_same_v<Op, acc::DataOp> &&
+ !std::is_same_v<Op, acc::DeclareOp>) {
collectPtrs(op.getReductionOperands(), values, hostToDevice);
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
@@ -56,18 +89,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
}
for (auto p : values)
- replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
+ op.getRegion());
}
-struct LegalizeDataInRegion
- : public acc::impl::LegalizeDataInRegionBase<LegalizeDataInRegion> {
+class LegalizeDataValuesInRegion
+ : public acc::impl::LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion> {
+public:
+ using LegalizeDataValuesInRegionBase<
+ LegalizeDataValuesInRegion>::LegalizeDataValuesInRegionBase;
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
+ !(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
+ applyToAccDataConstruct))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -78,14 +118,15 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
} else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
+ } else if (auto dataOp = dyn_cast<acc::DataOp>(*op)) {
+ collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
+ } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
+ collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
+ } else {
+ llvm_unreachable("unsupported acc region op");
}
});
}
};
} // end anonymous namespace
-
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::acc::createLegalizeDataInRegion() {
- return std::make_unique<LegalizeDataInRegion>();
-}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 113fe90450ab7b..842f8e260c499e 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -split-input-file --openacc-legalize-data %s | FileCheck %s --check-prefixes=CHECK,DEVICE
-// RUN: mlir-opt -split-input-file --openacc-legalize-data=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
+// RUN: mlir-opt -split-input-file --openacc-legalize-data-values %s | FileCheck %s --check-prefixes=CHECK,DEVICE
+// RUN: mlir-opt -split-input-file --openacc-legalize-data-values=host-to-device=false %s | FileCheck %s --check-prefixes=CHECK,HOST
func.func @test(%a: memref<10xf32>, %i : index) {
%create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
@@ -61,6 +61,32 @@ func.func @test(%a: memref<10xf32>, %i : index) {
// -----
+func.func @test(%a: memref<10xf32>, %i : index) {
+ %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.data dataOperands(%create : memref<10xf32>) {
+ %c0 = arith.constant 0.000000e+00 : f32
+ memref.store %c0, %a[%i] : memref<10xf32>
+ acc.serial {
+ %cs = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>, %[[I:.*]]: index)
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.data dataOperands(%[[CREATE]] : memref<10xf32>) {
+// CHECK: memref.store %{{.*}}, %[[A]][%[[I]]] : memref<10xf32>
+// DEVICE: %{{.*}} = memref.load %[[CREATE]][%[[I]]] : memref<10xf32>
+// HOST: %{{.*}} = memref.load %[[A]][%[[I]]] : memref<10xf32>
+// CHECK: acc.terminator
+// CHECK: }
+
+// -----
+
func.func @test(%a: memref<10xf32>) {
%lb = arith.constant 0 : index
%st = arith.constant 1 : index
More information about the Mlir-commits
mailing list