[Mlir-commits] [mlir] 4c9717c - [mlir][openacc] Add private/reduction in legalize data pass (#80882)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 6 13:21:17 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-02-06T13:21:13-08:00
New Revision: 4c9717c3bee3525c2bf0251469191cc65e246a14

URL: https://github.com/llvm/llvm-project/commit/4c9717c3bee3525c2bf0251469191cc65e246a14
DIFF: https://github.com/llvm/llvm-project/commit/4c9717c3bee3525c2bf0251469191cc65e246a14.diff

LOG: [mlir][openacc] Add private/reduction in legalize data pass (#80882)

This is a follow up to #80351 and adds private and reduction operands
from acc.loop, acc.parallel and acc.serial operations.

Added: 
    

Modified: 
    mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
    mlir/test/Dialect/OpenACC/legalize-data.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
index ef44a0ec68d9c..db6b472ff9733 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeData.cpp
@@ -24,10 +24,10 @@ using namespace mlir;
 
 namespace {
 
-template <typename Op>
-static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
-  llvm::SmallVector<std::pair<Value, Value>> values;
-  for (auto operand : op.getDataClauseOperands()) {
+static void collectPtrs(mlir::ValueRange operands,
+                        llvm::SmallVector<std::pair<Value, Value>> &values,
+                        bool hostToDevice) {
+  for (auto operand : operands) {
     Value varPtr = acc::getVarPtr(operand.getDefiningOp());
     Value accPtr = acc::getAccPtr(operand.getDefiningOp());
     if (varPtr && accPtr) {
@@ -37,6 +37,23 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
         values.push_back({accPtr, varPtr});
     }
   }
+}
+
+template <typename Op>
+static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
+  llvm::SmallVector<std::pair<Value, Value>> values;
+
+  if constexpr (std::is_same_v<Op, acc::LoopOp>) {
+    collectPtrs(op.getReductionOperands(), values, hostToDevice);
+    collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+  } else {
+    collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+    if constexpr (!std::is_same_v<Op, acc::KernelsOp>) {
+      collectPtrs(op.getReductionOperands(), values, hostToDevice);
+      collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+    }
+  }
 
   for (auto p : values)
     replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), op.getRegion());
@@ -50,7 +67,7 @@ struct LegalizeDataInRegion
     bool replaceHostVsDevice = this->hostToDevice.getValue();
 
     funcOp.walk([&](Operation *op) {
-      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op))
+      if (!isa<ACC_COMPUTE_CONSTRUCT_OPS>(*op) && !isa<acc::LoopOp>(*op))
         return;
 
       if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -59,6 +76,8 @@ struct LegalizeDataInRegion
         collectAndReplaceInRegion(serialOp, replaceHostVsDevice);
       } else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(*op)) {
         collectAndReplaceInRegion(kernelsOp, replaceHostVsDevice);
+      } else if (auto loopOp = dyn_cast<acc::LoopOp>(*op)) {
+        collectAndReplaceInRegion(loopOp, replaceHostVsDevice);
       }
     });
   }

diff  --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index 4c86223c720a3..113fe90450ab7 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -86,3 +86,117 @@ func.func @test(%a: memref<10xf32>) {
 // CHECK:   }
 // CHECK:   acc.yield
 // CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.parallel {
+    acc.loop private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.parallel  {
+// CHECK:   acc.loop private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }
+
+// -----
+
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+func.func @test(%a: memref<10xf32>) {
+  %lb = arith.constant 0 : index
+  %st = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %p1 = acc.private varPtr(%a : memref<10xf32>) -> memref<10xf32>
+  acc.serial private(@privatization_memref_10_f32 -> %p1 : memref<10xf32>) {
+    acc.loop control(%i : index) = (%lb : index) to (%c10 : index) step (%st : index) {
+      %ci = memref.load %a[%i] : memref<10xf32>
+      acc.yield
+    }
+    acc.yield
+  }
+  return
+}
+
+// CHECK: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) -> memref<10xf32>
+// CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
+// CHECK:   acc.loop control(%[[I:.*]] : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index) {
+// DEVICE:    %{{.*}} = memref.load %[[PRIVATE:.*]][%[[I]]] : memref<10xf32>
+// CHECK:     acc.yield
+// CHECK:   }
+// CHECK:   acc.yield
+// CHECK: }


        


More information about the Mlir-commits mailing list