[Mlir-commits] [mlir] f4aec22 - [mlir][acc] Fix async only api on data entry operations (#122818)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 14 07:30:09 PST 2025
Author: Razvan Lupusoru
Date: 2025-01-14T07:30:05-08:00
New Revision: f4aec22e4776218d2d94f5357e19897bc2e726d4
URL: https://github.com/llvm/llvm-project/commit/f4aec22e4776218d2d94f5357e19897bc2e726d4
DIFF: https://github.com/llvm/llvm-project/commit/f4aec22e4776218d2d94f5357e19897bc2e726d4.diff
LOG: [mlir][acc] Fix async only api on data entry operations (#122818)
Data entry operations which are created from constructs with async
clause that has no value (aka `acc data copyin(var) async`) end up
holding an attribute array named to keep track of this information.
However, in cases where `async` clause is not used, calling
`hasAsyncOnly` ends up crashing since this attribute is not set.
Thus, to fix this issue, ensure that we check for this attribute before
trying to walk the attribute array.
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index a47f70b168066e..c60eb5cc620a7d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -445,7 +445,10 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
}
/// Return true if the op has the async attribute for the given device_type.
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- for (auto attr : getAsyncOnlyAttr()) {
+ mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+ if (!asyncOnly)
+ return false;
+ for (auto attr : asyncOnly) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
@@ -817,7 +820,10 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
}
/// Return true if the op has the async attribute for the given device_type.
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- for (auto attr : getAsyncOnlyAttr()) {
+ mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+ if (!asyncOnly)
+ return false;
+ for (auto attr : asyncOnly) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index cfb8aa767b6f86..aa16421cbec512 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) {
testAsyncOnly<SerialOp>(b, context, loc, dtypes);
}
+template <typename Op>
+void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+ OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+ /*structured=*/true, /*implicit=*/true);
+
+ EXPECT_FALSE(op->hasAsyncOnly());
+ for (auto d : dtypes)
+ EXPECT_FALSE(op->hasAsyncOnly(d));
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasAsyncOnly());
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
+ op->removeAsyncOnlyAttr();
+
+ auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op->hasAsyncOnly());
+ op->removeAsyncOnlyAttr();
+
+ auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op->hasAsyncOnly());
+
+ op->removeAsyncOnlyAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) {
+ testAsyncOnlyDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<PresentOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<CopyinOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<CreateOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<NoCreateOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<AttachOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
template <typename Op>
void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
llvm::SmallVector<DeviceType> &dtypes) {
@@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) {
testAsyncValue<SerialOp>(b, context, loc, dtypes);
}
+template <typename Op>
+void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+ OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+ /*structured=*/true, /*implicit=*/true);
+
+ mlir::Value empty;
+ EXPECT_EQ(op->getAsyncValue(), empty);
+ for (auto d : dtypes)
+ EXPECT_EQ(op->getAsyncValue(d), empty);
+
+ OwningOpRef<arith::ConstantIndexOp> val =
+ b.create<arith::ConstantIndexOp>(loc, 1);
+ auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+ op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op->getAsyncOperandsMutable().assign(val->getResult());
+ EXPECT_EQ(op->getAsyncValue(), empty);
+ EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
+
+ op->getAsyncOperandsMutable().clear();
+ op->removeAsyncOperandsDeviceTypeAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) {
+ testAsyncValueDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<PresentOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<CopyinOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<CreateOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<NoCreateOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<AttachOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
template <typename Op>
void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
llvm::SmallVector<DeviceType> &dtypes,
More information about the Mlir-commits
mailing list