[Mlir-commits] [mlir] [mlir][openacc] Add private/reduction in legalize data pass (PR #80882)

Valentin Clement バレンタイン クレメン llvmlistbot at llvm.org
Tue Feb 6 09:51:55 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/80882

This is a follow up to #80351 and adds private and reduction operands from acc.loop, acc.parallel and acc.serial operations.

>From 41de1e507a31f6ddf2c0e018fbce89ab08895206 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 5 Feb 2024 15:17:08 -0800
Subject: [PATCH] [mlir][openacc] Add private/reduction in legalize data pass

---
 .../OpenACC/Transforms/LegalizeData.cpp       |  29 ++++-
 mlir/test/Dialect/OpenACC/legalize-data.mlir  | 114 ++++++++++++++++++
 2 files changed, 138 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9c..db6b472ff9733 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
-  llvm::SmallVector<std::pair<Value, Value>> values;
-  for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+                        llvm::SmallVector<std::pair<Value, Value>> &values,
+                        bool hostToDevice) {
+  for (auto operand : operands) {
     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
     if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
         values.push_back({accPtr, varPtr});
     }
   }
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+
+  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+    collectPtrs(op.getReductionOperands(), values, hostToDevice);
+    collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+  } else {
+    collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+    if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+      collectPtrs(op.getReductionOperands(), values, hostToDevice);
+      collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+    }
+  }
 
   for (auto p : values)
     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     funcOp.walk([&](Operation *op) {
-      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+        collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
       }
     });
   }
diff --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a3..113fe90450ab7 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
 // CHECK:   }
 // CHECK:   acc.yield
 // CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel {
+    acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel  {
+// CHECK:   acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }



More information about the Mlir-commits mailing list