[Mlir-commits] [flang] [mlir] [acc] Introduce varType to acc data clause operations (PR #119007)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 6 10:39:41 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

<details>
<summary>Changes</summary>

The acc data clause operations hold an operand named `varPtr`. This was intended to hold a pointer to a variable - where the element type of that pointer specifies the type of the variable. However, for both memref and llvm dialects, this assumption is not true. This is because memref element type for cases like memref<10xf32> is simply f32 and for LLVM, after opaque pointers, the variable type is no longer recoverable.

Thus, introduce varType to ensure that appropriate semantics are kept.

Both the parser and printer for this new type attribute allow it to not be specified in cases where a dialect's getElementType() applied to `varPtr`'s type has a recoverable type. And more specifically, for FIR, no changes are needed in the MLIR unit tests.

---

Patch is 71.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119007.diff


9 Files Affected:

- (modified) flang/lib/Lower/OpenACC.cpp (+5-3) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+65-66) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+39) 
- (modified) mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir (+9-9) 
- (modified) mlir/test/Dialect/OpenACC/canonicalize.mlir (+14-14) 
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+7-7) 
- (modified) mlir/test/Dialect/OpenACC/legalize-data.mlir (+16-16) 
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+113-113) 
- (modified) mlir/test/Target/LLVMIR/openacc-llvm.mlir (+10-10) 


``````````diff
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 878dccc4ecbc4b..75dcf6ec3e1107 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -139,6 +139,8 @@ createDataEntryOp(fir::FirOpBuilder &builder, mlir::Location loc,
   op.setStructured(structured);
   op.setImplicit(implicit);
   op.setDataClause(dataClause);
+  op.setVarType(mlir::cast<mlir::acc::PointerLikeType>(baseAddr.getType())
+                    .getElementType());
   op->setAttr(Op::getOperandSegmentSizeAttr(),
               builder.getDenseI32ArrayAttr(operandSegments));
   if (!asyncDeviceTypes.empty())
@@ -266,8 +268,8 @@ static void createDeclareDeallocFuncWithArg(
   if constexpr (std::is_same_v<ExitOp, mlir::acc::CopyoutOp> ||
                 std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
     builder.create<ExitOp>(entryOp.getLoc(), entryOp.getAccPtr(),
-                           entryOp.getVarPtr(), entryOp.getBounds(),
-                           entryOp.getAsyncOperands(),
+                           entryOp.getVarPtr(), entryOp.getVarType(),
+                           entryOp.getBounds(), entryOp.getAsyncOperands(),
                            entryOp.getAsyncOperandsDeviceTypeAttr(),
                            entryOp.getAsyncOnlyAttr(), entryOp.getDataClause(),
                            /*structured=*/false, /*implicit=*/false,
@@ -450,7 +452,7 @@ static void genDataExitOperations(fir::FirOpBuilder &builder,
                   std::is_same_v<ExitOp, mlir::acc::UpdateHostOp>)
       builder.create<ExitOp>(
           entryOp.getLoc(), entryOp.getAccPtr(), entryOp.getVarPtr(),
-          entryOp.getBounds(), entryOp.getAsyncOperands(),
+          entryOp.getVarType(), entryOp.getBounds(), entryOp.getAsyncOperands(),
           entryOp.getAsyncOperandsDeviceTypeAttr(), entryOp.getAsyncOnlyAttr(),
           entryOp.getDataClause(), structured, entryOp.getImplicit(),
           builder.getStringAttr(*entryOp.getName()));
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 8d7e27405cfa46..d089519d7fd808 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -381,17 +381,18 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
     OpenACC_Op<mnemonic, !listconcat(traits,
         [AttrSizedOperandSegments,
          MemoryEffects<[MemRead<OpenACC_CurrentDeviceIdResource>]>])> {
-  let arguments = !con(additionalArgs,
-                      (ins
-                       Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
-                       Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
-                       Variadic<IntOrIndex>:$asyncOperands,
-                       OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
-                       OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
-                       DefaultValuedAttr<OpenACC_DataClauseAttr,clause>:$dataClause,
-                       DefaultValuedAttr<BoolAttr, "true">:$structured,
-                       DefaultValuedAttr<BoolAttr, "false">:$implicit,
-                       OptionalAttr<StrAttr>:$name));
+  let arguments = !con(
+      additionalArgs,
+      (ins TypeAttr:$varType,
+          Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
+          Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
+          Variadic<IntOrIndex>:$asyncOperands,
+          OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+          OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+          DefaultValuedAttr<OpenACC_DataClauseAttr, clause>:$dataClause,
+          DefaultValuedAttr<BoolAttr, "true">:$structured,
+          DefaultValuedAttr<BoolAttr, "false">:$implicit,
+          OptionalAttr<StrAttr>:$name));
 
   let description = !strconcat(extraDescription, [{
     Description of arguments:
@@ -458,7 +459,7 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
   }];
 
   let assemblyFormat = [{
-    `varPtr` `(` $varPtr `:` type($varPtr) `)`
+    `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType)  `)`
     oilist(
         `varPtrPtr` `(` $varPtrPtr `:` type($varPtrPtr) `)`
       | `bounds` `(` $bounds `)`
@@ -469,32 +470,35 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
 
   let hasVerifier = 1;
 
-  let builders = [
-    OpBuilder<(ins "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
-        build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
-          bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+  let builders = [OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
+        build($_builder, $_state, varPtr.getType(), varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+            ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
+          /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+          /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
-      }]
-    >,
-    OpBuilder<(ins "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      "const ::llvm::Twine &":$name,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
-        build($_builder, $_state, varPtr.getType(), varPtr, /*varPtrPtr=*/{},
-          bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
+      }]>,
+                  OpBuilder<(ins "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit, "const ::llvm::Twine &":$name,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
+        build($_builder, $_state, varPtr.getType(), varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+            ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
+          /*varPtrPtr=*/{}, bounds, /*asyncOperands=*/{},
+          /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit),
           /*name=*/$_builder.getStringAttr(name));
-      }]
-    >
-  ];
+      }]>];
 }
 
 //===----------------------------------------------------------------------===//
@@ -794,63 +798,58 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
     }
   }];
 
-  let assemblyFormat = [{
-    `accPtr` `(` $accPtr `:` type($accPtr) `)`
-    oilist(
-        `bounds` `(` $bounds `)`
-      | `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
-      | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
-            type($asyncOperands), $asyncOperandsDeviceType) `)`
-    ) attr-dict
-  }];
-
   let hasVerifier = 1;
 }
 
-class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause> :
-    OpenACC_DataExitOp<mnemonic, clause,
-      "- `varPtr`: The address of variable to copy back to.",
-      [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
-                      MemWrite<OpenACC_RuntimeCounters>]>],
-      (ins Arg<OpenACC_PointerLikeTypeInterface,"Address of device variable",[MemRead]>:$accPtr,
-           Arg<OpenACC_PointerLikeTypeInterface,"Address of variable",[MemWrite]>:$varPtr)> {
+class OpenACC_DataExitOpWithVarPtr<string mnemonic, string clause>
+    : OpenACC_DataExitOp<
+          mnemonic, clause,
+          "- `varPtr`: The address of variable to copy back to.",
+          [MemoryEffects<[MemRead<OpenACC_RuntimeCounters>,
+                          MemWrite<OpenACC_RuntimeCounters>]>],
+          (ins Arg<OpenACC_PointerLikeTypeInterface,
+                   "Address of device variable", [MemRead]>:$accPtr,
+              Arg<OpenACC_PointerLikeTypeInterface,
+                  "Address of variable", [MemWrite]>:$varPtr,
+              TypeAttr:$varType)> {
   let assemblyFormat = [{
     `accPtr` `(` $accPtr `:` type($accPtr) `)`
     (`bounds` `(` $bounds^ `)` )?
     (`async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType)^ `)`)?
-    `to` `varPtr` `(` $varPtr `:` type($varPtr) `)`
+    `to` `varPtr` `(` $varPtr `:` custom<varPtrTypes>(type($varPtr), $varType) `)`
     attr-dict
   }];
 
-  let builders = [
-    OpBuilder<(ins "::mlir::Value":$accPtr,
-      "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+  let builders = [OpBuilder<(ins "::mlir::Value":$accPtr,
+                                "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
         build($_builder, $_state, accPtr, varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+          ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
           bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr);
-      }]
-    >,
-    OpBuilder<(ins "::mlir::Value":$accPtr,
-      "::mlir::Value":$varPtr,
-      "bool":$structured,
-      "bool":$implicit,
-      "const ::llvm::Twine &":$name,
-      CArg<"::mlir::ValueRange", "{}">:$bounds), [{
+      }]>,
+                  OpBuilder<(ins "::mlir::Value":$accPtr,
+                                "::mlir::Value":$varPtr, "bool":$structured,
+                                "bool":$implicit, "const ::llvm::Twine &":$name,
+                                CArg<"::mlir::ValueRange", "{}">:$bounds),
+                            [{
         build($_builder, $_state, accPtr, varPtr,
+          /*varType=*/::mlir::TypeAttr::get(
+          ::mlir::cast<::mlir::acc::PointerLikeType>(
+              varPtr.getType()).getElementType()),
           bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr,
           /*asyncOnly=*/nullptr, /*dataClause=*/nullptr,
           /*structured=*/$_builder.getBoolAttr(structured),
           /*implicit=*/$_builder.getBoolAttr(implicit),
           /*name=*/$_builder.getStringAttr(name));
-      }]
-    >
-  ];
+      }]>];
 }
 
 class OpenACC_DataExitOpNoVarPtr<string mnemonic, string clause> :
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 280260e0485bb5..4daba2679bd91c 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -18,6 +19,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/LogicalResult.h"
 
 using namespace mlir;
 using namespace acc;
@@ -190,6 +192,43 @@ static LogicalResult checkWaitAndAsyncConflict(Op op) {
   return success();
 }
 
+static ParseResult parsevarPtrTypes(mlir::OpAsmParser &parser,
+                                    mlir::Type &varPtrRawType,
+                                    mlir::TypeAttr &varTypeAttr) {
+  if (failed(parser.parseType(varPtrRawType))) {
+    return failure();
+  }
+
+  // If there is no comma, it means that the varType is implied from the
+  // element type of varPtr.
+  if (succeeded(parser.parseOptionalComma())) {
+    mlir::Type varType;
+    if (failed(parser.parseType(varType)))
+      return failure();
+    varTypeAttr = mlir::TypeAttr::get(varType);
+  } else {
+    varTypeAttr = mlir::TypeAttr::get(
+        mlir::cast<mlir::acc::PointerLikeType>(varPtrRawType).getElementType());
+  }
+
+  return success();
+}
+
+static void printvarPtrTypes(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                             mlir::Type varPtrType,
+                             mlir::TypeAttr varTypeAttr) {
+  p.printType(varPtrType);
+  mlir::Type varType = varTypeAttr.getValue();
+
+  // Avoid printing the varType if it is already captured as the element type
+  // of varPtr's type.
+  if (mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType() !=
+      varType) {
+    p << ", ";
+    p.printType(varType);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // DataBoundsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
index d8e89f64f8bc04..d83baf3df114bf 100644
--- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
+++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -convert-openacc-to-scf -split-input-file | FileCheck %s
 
 func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -14,7 +14,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -28,7 +28,7 @@ func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.update_device varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.update if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -42,7 +42,7 @@ func.func @testupdateop(%a: memref<f32>, %ifCond: i1) -> () {
 
 func.func @update_true(%arg0: memref<f32>) {
   %true = arith.constant true
-  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
   acc.update if(%true) dataOperands(%0 : memref<f32>)
   return
 }
@@ -55,7 +55,7 @@ func.func @update_true(%arg0: memref<f32>) {
 
 func.func @update_false(%arg0: memref<f32>) {
   %false = arith.constant false
-  %0 = acc.update_device varPtr(%arg0 : memref<f32>) -> memref<f32>
+  %0 = acc.update_device varPtr(%arg0 : memref<f32>, f32) -> memref<f32>
   acc.update if(%false) dataOperands(%0 : memref<f32>)
   return
 }
@@ -67,7 +67,7 @@ func.func @update_false(%arg0: memref<f32>) {
 
 func.func @enter_data_true(%d1 : memref<f32>) {
   %true = arith.constant true
-  %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
   return
 }
@@ -80,7 +80,7 @@ func.func @enter_data_true(%d1 : memref<f32>) {
 
 func.func @enter_data_false(%d1 : memref<f32>) {
   %false = arith.constant false
-  %0 = acc.create varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
   return
 }
@@ -92,7 +92,7 @@ func.func @enter_data_false(%d1 : memref<f32>) {
 
 func.func @exit_data_true(%d1 : memref<f32>) {
   %true = arith.constant true
-  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%true) dataOperands(%0 : memref<f32>) attributes {async}
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -106,7 +106,7 @@ func.func @exit_data_true(%d1 : memref<f32>) {
 
 func.func @exit_data_false(%d1 : memref<f32>) {
   %false = arith.constant false
-  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%d1 : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%false) dataOperands(%0 : memref<f32>) attributes {async}
   acc.delete accPtr(%0 : memref<f32>)
   return
diff --git a/mlir/test/Dialect/OpenACC/canonicalize.mlir b/mlir/test/Dialect/OpenACC/canonicalize.mlir
index e43a27f6b9e89a..c5272f579c1d23 100644
--- a/mlir/test/Dialect/OpenACC/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenACC/canonicalize.mlir
@@ -2,7 +2,7 @@
 
 func.func @testenterdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant true
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -13,7 +13,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
 
 func.func @testenterdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant false
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -25,7 +25,7 @@ func.func @testenterdataop(%a: memref<f32>) -> () {
 
 func.func @testexitdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant true
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -37,7 +37,7 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
 
 func.func @testexitdataop(%a: memref<f32>) -> () {
   %ifCond = arith.constant false
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32>)
   return
@@ -49,8 +49,8 @@ func.func @testexitdataop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
-  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
   %ifCond = arith.constant true
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
@@ -61,8 +61,8 @@ func.func @testupdateop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testupdateop(%a: memref<f32>) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
-  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>)
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
+  acc.update_host accPtr(%0 : memref<f32>) to varPtr(%a : memref<f32>, f32)
   %ifCond = arith.constant false
   acc.update if(%ifCond) dataOperands(%0: memref<f32>)
   return
@@ -74,7 +74,7 @@ func.func @testupdateop(%a: memref<f32>) -> () {
 // -----
 
 func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.create varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.create varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.enter_data if(%ifCond) dataOperands(%0 : memref<f32>)
   return
 }
@@ -85,7 +85,7 @@ func.func @testenterdataop(%a: memref<f32>, %ifCond: i1) -> () {
 // -----
 
 func.func @testexitdataop(%a: memref<f32>, %ifCond: i1) -> () {
-  %0 = acc.getdeviceptr varPtr(%a : memref<f32>) -> memref<f32>
+  %0 = acc.getdeviceptr varPtr(%a : memref<f32>, f32) -> memref<f32>
   acc.exit_data if(%ifCond) dataOperands(%0 : memref<f32>)
   acc.delete accPtr(%0 : memref<f32...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/119007


More information about the Mlir-commits mailing list