[Mlir-commits] [mlir] [mlir][acc] Update LegalizeDataValues pass to allow MappableType (PR #125134)
Razvan Lupusoru
llvmlistbot at llvm.org
Thu Jan 30 15:21:00 PST 2025
https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/125134
With the addition of new type interface MappableType, the LegalizeDataValues should not make the assumption it can obtain a pointer to the data (aka acc::getVarPtr() is now not guaranteed to get a value - acc::getVar() must be used instead).
Thus update the pass to ensure it handles any var used in its data clause operations.
>From d790ba6c352abd797e0b59575a1dc6ce568f7b23 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 30 Jan 2025 15:19:11 -0800
Subject: [PATCH] [mlir][acc] Update LegalizeDataValues pass to allow
MappableType
With the addition of new type interface MappableType, the
LegalizeDataValues should not make the assumption it can obtain a
pointer to the data (aka acc::getVarPtr() is now not guaranteed to
get a value - acc::getVar() must be used instead).
Thus update the pass to ensure it handles any var used in its data
clause operations.
---
.../OpenACC/Transforms/LegalizeDataValues.cpp | 24 +++++++++----------
1 file changed, 12 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 026b309ce4969d..a553653c73479b 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -36,17 +36,17 @@ static bool insideAccComputeRegion(mlir::Operation *op) {
return false;
}
-static void collectPtrs(mlir::ValueRange operands,
+static void collectVars(mlir::ValueRange operands,
llvm::SmallVector<std::pair<Value, Value>> &values,
bool hostToDevice) {
for (auto operand : operands) {
- Value varPtr = acc::getVarPtr(operand.getDefiningOp());
- Value accPtr = acc::getAccPtr(operand.getDefiningOp());
- if (varPtr && accPtr) {
+ Value var = acc::getVar(operand.getDefiningOp());
+ Value accVar = acc::getAccVar(operand.getDefiningOp());
+ if (var && accVar) {
if (hostToDevice)
- values.push_back({varPtr, accPtr});
+ values.push_back({var, accVar});
else
- values.push_back({accPtr, varPtr});
+ values.push_back({accVar, var});
}
}
}
@@ -75,16 +75,16 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
llvm::SmallVector<std::pair<Value, Value>> values;
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
- collectPtrs(op.getReductionOperands(), values, hostToDevice);
- collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+ collectVars(op.getReductionOperands(), values, hostToDevice);
+ collectVars(op.getPrivateOperands(), values, hostToDevice);
} else {
- collectPtrs(op.getDataClauseOperands(), values, hostToDevice);
+ collectVars(op.getDataClauseOperands(), values, hostToDevice);
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp>) {
- collectPtrs(op.getReductionOperands(), values, hostToDevice);
- collectPtrs(op.getPrivateOperands(), values, hostToDevice);
- collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
+ collectVars(op.getReductionOperands(), values, hostToDevice);
+ collectVars(op.getPrivateOperands(), values, hostToDevice);
+ collectVars(op.getFirstprivateOperands(), values, hostToDevice);
}
}
More information about the Mlir-commits
mailing list