[flang-commits] [flang] 689afa8 - [mlir][openacc] Cleanup acc.update from old data clause operands

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Mon May 8 10:03:36 PDT 2023


Author: Valentin Clement
Date: 2023-05-08T10:03:28-07:00
New Revision: 689afa88ae8b8f3fc29bcd3be098b91f8a12e62e

URL: https://github.com/llvm/llvm-project/commit/689afa88ae8b8f3fc29bcd3be098b91f8a12e62e
DIFF: https://github.com/llvm/llvm-project/commit/689afa88ae8b8f3fc29bcd3be098b91f8a12e62e.diff

LOG: [mlir][openacc] Cleanup acc.update from old data clause operands

Since the new data operand operations have been added in D148389 and
adopted on acc.update in D149909, the old clause operands are no longer
needed. This is a first patch to start cleaning the OpenACC operations
with data clause operands.

The `LegalizeDataOpForLLVMTranslation` will become obsolete when all
operations will be cleaned. For the time being only the appropriate
part are being removed.

`processOperands` will also receive some updates once all the operands
will be coming from an acc data operand operation.

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D150053

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
    flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
    mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir
    mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
    mlir/test/Dialect/OpenACC/canonicalize.mlir
    mlir/test/Dialect/OpenACC/invalid.mlir
    mlir/test/Dialect/OpenACC/ops.mlir
    mlir/test/Target/LLVMIR/openacc-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fe7b34121e997..d46c0c54ee460 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -1494,7 +1494,6 @@ genACCUpdateOp(Fortran::lower::AbstractConverter &converter,
   addOperand(operands, operandSegments, waitDevnum);
   addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, deviceTypeOperands);
-  operandSegments.append({0, 0});
   addOperands(operands, operandSegments, dataClauseOperands);
 
   mlir::acc::UpdateOp updateOp = createSimpleOp<mlir::acc::UpdateOp>(

diff  --git a/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
index c68b06878949a..a612c4c4641ed 100644
--- a/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/OpenACC/OpenACCDataOperandConversion.cpp
@@ -119,7 +119,6 @@ void OpenACCDataOperandConversion::runOnOperation() {
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
-  patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
 
   ConversionTarget target(*context);
   target.addLegalDialect<fir::FIROpsDialect>();
@@ -182,12 +181,6 @@ void OpenACCDataOperandConversion::runOnOperation() {
                allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
       });
 
-  target.addDynamicallyLegalOp<acc::UpdateOp>(
-      [allDataOperandsAreConverted](acc::UpdateOp op) {
-        return allDataOperandsAreConverted(op.getHostOperands()) &&
-               allDataOperandsAreConverted(op.getDeviceOperands());
-      });
-
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
index f22b1a24fceaa..bfe5c378bb364 100644
--- a/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
+++ b/flang/test/Transforms/OpenACC/convert-data-operands-to-llvmir.fir
@@ -56,28 +56,6 @@ fir.global internal @_QFEa : !fir.array<10xf32> {
   fir.has_value %0 : !fir.array<10xf32>
 }
 
-func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
-  %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
-  acc.update device(%0 : !fir.ref<!fir.array<10xf32>>)
-  return
-}
-
-// CHECK-LABEL: func.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
-// CHECK: %[[ADDR:.*]] = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ADDR]] : !fir.ref<!fir.array<10xf32>> to !llvm.ptr<array<10 x f32>>
-// CHECK: acc.update device(%[[CAST]] : !llvm.ptr<array<10 x f32>>)
-
-// LLVMIR-LABEL: llvm.func @_QQsub_update() attributes {fir.bindc_name = "a"} {
-// LLVMIR: %[[ADDR:.*]] = llvm.mlir.addressof @_QFEa : !llvm.ptr<array<10 x f32>>
-// LLVMIR: acc.update device(%[[ADDR]] : !llvm.ptr<array<10 x f32>>)
-
-// -----
-
-fir.global internal @_QFEa : !fir.array<10xf32> {
-  %0 = fir.undefined !fir.array<10xf32>
-  fir.has_value %0 : !fir.array<10xf32>
-}
-
 func.func @_QQsub_parallel() attributes {fir.bindc_name = "test"} {
   %0 = fir.address_of(@_QFEa) : !fir.ref<!fir.array<10xf32>>
   %1 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index b028b5eeab88e..6e6cf7c57451f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1035,8 +1035,6 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
                        UnitAttr:$async,
                        UnitAttr:$wait,
                        Variadic<IntOrIndex>:$deviceTypeOperands,
-                       Variadic<AnyType>:$hostOperands,
-                       Variadic<AnyType>:$deviceOperands,
                        Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
                        UnitAttr:$ifPresent);
 
@@ -1056,8 +1054,6 @@ def OpenACC_UpdateOp : OpenACC_Op<"update", [AttrSizedOperandSegments]> {
       | `device_type` `(` $deviceTypeOperands `:`
           type($deviceTypeOperands) `)`
       | `wait` `(` $waitOperands `:` type($waitOperands) `)`
-      | `host` `(` $hostOperands `:` type($hostOperands) `)`
-      | `device` `(` $deviceOperands `:` type($deviceOperands) `)`
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
     )
     attr-dict-with-keyword

diff  --git a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
index be60afac04222..fc1cc3feeb009 100644
--- a/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenACCToLLVM/OpenACCToLLVM.cpp
@@ -159,7 +159,6 @@ void mlir::populateOpenACCToLLVMConversionPatterns(
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
   patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
-  patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
 }
 
 namespace {
@@ -243,12 +242,6 @@ void ConvertOpenACCToLLVMPass::runOnOperation() {
                allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
       });
 
-  target.addDynamicallyLegalOp<acc::UpdateOp>(
-      [allDataOperandsAreConverted](acc::UpdateOp op) {
-        return allDataOperandsAreConverted(op.getHostOperands()) &&
-               allDataOperandsAreConverted(op.getDeviceOperands());
-      });
-
   if (failed(applyPartialConversion(op, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6f53c4e75b7e7..e5b22aa0a83d9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -664,10 +664,8 @@ LogicalResult acc::ShutdownOp::verify() {
 
 LogicalResult acc::UpdateOp::verify() {
   // At least one of host or device should have a value.
-  if (getHostOperands().empty() && getDeviceOperands().empty() &&
-      getDataClauseOperands().empty())
-    return emitError(
-        "at least one value must be present in hostOperands or deviceOperands");
+  if (getDataClauseOperands().empty())
+    return emitError("at least one value must be present in dataOperands");
 
   // The async attribute represent the async clause without value. Therefore the
   // attribute and operand cannot appear at the same time.
@@ -692,8 +690,7 @@ LogicalResult acc::UpdateOp::verify() {
 }
 
 unsigned UpdateOp::getNumDataOperands() {
-  return getHostOperands().size() + getDeviceOperands().size() +
-         getDataClauseOperands().size();
+  return getDataClauseOperands().size();
 }
 
 Value UpdateOp::getDataOperand(unsigned i) {

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 4390839e14af0..17403a2f0d87d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -214,19 +214,28 @@ processDataOperands(llvm::IRBuilderBase &builder,
   unsigned index = 0;
 
   // Host operands are handled as `from` call.
-  if (failed(processOperands(builder, moduleTranslation, op,
-                             op.getHostOperands(), op.getNumDataOperands(),
+  // Device operands are handled as `to` call.
+  llvm::SmallVector<mlir::Value> from, to;
+  for (mlir::Value dataOp : op.getDataClauseOperands()) {
+    if (auto getDevicePtrOp = mlir::dyn_cast_or_null<acc::GetDevicePtrOp>(
+            dataOp.getDefiningOp())) {
+      from.push_back(getDevicePtrOp.getVarPtr());
+    } else if (auto updateDeviceOp =
+                   mlir::dyn_cast_or_null<acc::UpdateDeviceOp>(
+                       dataOp.getDefiningOp())) {
+      to.push_back(updateDeviceOp.getVarPtr());
+    }
+  }
+
+  if (failed(processOperands(builder, moduleTranslation, op, from, from.size(),
                              kHostCopyoutFlag, flags, names, index,
                              mapperAllocas)))
     return failure();
 
-  // Device operands are handled as `to` call.
-  if (failed(processOperands(builder, moduleTranslation, op,
-                             op.getDeviceOperands(), op.getNumDataOperands(),
+  if (failed(processOperands(builder, moduleTranslation, op, to, to.size(),
                              kDeviceCopyinFlag, flags, names, index,
                              mapperAllocas)))
     return failure();
-
   return success();
 }
 
@@ -486,6 +495,10 @@ LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
                "unexpected OpenACC terminator with operands");
         return success();
       })
+      .Case([&](acc::UpdateDeviceOp) {
+        // NOP
+        return success();
+      })
       .Default([&](Operation *op) {
         return op->emitError("unsupported OpenACC operation: ")
                << op->getName();

diff  --git a/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir b/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir
index f6966f3f9f056..f7cc3cf733fc7 100644
--- a/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenACCToLLVM/convert-data-operands-to-llvmir.mlir
@@ -74,43 +74,6 @@ func.func @testexitdataop(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
-  acc.update host(%b : memref<10xf32>) device(%a : memref<10xf32>)
-  return
-}
-
-// CHECK: acc.update host(%{{.*}} : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) device(%{{.*}} : !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>)
-
-// -----
-
-func.func @testupdateop(%a: !llvm.ptr, %b: memref<10xf32>) -> () {
-  acc.update host(%b : memref<10xf32>) device(%a : !llvm.ptr)
-  return
-}
-
-// CHECK: acc.update host(%{{.*}} : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) device(%{{.*}} : !llvm.ptr)
-
-// -----
-
-func.func @testupdateop(%a: memref<10xi64>, %b: memref<10xf32>) -> () {
-  acc.update host(%b : memref<10xf32>) device(%a : memref<10xi64>) attributes {async}
-  return
-}
-
-// CHECK: acc.update host(%{{.*}} : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) device(%{{.*}} : !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) attributes {async}
-
-// -----
-
-func.func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
-  %ifCond = arith.constant true
-  acc.update if(%ifCond) host(%b : memref<10xf32>) device(%a : memref<10xf32>)
-  return
-}
-
-// CHECK: acc.update if(%{{.*}}) host(%{{.*}} : !llvm.struct<"openacc_data", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>) device(%{{.*}} : !llvm.struct<"openacc_data.1", (struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, ptr, i64)>)
-
-// -----
-
 func.func @testdataregion(%a: memref<10xf32>, %b: memref<10xf32>) -> () {
   acc.data copy(%b : memref<10xf32>) copyout(%a : memref<10xf32>) {
     acc.parallel {

diff  --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
index 6f01f5fe3ea06..ce3ad44a83e50 100644
--- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
+++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
@@ -24,38 +24,41 @@ func.func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
-  acc.update if(%ifCond) host(%a: memref<10xf32>)
+func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
+  %0 = acc.update_device varPtr(%a : memref<f32>) -> memref<f32>
+  acc.update if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
 
-// CHECK:      func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
+// CHECK:      func @testupdateop(%{{.*}}: memref<f32>, [[IFCOND:%.*]]: i1)
 // CHECK:        scf.if [[IFCOND]] {
-// CHECK-NEXT:     acc.update host(%{{.*}} : memref<10xf32>)
+// CHECK:          acc.update dataOperands(%{{.*}} : memref<f32>)
 // CHECK-NEXT:   }
 
 // -----
 
-func.func @update_true(%arg0: memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+func.func @update_true(%arg0: memref<f32>) {
   %true = arith.constant true
-  acc.update if(%true) host(%arg0 : memref<10xf32, #spirv.storage_class<StorageBuffer>>)
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  acc.update if(%true) dataOperands(%0 : memref<f32>)
   return
 }
 
 // CHECK-LABEL: func.func @update_true
-// CHECK-NOT:if
-// CHECK:acc.update host
+// CHECK-NOT:     if
+// CHECK:         acc.update dataOperands
 
 // -----
 
-func.func @update_false(%arg0: memref<10xf32, #spirv.storage_class<StorageBuffer>>) {
+func.func @update_false(%arg0: memref<f32>) {
   %false = arith.constant false
-  acc.update if(%false) host(%arg0 : memref<10xf32, #spirv.storage_class<StorageBuffer>>)
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  acc.update if(%false) dataOperands(%0 : memref<f32>)
   return
 }
 
 // CHECK-LABEL: func.func @update_false
-// CHECK-NOT:acc.update
+// CHECK-NOT:     acc.update dataOperands
 
 // -----
 
@@ -66,8 +69,8 @@ func.func @enter_data_true(%d1 : memref<10xf32>) {
 }
 
 // CHECK-LABEL: func.func @enter_data_true
-// CHECK-NOT:if
-// CHECK:acc.enter_data create
+// CHECK-NOT:     if
+// CHECK:           acc.enter_data create
 
 // -----
 

diff  --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index 10cb19f128829..c0d8e152a5458 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -42,24 +42,28 @@ func.func @testexitdataop(%a: memref<10xf32>) -> () {
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>) -> () {
+func.func @testupdateop(%a: memref<f32>) -> () {
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
   %ifCond = arith.constant true
-  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
 }
 
-// CHECK: acc.update host(%{{.*}} : memref<10xf32>)
+// CHECK: acc.update dataOperands(%{{.*}} : memref<f32>)
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>) -> () {
+func.func @testupdateop(%a: memref<f32>) -> () {
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
   %ifCond = arith.constant false
-  acc.update if(%ifCond) host(%a: memref<10xf32>)
+  acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
 }
 
 // CHECK: func @testupdateop
-// CHECK-NOT: acc.update
+// CHECK-NOT: acc.update{{.$}}
 
 // -----
 
@@ -83,10 +87,12 @@ func.func @testexitdataop(%a: memref<10xf32>, %ifCond: i1) -> () {
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>, %ifCond: i1) -> () {
-  acc.update if(%ifCond) host(%a: memref<10xf32>)
+func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+  acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
 }
 
-// CHECK:  func @testupdateop(%{{.*}}: memref<10xf32>, [[IFCOND:%.*]]: i1)
-// CHECK:    acc.update if(%{{.*}}) host(%{{.*}} : memref<10xf32>)
+// CHECK:  func @testupdateop(%{{.*}}: memref<f32>, [[IFCOND:%.*]]: i1)
+// CHECK:    acc.update if(%{{.*}}) dataOperands(%{{.*}} : memref<f32>)

diff  --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 66644a7f31125..d4795dff6f239 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -84,29 +84,32 @@ acc.data dataOperands(%value : memref<10xf32>) {
 
 // -----
 
-// expected-error at +1 {{at least one value must be present in hostOperands or deviceOperands}}
+// expected-error at +1 {{at least one value must be present in dataOperands}}
 acc.update
 
 // -----
 
 %cst = arith.constant 1 : index
-%value = memref.alloc() : memref<10xf32>
+%value = memref.alloc() : memref<f32>
+%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{wait_devnum cannot appear without waitOperands}}
-acc.update wait_devnum(%cst: index) host(%value: memref<10xf32>)
+acc.update wait_devnum(%cst: index) dataOperands(%0: memref<f32>)
 
 // -----
 
 %cst = arith.constant 1 : index
-%value = memref.alloc() : memref<10xf32>
+%value = memref.alloc() : memref<f32>
+%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{async attribute cannot appear with asyncOperand}}
-acc.update async(%cst: index) host(%value: memref<10xf32>) attributes {async}
+acc.update async(%cst: index) dataOperands(%0 : memref<f32>) attributes {async}
 
 // -----
 
 %cst = arith.constant 1 : index
-%value = memref.alloc() : memref<10xf32>
+%value = memref.alloc() : memref<f32>
+%0 = acc.update_device varPtr(%value : memref<f32>) -> memref<f32>
 // expected-error at +1 {{wait attribute cannot appear with waitOperands}}
-acc.update wait(%cst: index) host(%value: memref<10xf32>) attributes {wait}
+acc.update wait(%cst: index) dataOperands(%0: memref<f32>) attributes {wait}
 
 // -----
 

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 463acd1fd60be..e46167e3fb29b 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -761,41 +761,45 @@ func.func @testdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf
 
 // -----
 
-func.func @testupdateop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
+func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
   %i64Value = arith.constant 1 : i64
   %i32Value = arith.constant 1 : i32
   %idxValue = arith.constant 1 : index
   %ifCond = arith.constant true
-  acc.update async(%i64Value: i64) host(%a: memref<10xf32>)
-  acc.update async(%i32Value: i32) host(%a: memref<10xf32>)
-  acc.update async(%i32Value: i32) host(%a: memref<10xf32>)
-  acc.update async(%idxValue: index) host(%a: memref<10xf32>)
-  acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) host(%a: memref<10xf32>)
-  acc.update if(%ifCond) host(%a: memref<10xf32>)
-  acc.update device_type(%i32Value : i32) host(%a: memref<10xf32>)
-  acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>)
-  acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {async}
-  acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {wait}
-  acc.update host(%a: memref<10xf32>) device(%b, %c : memref<10xf32>, memref<10x10xf32>) attributes {ifPresent}
+  %0 = acc.update_device varPtr(%a : memref<f32>) -> memref<f32>
+  %1 = acc.update_device varPtr(%b : memref<f32>) -> memref<f32>
+  %2 = acc.update_device varPtr(%c : memref<f32>) -> memref<f32>
+  
+  acc.update async(%i64Value: i64) dataOperands(%0: memref<f32>)
+  acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
+  acc.update async(%i32Value: i32) dataOperands(%0: memref<f32>)
+  acc.update async(%idxValue: index) dataOperands(%0: memref<f32>)
+  acc.update wait_devnum(%i64Value: i64) wait(%i32Value, %idxValue : i32, index) dataOperands(%0: memref<f32>)
+  acc.update if(%ifCond) dataOperands(%0: memref<f32>)
+  acc.update device_type(%i32Value : i32) dataOperands(%0: memref<f32>)
+  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>)
+  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {async}
+  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+  acc.update dataOperands(%0, %1, %2 : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
   return
 }
 
-// CHECK: func @testupdateop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
+// CHECK: func @testupdateop([[ARGA:%.*]]: memref<f32>, [[ARGB:%.*]]: memref<f32>, [[ARGC:%.*]]: memref<f32>) {
 // CHECK:   [[I64VALUE:%.*]] = arith.constant 1 : i64
 // CHECK:   [[I32VALUE:%.*]] = arith.constant 1 : i32
 // CHECK:   [[IDXVALUE:%.*]] = arith.constant 1 : index
 // CHECK:   [[IFCOND:%.*]] = arith.constant true
-// CHECK:   acc.update async([[I64VALUE]] : i64) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update async([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update async([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update async([[IDXVALUE]] : index) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update if([[IFCOND]]) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update device_type([[I32VALUE]] : i32) host([[ARGA]] : memref<10xf32>)
-// CHECK:   acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>)
-// CHECK:   acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {async}
-// CHECK:   acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {wait}
-// CHECK:   acc.update host([[ARGA]] : memref<10xf32>) device([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) attributes {ifPresent}
+// CHECK:   acc.update async([[I64VALUE]] : i64) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update async([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update async([[IDXVALUE]] : index) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update wait_devnum([[I64VALUE]] : i64) wait([[I32VALUE]], [[IDXVALUE]] : i32, index) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update if([[IFCOND]]) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update device_type([[I32VALUE]] : i32) dataOperands(%{{.*}} : memref<f32>)
+// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>)
+// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {async}
+// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {wait}
+// CHECK:   acc.update dataOperands(%{{.*}}, %{{.*}}, %{{.*}} : memref<f32>, memref<f32>, memref<f32>) attributes {ifPresent}
 
 // -----
 

diff  --git a/mlir/test/Target/LLVMIR/openacc-llvm.mlir b/mlir/test/Target/LLVMIR/openacc-llvm.mlir
index df3ce7a1d96d0..61a2d968247fb 100644
--- a/mlir/test/Target/LLVMIR/openacc-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openacc-llvm.mlir
@@ -118,17 +118,9 @@ llvm.func @testexitdataop(%arg0: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1
 
 // -----
 
-llvm.func @testupdateop(%arg0: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, %arg1: !llvm.ptr<f32>) {
-  %0 = llvm.mlir.constant(10 : index) : i64
-  %1 = llvm.mlir.null : !llvm.ptr<f32>
-  %2 = llvm.getelementptr %1[%0] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
-  %3 = llvm.ptrtoint %2 : !llvm.ptr<f32> to i64
-  %4 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-  %5 = llvm.mlir.undef : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
-  %6 = llvm.insertvalue %arg0, %5[0] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
-  %7 = llvm.insertvalue %4, %6[1] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
-  %8 = llvm.insertvalue %3, %7[2] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
-  acc.update host(%8 : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>) device(%arg1 : !llvm.ptr<f32>)
+llvm.func @testupdateop(%arg1: !llvm.ptr<f32>) {
+  %0 = acc.update_device varPtr(%arg1 : !llvm.ptr<f32>) -> !llvm.ptr<f32>
+  acc.update dataOperands(%0 : !llvm.ptr<f32>)
   llvm.return
 }
 
@@ -137,37 +129,26 @@ llvm.func @testupdateop(%arg0: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x
 // CHECK: [[LOCSTR:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};testupdateop;{{[0-9]*}};{{[0-9]*}};;\00", align 1
 // CHECK: [[LOCGLOBAL:@.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 {{[0-9]*}}, ptr [[LOCSTR]] }, align 8
 // CHECK: [[MAPNAME1:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1
-// CHECK: [[MAPNAME2:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1
-// CHECK: [[MAPTYPES:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i64] [i64 2, i64 1]
-// CHECK: [[MAPNAMES:@.*]] = private constant [{{[0-9]*}} x ptr] [ptr [[MAPNAME1]], ptr [[MAPNAME2]]]
-
-// CHECK: define void @testupdateop({ ptr, ptr, i64, [1 x i64], [1 x i64] } %{{.*}}, ptr [[SIMPLEPTR:%.*]])
-// CHECK: [[ARGBASE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x ptr], align 8
-// CHECK: [[ARG_ALLOCA:%.*]] = alloca [{{[0-9]*}} x ptr], align 8
-// CHECK: [[SIZE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i64], align 8
-
-// CHECK: [[ARGBASE:%.*]] = extractvalue %openacc_data %{{.*}}, 0
-// CHECK: [[ARG:%.*]] = extractvalue %openacc_data %{{.*}}, 1
-// CHECK: [[ARGSIZE:%.*]] = extractvalue %openacc_data %{{.*}}, 2
-// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARGBASE_ALLOCA]], i32 0, i32 0
-// CHECK: store { ptr, ptr, i64, [1 x i64], [1 x i64] } [[ARGBASE]], ptr [[ARGBASEGEP]], align 8
-// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARG_ALLOCA]], i32 0, i32 0
-// CHECK: store ptr [[ARG]], ptr [[ARGGEP]], align 8
-// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], ptr [[SIZE_ALLOCA]], i32 0, i32 0
-// CHECK: store i64 [[ARGSIZE]], ptr [[SIZEGEP]], align 4
-
-// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARGBASE_ALLOCA]], i32 0, i32 1
-// CHECK: store ptr [[SIMPLEPTR]], ptr [[ARGBASEGEP]], align 8
-// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARG_ALLOCA]], i32 0, i32 1
-// CHECK: store ptr [[SIMPLEPTR]], ptr [[ARGGEP]], align 8
-// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], ptr [[SIZE_ALLOCA]], i32 0, i32 1
-// CHECK: store i64 ptrtoint (ptr getelementptr (ptr, ptr null, i32 1) to i64), ptr [[SIZEGEP]], align 4
-
-// CHECK: [[ARGBASE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARGBASE_ALLOCA]], i32 0, i32 0
-// CHECK: [[ARG_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x ptr], ptr [[ARG_ALLOCA]], i32 0, i32 0
-// CHECK: [[SIZE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i64], ptr [[SIZE_ALLOCA]], i32 0, i32 0
-
-// CHECK: call void @__tgt_target_data_update_mapper(ptr [[LOCGLOBAL]], i64 -1, i32 2, ptr [[ARGBASE_ALLOCA_GEP]], ptr [[ARG_ALLOCA_GEP]], ptr [[SIZE_ALLOCA_GEP]], ptr [[MAPTYPES]], ptr [[MAPNAMES]], ptr null)
+// CHECK: [[MAPTYPES:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i64] [i64 1]
+// CHECK: [[MAPNAMES:@.*]] = private constant [{{[0-9]*}} x ptr] [ptr [[MAPNAME1]]]
+
+// CHECK: define void @testupdateop(ptr %[[SIMPLEPTR:.*]])
+// CHECK: %[[OFFLOAD_BASEPTRS:.*]] = alloca [{{[0-9]*}} x ptr]
+// CHECK: %[[OFFLOAD_PTRS:.*]] = alloca [{{[0-9]*}} x ptr]
+// CHECK: %[[OFFLOAS_SIZES:.*]] = alloca [{{[0-9]*}} x i64]
+
+// CHECK: %[[BASEGEP:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_BASEPTRS]], i32 0, i32 0
+// CHECK: store ptr %[[SIMPLEPTR]], ptr %[[BASEGEP]]
+// CHECK: %[[ARGGEP:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_PTRS]], i32 0, i32 0
+// CHECK: store ptr %[[SIMPLEPTR]], ptr %[[ARGGEP]]
+// CHECK: %[[SIZEGEP:.*]] = getelementptr inbounds [1 x i64], ptr %[[OFFLOAS_SIZES]], i32 0, i32 0
+// CHECK: store i64 ptrtoint (ptr getelementptr (ptr, ptr null, i32 1) to i64), ptr %[[SIZEGEP]]
+
+// CHECK: %[[OFFLOAD_BASEPTRS_GEP:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_BASEPTRS]], i32 0, i32 0
+// CHECK: %[[OFFLOAD_PTRS_GEP:.*]] = getelementptr inbounds [1 x ptr], ptr %[[OFFLOAD_PTRS]], i32 0, i32 0
+// CHECK: %[[OFFLOAS_SIZES_GEP:.*]] = getelementptr inbounds [1 x i64], ptr %[[OFFLOAS_SIZES]], i32 0, i32 0
+
+// CHECK: call void @__tgt_target_data_update_mapper(ptr [[LOCGLOBAL]], i64 -1, i32 1, ptr %[[OFFLOAD_BASEPTRS_GEP]], ptr %[[OFFLOAD_PTRS_GEP]], ptr %[[OFFLOAS_SIZES_GEP]], ptr [[MAPTYPES]], ptr [[MAPNAMES]], ptr null)
 
 // CHECK: declare void @__tgt_target_data_update_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #0
 


        


More information about the flang-commits mailing list