[flang-commits] [flang] [mlir] [mlir][acc] Improve implicit deviceptr detection for alias (PR #195934)

Razvan Lupusoru via flang-commits flang-commits at lists.llvm.org
Tue May 5 13:48:14 PDT 2026


https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/195934

The ACCImplicitData automatically is able to use deviceptr clause when variable is detected as being device data. However, it was missing check for own `acc declare deviceptr` attribute.

>From 123050845961c8adac39e3233975668960b867be Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Tue, 5 May 2026 13:47:15 -0700
Subject: [PATCH] [mlir][acc] Improve implicit deviceptr detection for alias

The ACCImplicitData automatically is able to use deviceptr
clause when variable is detected as being device data. However,
it was missing check for own `acc declare deviceptr` attribute.
---
 .../Transforms/OpenACC/acc-implicit-data.fir  | 33 +++++++++++++
 .../mlir/Dialect/OpenACC/OpenACCUtils.h       |  4 +-
 .../OpenACC/Transforms/ACCImplicitData.cpp    | 23 +++++----
 .../Dialect/OpenACC/Utils/OpenACCUtils.cpp    | 38 +++++++++------
 .../Dialect/OpenACC/acc-implicit-data.mlir    | 28 +++++++++++
 .../Dialect/OpenACC/OpenACCUtilsTest.cpp      | 47 +++++++++++++++++++
 6 files changed, 150 insertions(+), 23 deletions(-)

diff --git a/flang/test/Transforms/OpenACC/acc-implicit-data.fir b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
index 050fe55747d23..18d9f310eabee 100644
--- a/flang/test/Transforms/OpenACC/acc-implicit-data.fir
+++ b/flang/test/Transforms/OpenACC/acc-implicit-data.fir
@@ -394,3 +394,36 @@ func.func private @_FortranAAllocatableSetBounds(!fir.ref<!fir.box<none>>, i32,
 // CHECK-NOT: acc.copyin
 // CHECK: acc.deviceptr
 // CHECK-NOT: acc.copyout
+
+// -----
+
+// Test argument mapped with deviceptr but used not via data mapping.
+func.func @test_fir_declare_deviceptr_arg_in_parallel(%arg0: !fir.ref<!fir.array<10xf64>>) {
+  %c10 = arith.constant 10 : index
+  %shape = fir.shape %c10 : (index) -> !fir.shape<1>
+  %arr_decl = fir.declare %arg0(%shape) {acc.declare = #acc.declare<dataClause = acc_deviceptr>, uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xf64>>
+  %arr_box = fir.embox %arr_decl(%shape) : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>) -> !fir.box<!fir.array<10xf64>>
+  %devptr = acc.deviceptr var(%arr_box : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
+  %token = acc.declare_enter dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
+  acc.parallel {
+    %addr = fir.box_addr %arr_box : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
+    %elem = fir.array_coor %arr_decl(%shape) %c10 : (!fir.ref<!fir.array<10xf64>>, !fir.shape<1>, index) -> !fir.ref<f64>
+    acc.yield
+  }
+  acc.declare_exit token(%token) dataOperands(%devptr : !fir.box<!fir.array<10xf64>>)
+  return
+}
+
+// CHECK-LABEL: func.func @test_fir_declare_deviceptr_arg_in_parallel
+// CHECK: %[[DECL:.*]] = fir.declare %{{.*}}{{.*}}{acc.declare = #acc.declare<dataClause = acc_deviceptr>{{.*}}
+// CHECK: %[[BOX:.*]] = fir.embox %[[DECL]]
+// CHECK: %[[DEVPTR:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {name = "a"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
+// CHECK: %[[IMPLICIT_BOX:.*]] = acc.deviceptr var(%[[BOX]] : !fir.box<!fir.array<10xf64>>) -> !fir.box<!fir.array<10xf64>> {implicit = true, name = "a"}
+// CHECK: %[[IMPLICIT_REF:.*]] = acc.deviceptr varPtr(%[[DECL]] : !fir.ref<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>> {implicit = true, name = "a"}
+// CHECK: acc.parallel dataOperands(%[[IMPLICIT_BOX]], %[[IMPLICIT_REF]] : !fir.box<!fir.array<10xf64>>, !fir.ref<!fir.array<10xf64>>) {
+// CHECK: fir.box_addr %[[IMPLICIT_BOX]] : (!fir.box<!fir.array<10xf64>>) -> !fir.ref<!fir.array<10xf64>>
+// CHECK: fir.array_coor %[[IMPLICIT_REF]]
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : !fir.box<!fir.array<10xf64>>)
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
index c26ddbd54f1b9..5a8e4e362108d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtils.h
@@ -78,7 +78,9 @@ bool isValidSymbolUse(mlir::Operation *user, mlir::SymbolRefAttr symbol,
 
 /// Check if a value represents device data.
 /// This checks if the value represents device data via the
-/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces.
+/// MappableType, PointerLikeType, and GlobalVariableOpInterface interfaces,
+/// and whether the defining operation carries `acc.declare` with the deviceptr
+/// clause.
 /// \param val The value to check
 /// \return true if the value is device data, false otherwise
 bool isDeviceValue(mlir::Value val);
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
index 95c8d1076ccb0..2de714ffcbc35 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitData.cpp
@@ -312,8 +312,13 @@ Operation *ACCImplicitData::getOriginalDataClauseOpForAlias(
       // Only accept clauses that guarantee that the alias is present.
       if (isa<acc::CopyinOp, acc::CreateOp, acc::PresentOp, acc::NoCreateOp,
               acc::DevicePtrOp>(dataClauseOp))
-        if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust())
+        if (aliasAnalysis.alias(acc::getVar(dataClauseOp), var).isMust()) {
+          LLVM_DEBUG(llvm::dbgs()
+                         << "Using existing data clause:\n\t" << *dataClauseOp
+                         << "\n\tas reference when processing var:\n\t" << var
+                         << "\n";);
           return dataClauseOp;
+        }
     }
   }
   return nullptr;
@@ -452,6 +457,15 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
       typeCategory, acc::VariableTypeCategory::aggregate);
   Location loc = computeConstructOp->getLoc();
 
+  if (acc::isDeviceValue(var)) {
+    // If the variable is device data, use deviceptr clause.
+    LLVM_DEBUG(llvm::dbgs() << "Using deviceptr clause because variable is "
+                               "device data\n");
+    return acc::DevicePtrOp::create(builder, loc, var,
+                                    /*structured=*/true, /*implicit=*/true,
+                                    accSupport.getVariableName(var));
+  }
+
   Operation *op = nullptr;
   op = getOriginalDataClauseOpForAlias(var, builder, computeConstructOp,
                                        dominatingDataClauses);
@@ -476,13 +490,6 @@ Operation *ACCImplicitData::generateDataClauseOpForCandidate(
                                   acc::getBounds(op));
   }
 
-  if (acc::isDeviceValue(var)) {
-    // If the variable is device data, use deviceptr clause.
-    return acc::DevicePtrOp::create(builder, loc, var,
-                                    /*structured=*/true, /*implicit=*/true,
-                                    accSupport.getVariableName(var));
-  }
-
   if (isScalar) {
     if (enableImplicitReductionCopy &&
         acc::isOnlyUsedByReductionClauses(var,
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
index f20ace4398696..411b0a7a4457d 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp
@@ -245,23 +245,33 @@ bool mlir::acc::isDeviceValue(mlir::Value val) {
     if (pointerLikeTy.isDeviceData(val))
       return true;
 
+  mlir::Operation *defOp = val.getDefiningOp();
+  if (!defOp)
+    return false;
+
+  // `acc.declare` with deviceptr marks data that is already associated with
+  // the device.
+  if (auto declareAttr = defOp->getAttrOfType<mlir::acc::DeclareAttr>(
+          mlir::acc::getDeclareAttrName()))
+    if (declareAttr.getDataClause().getValue() ==
+        mlir::acc::DataClause::acc_deviceptr)
+      return true;
+
   // Handle operations that access a partial entity - check if the base entity
   // is device data.
-  if (auto *defOp = val.getDefiningOp()) {
-    if (auto partialAccess =
-            dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
-      if (mlir::Value base = partialAccess.getBaseEntity())
-        return isDeviceValue(base);
-    }
+  if (auto partialAccess =
+          dyn_cast<mlir::acc::PartialEntityAccessOpInterface>(defOp)) {
+    if (mlir::Value base = partialAccess.getBaseEntity())
+      return isDeviceValue(base);
+  }
 
-    // Handle address_of - check if the referenced global is device data.
-    if (auto addrOfIface =
-            dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
-      auto symbol = addrOfIface.getSymbol();
-      if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
-              mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
-        return global.isDeviceData();
-    }
+  // Handle address_of - check if the referenced global is device data.
+  if (auto addrOfIface =
+          dyn_cast<mlir::acc::AddressOfGlobalOpInterface>(defOp)) {
+    auto symbol = addrOfIface.getSymbol();
+    if (auto global = mlir::SymbolTable::lookupNearestSymbolFrom<
+            mlir::acc::GlobalVariableOpInterface>(defOp, symbol))
+      return global.isDeviceData();
   }
 
   return false;
diff --git a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
index df0dbbfee8b1d..3b6b5e1ade5e0 100644
--- a/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
+++ b/mlir/test/Dialect/OpenACC/acc-implicit-data.mlir
@@ -259,3 +259,31 @@ func.func @test_device_global_in_parallel() {
 // CHECK: acc.deviceptr varPtr({{.*}} : memref<10xf32, #gpu.address_space<global>>) -> memref<10xf32, #gpu.address_space<global>> {implicit = true, name = ""}
 // CHECK-NOT: acc.copyin
 // CHECK-NOT: acc.copyout
+
+// -----
+
+// Test memref.view tagged with acc.declare deviceptr and used directly in region.
+func.func @test_declare_deviceptr_arg_in_parallel(%arg0: memref<?xi8>) {
+  %c0 = arith.constant 0 : index
+  %view = memref.view %arg0[%c0][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
+  %devptr = acc.deviceptr varPtr(%view : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
+  %token = acc.declare_enter dataOperands(%devptr : memref<10xf32>)
+  acc.parallel {
+    %c0_1 = arith.constant 0 : index
+    %load = memref.load %arg0[%c0_1] : memref<?xi8>
+    acc.yield
+  }
+  acc.declare_exit token(%token) dataOperands(%devptr : memref<10xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @test_declare_deviceptr_arg_in_parallel
+// CHECK: %[[VIEW:.*]] = memref.view %{{.*}}[{{.*}}][] {acc.declare = #acc.declare<dataClause = acc_deviceptr>} : memref<?xi8> to memref<10xf32>
+// CHECK: %[[DEVPTR:.*]] = acc.deviceptr varPtr(%[[VIEW]] : memref<10xf32>) -> memref<10xf32> {name = "arg0"}
+// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DEVPTR]] : memref<10xf32>)
+// CHECK: %[[IMPLICIT_DEVPTR:.*]] = acc.deviceptr varPtr(%{{.*}} : memref<?xi8>) -> memref<?xi8> {implicit = true, name = ""}
+// CHECK: acc.parallel dataOperands(%[[IMPLICIT_DEVPTR]] : memref<?xi8>) {
+// CHECK: memref.load %[[IMPLICIT_DEVPTR]][{{.*}}] : memref<?xi8>
+// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DEVPTR]] : memref<10xf32>)
+// CHECK-NOT: acc.copyin
+// CHECK-NOT: acc.copyout
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
index 489fe9108c04e..30c6f4312efaf 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsTest.cpp
@@ -1398,6 +1398,33 @@ TEST_F(OpenACCUtilsTest, getDominatingDataClausesEmpty) {
 // isDeviceValue Tests
 //===----------------------------------------------------------------------===//
 
+namespace {
+static Value memrefViewFromBlockArgWithDeclare(OpBuilder &builder, Location loc,
+                                               MLIRContext *ctx,
+                                               DataClause clause,
+                                               ModuleOp module,
+                                               StringRef funcName) {
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(module.getBody());
+
+  auto i8BufTy = MemRefType::get({40}, builder.getI8Type());
+  auto viewTy = MemRefType::get({10}, builder.getI32Type());
+  auto funcType = builder.getFunctionType({i8BufTy}, {});
+  func::FuncOp funcOp = func::FuncOp::create(builder, loc, funcName, funcType);
+  Block *entry = funcOp.addEntryBlock();
+
+  builder.setInsertionPointToStart(entry);
+  Value buf = entry->getArgument(0);
+  Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
+  memref::ViewOp viewOp =
+      memref::ViewOp::create(builder, loc, viewTy, buf, c0, ValueRange{});
+  viewOp->setAttr(getDeclareAttrName(),
+                  DeclareAttr::get(ctx, DataClauseAttr::get(ctx, clause)));
+  func::ReturnOp::create(builder, loc);
+  return viewOp.getResult();
+}
+} // namespace
+
 TEST_F(OpenACCUtilsTest, isDeviceValueMemrefGlobalAddressSpace) {
   // Test that a memref with GPU global address space is considered device data
   auto gpuAddressSpace =
@@ -1525,6 +1552,26 @@ TEST_F(OpenACCUtilsTest, isDeviceValueGlobalWithoutGPUAddressSpace) {
   EXPECT_FALSE(isDeviceValue(val));
 }
 
+TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareDeviceptr) {
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(module->getBody());
+  Value val = memrefViewFromBlockArgWithDeclare(
+      b, loc, &context, DataClause::acc_deviceptr, module.get(),
+      "test_memref_view_declare_devptr");
+  EXPECT_TRUE(isDeviceValue(val));
+}
+
+TEST_F(OpenACCUtilsTest, isDeviceValueAccDeclareNonDeviceptr) {
+  OwningOpRef<ModuleOp> module = ModuleOp::create(loc);
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(module->getBody());
+  Value val = memrefViewFromBlockArgWithDeclare(
+      b, loc, &context, DataClause::acc_copyin, module.get(),
+      "test_memref_view_declare_copyin");
+  EXPECT_FALSE(isDeviceValue(val));
+}
+
 //===----------------------------------------------------------------------===//
 // isValidValueUse Tests
 //===----------------------------------------------------------------------===//



More information about the flang-commits mailing list