[Mlir-commits] [mlir] [mlir][acc] Update LegalizeDataValues pass to allow MappableType (PR #125134)

Razvan Lupusoru llvmlistbot at llvm.org
Thu Jan 30 15:21:00 PST 2025


https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/125134

With the addition of new type interface MappableType, the LegalizeDataValues should not make the assumption it can obtain a pointer to the data (aka acc::getVarPtr() is now not guaranteed to get a value - acc::getVar() must be used instead).

Thus update the pass to ensure it handles any var used in its data clause operations.

>From d790ba6c352abd797e0b59575a1dc6ce568f7b23 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 30 Jan 2025 15:19:11 -0800
Subject: [PATCH] [mlir][acc] Update LegalizeDataValues pass to allow
 MappableType

With the addition of new type interface MappableType, the
LegalizeDataValues should not make the assumption it can obtain a
pointer to the data (aka acc::getVarPtr() is now not guaranteed to
get a value - acc::getVar() must be used instead).

Thus update the pass to ensure it handles any var used in its data
clause operations.
---
 .../OpenACC/Transforms/LegalizeDataValues.cpp | 24 +++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 026b309ce4969d..a553653c73479b 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -36,17 +36,17 @@ static bool insideAccComputeRegion(mlir::Operation *op) {
   return false;
 }
 
-static void collectPtrs(mlir::ValueRange operands,
+static void collectVars(mlir::ValueRange operands,
                         llvm::SmallVector<std::pair<Value, Value>> &values,
                         bool hostToDevice) {
   for (auto operand : operands) {
-    Value varPtr = acc::getVarPtr(operand.getDefiningOp());
-    Value accPtr = acc::getAccPtr(operand.getDefiningOp());
-    if (varPtr && accPtr) {
+    Value var = acc::getVar(operand.getDefiningOp());
+    Value accVar = acc::getAccVar(operand.getDefiningOp());
+    if (var && accVar) {
       if (hostToDevice)
-        values.push_back({varPtr, accPtr});
+        values.push_back({var, accVar});
       else
-        values.push_back({accPtr, varPtr});
+        values.push_back({accVar, var});
     }
   }
 }
@@ -75,16 +75,16 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
   llvm::SmallVector<std::pair<Value, Value>> values;
 
   if constexpr (std::is_same_v<Op, acc::LoopOp>) {
-    collectPtrs(op.getReductionOperands(), values, hostToDevice);
-    collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+    collectVars(op.getReductionOperands(), values, hostToDevice);
+    collectVars(op.getPrivateOperands(), values, hostToDevice);
   } else {
-    collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+    collectVars(op.getDataClauseOperands(), values, hostToDevice);
     if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp>) {
-      collectPtrs(op.getReductionOperands(), values, hostToDevice);
-      collectPtrs(op.getPrivateOperands(), values, hostToDevice);
-      collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
+      collectVars(op.getReductionOperands(), values, hostToDevice);
+      collectVars(op.getPrivateOperands(), values, hostToDevice);
+      collectVars(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 



More information about the Mlir-commits mailing list