[Mlir-commits] [mlir] [mlir][openacc] Add private/reduction in legalize data pass (PR #80882)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 6 09:52:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
This is a follow up to #<!-- -->80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.
---
Full diff: https://github.com/llvm/llvm-project/pull/80882.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp (+24-5)
- (modified) mlir/test/Dialect/OpenACC/legalize-data.mlir (+114)
``````````diff
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9ca..db6b472ff9733a 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
namespace {
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
- llvm::SmallVector<std::pair<Value, Value>> values;
- for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+ llvm::SmallVector<std::pair<Value, Value>> &values,
+ bool hostToDevice) {
+ for (auto operand : operands) {
Value varPtr = acc::getVarPtr(operand.getDefiningOp());
Value accPtr = acc::getAccPtr(operand.getDefiningOp());
if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
values.push_back({accPtr, varPtr});
}
}
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+ llvm::SmallVector<std::pair<Value, Value>> values;
+
+ if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+ } else {
+ collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+ if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+ collectPtrs(op.getReductionOperands(), values, hostToDevice);
+ collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+ collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+ }
+ }
for (auto p : values)
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
bool replaceHostVsDevice = this->hostToDevice.getValue();
funcOp.walk([&](Operation *op) {
- if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+ if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
return;
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
} else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+ } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+ collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
}
});
}
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a33..113fe90450ab7b 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
// CHECK: }
// CHECK: acc.yield
// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.parallel {
+ acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel {
+// CHECK: acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+ %0 = memref.alloc() : memref<10xf32>
+ acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+ memref.dealloc %arg0 : memref<10xf32>
+ acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+ %lb = arith.constant 0 : index
+ %st = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+ acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+ acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+ %ci = memref.load %a[%i] : memref<10xf32>
+ acc.yield
+ }
+ acc.yield
+ }
+ return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK: acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index) {
+// DEVICE: %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK: acc.yield
+// CHECK: }
+// CHECK: acc.yield
+// CHECK: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/80882
More information about the Mlir-commits
mailing list