[Mlir-commits] [mlir] 3633de7 - [mlir][acc] Handle OpenACC host_data in LegalizeDataValues (#134767)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 14 16:29:20 PDT 2025


Author: nvptm
Date: 2025-04-14T16:29:17-07:00
New Revision: 3633de702985e1580d72223ae38a31d0e6fd480b

URL: https://github.com/llvm/llvm-project/commit/3633de702985e1580d72223ae38a31d0e6fd480b
DIFF: https://github.com/llvm/llvm-project/commit/3633de702985e1580d72223ae38a31d0e6fd480b.diff

LOG: [mlir][acc] Handle OpenACC host_data in LegalizeDataValues (#134767)

`LegalizeDataValuesInRegion` is intended to replace the SSA values used
in a region with the output of data operations, but misses the handling
of the OpenACC `host_data` construct. As a result, currently

```
 !$acc host_data use_device(%var)
   ...%var...
 !$acc end host_data

```
is lowered to

```
 %dev_var = acc.use_device(%var)
 acc.host_data data_operands(%dev_var) {
   ...%var...
 }
```

This pull request updates the LegalizeDataValuesInRegion to handle
HostDataOp such that lowering results in

```
 %dev_var = acc.use_device(%var)
 acc.host_data data_operands(%dev_var) {
   ...%dev_var...
 }
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACC.h
    mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
    mlir/test/Dialect/OpenACC/legalize-data.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 748cb7f28fc8c..ff5845343313c 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -58,11 +58,10 @@
 #define ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS                                     \
   ACC_COMPUTE_CONSTRUCT_OPS, mlir::acc::LoopOp
 #define ACC_DATA_CONSTRUCT_STRUCTURED_OPS                                      \
-  mlir::acc::DataOp, mlir::acc::DeclareOp
+  mlir::acc::DataOp, mlir::acc::DeclareOp, mlir::acc::HostDataOp
 #define ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS                                    \
   mlir::acc::EnterDataOp, mlir::acc::ExitDataOp, mlir::acc::UpdateOp,          \
-      mlir::acc::HostDataOp, mlir::acc::DeclareEnterOp,                        \
-      mlir::acc::DeclareExitOp
+      mlir::acc::DeclareEnterOp, mlir::acc::DeclareExitOp
 #define ACC_DATA_CONSTRUCT_OPS                                                 \
   ACC_DATA_CONSTRUCT_STRUCTURED_OPS, ACC_DATA_CONSTRUCT_UNSTRUCTURED_OPS
 #define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS                                     \

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index a553653c73479..f2abeab744d17 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -81,7 +81,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
     collectVars(op.getDataClauseOperands(), values, hostToDevice);
     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
                   !std::is_same_v<Op, acc::DataOp> &&
-                  !std::is_same_v<Op, acc::DeclareOp>) {
+                  !std::is_same_v<Op, acc::DeclareOp> &&
+                  !std::is_same_v<Op, acc::HostDataOp>) {
       collectVars(op.getReductionOperands(), values, hostToDevice);
       collectVars(op.getPrivateOperands(), values, hostToDevice);
       collectVars(op.getFirstprivateOperands(), values, hostToDevice);
@@ -122,6 +123,8 @@ class LegalizeDataValuesInRegion
         collectAndReplaceInRegion(dataOp, replaceHostVsDevice);
       } else if (auto declareOp = dyn_cast<acc::DeclareOp>(*op)) {
         collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
+      } else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
+        collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
       } else {
         llvm_unreachable("unsupported acc region op");
       }

diff  --git a/mlir/test/Dialect/OpenACC/legalize-data.mlir b/mlir/test/Dialect/OpenACC/legalize-data.mlir
index baa72ae416c92..9461225e9a7e0 100644
--- a/mlir/test/Dialect/OpenACC/legalize-data.mlir
+++ b/mlir/test/Dialect/OpenACC/legalize-data.mlir
@@ -102,7 +102,7 @@ func.func @test(%a: memref<10xf32>) {
   return
 }
 
-// CHECK: func.func @test
+// CHECK-LABEL: func.func @test
 // CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
 // CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
 // CHECK: acc.parallel dataOperands(%[[CREATE]] : memref<10xf32>) {
@@ -140,7 +140,7 @@ func.func @test(%a: memref<10xf32>) {
   return
 }
 
-// CHECK: func.func @test
+// CHECK-LABEL: func.func @test
 // CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
 // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
 // CHECK: acc.parallel private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
@@ -178,7 +178,7 @@ func.func @test(%a: memref<10xf32>) {
   return
 }
 
-// CHECK: func.func @test
+// CHECK-LABEL: func.func @test
 // CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
 // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
 // CHECK: acc.parallel  {
@@ -216,7 +216,7 @@ func.func @test(%a: memref<10xf32>) {
   return
 }
 
-// CHECK: func.func @test
+// CHECK-LABEL: func.func @test
 // CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
 // CHECK: %[[PRIVATE:.*]] = acc.private varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
 // CHECK: acc.serial private(@privatization_memref_10_f32 -> %[[PRIVATE]] : memref<10xf32>) {
@@ -226,3 +226,23 @@ func.func @test(%a: memref<10xf32>) {
 // CHECK:   }
 // CHECK:   acc.yield
 // CHECK: }
+
+// -----
+
+func.func @test(%a: memref<10xf32>) {
+  %devptr = acc.use_device varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+  acc.host_data dataOperands(%devptr : memref<10xf32>) {
+    func.call @foo(%a) : (memref<10xf32>) -> ()
+    acc.terminator
+  }
+  return
+}
+func.func private @foo(memref<10xf32>)
+
+// CHECK-LABEL: func.func @test
+// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
+// CHECK: %[[USE_DEVICE:.*]] = acc.use_device varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32>
+// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
+// DEVICE:   func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
+// CHECK:   acc.terminator
+// CHECK: }
\ No newline at end of file


        


More information about the Mlir-commits mailing list