[Mlir-commits] [mlir] f4aec22 - [mlir][acc] Fix async only api on data entry operations (#122818)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 14 07:30:09 PST 2025


Author: Razvan Lupusoru
Date: 2025-01-14T07:30:05-08:00
New Revision: f4aec22e4776218d2d94f5357e19897bc2e726d4

URL: https://github.com/llvm/llvm-project/commit/f4aec22e4776218d2d94f5357e19897bc2e726d4
DIFF: https://github.com/llvm/llvm-project/commit/f4aec22e4776218d2d94f5357e19897bc2e726d4.diff

LOG: [mlir][acc] Fix async only api on data entry operations (#122818)

Data entry operations which are created from constructs with async
clause that has no value (aka `acc data copyin(var) async`) end up
holding an attribute array named to keep track of this information.
However, in cases where `async` clause is not used, calling
`hasAsyncOnly` ends up crashing since this attribute is not set.

Thus, to fix this issue, ensure that we check for this attribute before
trying to walk the attribute array.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index a47f70b168066e..c60eb5cc620a7d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -445,7 +445,10 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
     }
     /// Return true if the op has the async attribute for the given device_type.
     bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-      for (auto attr : getAsyncOnlyAttr()) {
+      mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+      if (!asyncOnly)
+        return false;
+      for (auto attr : asyncOnly) {
         auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
         if (deviceTypeAttr.getValue() == deviceType)
           return true;
@@ -817,7 +820,10 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
     }
     /// Return true if the op has the async attribute for the given device_type.
     bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-      for (auto attr : getAsyncOnlyAttr()) {
+      mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+      if (!asyncOnly)
+        return false;
+      for (auto attr : asyncOnly) {
         auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
         if (deviceTypeAttr.getValue() == deviceType)
           return true;

diff  --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index cfb8aa767b6f86..aa16421cbec512 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) {
   testAsyncOnly<SerialOp>(b, context, loc, dtypes);
 }
 
+template <typename Op>
+void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+                            llvm::SmallVector<DeviceType> &dtypes) {
+  auto memrefTy = MemRefType::get({}, b.getI32Type());
+  OwningOpRef<memref::AllocaOp> varPtrOp =
+      b.create<memref::AllocaOp>(loc, memrefTy);
+
+  TypedValue<PointerLikeType> varPtr =
+      cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+  OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+                                    /*structured=*/true, /*implicit=*/true);
+
+  EXPECT_FALSE(op->hasAsyncOnly());
+  for (auto d : dtypes)
+    EXPECT_FALSE(op->hasAsyncOnly(d));
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasAsyncOnly());
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
+  op->removeAsyncOnlyAttr();
+
+  auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op->hasAsyncOnly());
+  op->removeAsyncOnlyAttr();
+
+  auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op->hasAsyncOnly());
+
+  op->removeAsyncOnlyAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) {
+  testAsyncOnlyDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<PresentOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<CopyinOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<CreateOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<NoCreateOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<AttachOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
 template <typename Op>
 void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
                     llvm::SmallVector<DeviceType> &dtypes) {
@@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) {
   testAsyncValue<SerialOp>(b, context, loc, dtypes);
 }
 
+template <typename Op>
+void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+                             llvm::SmallVector<DeviceType> &dtypes) {
+  auto memrefTy = MemRefType::get({}, b.getI32Type());
+  OwningOpRef<memref::AllocaOp> varPtrOp =
+      b.create<memref::AllocaOp>(loc, memrefTy);
+
+  TypedValue<PointerLikeType> varPtr =
+      cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+  OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+                                    /*structured=*/true, /*implicit=*/true);
+
+  mlir::Value empty;
+  EXPECT_EQ(op->getAsyncValue(), empty);
+  for (auto d : dtypes)
+    EXPECT_EQ(op->getAsyncValue(d), empty);
+
+  OwningOpRef<arith::ConstantIndexOp> val =
+      b.create<arith::ConstantIndexOp>(loc, 1);
+  auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+  op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+  op->getAsyncOperandsMutable().assign(val->getResult());
+  EXPECT_EQ(op->getAsyncValue(), empty);
+  EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
+
+  op->getAsyncOperandsMutable().clear();
+  op->removeAsyncOperandsDeviceTypeAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) {
+  testAsyncValueDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<PresentOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<CopyinOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<CreateOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<NoCreateOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<AttachOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
 template <typename Op>
 void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
                         llvm::SmallVector<DeviceType> &dtypes,


        


More information about the Mlir-commits mailing list