[Mlir-commits] [mlir] [mlir][openacc][NFC] Cleanup hasOnly functions for device_type support (PR #78800)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 19 14:30:18 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Just a cleanup for all the `has.*Only()` function to avoid code duplication

---
Full diff: https://github.com/llvm/llvm-project/pull/78800.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+49-101) 


``````````diff
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bc03adbcae64df..a63d6afa0e8532 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
       *getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// device_type support helpers
+//===----------------------------------------------------------------------===//
+
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+                          mlir::acc::DeviceType deviceType) {
+  if (!hasDeviceTypeValues(arrayAttr))
+    return false;
+
+  for (auto attr : *arrayAttr) {
+    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+    if (deviceTypeAttr.getValue() == deviceType)
+      return true;
+  }
+
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+
+  p << "[";
+  llvm::interleaveComma(*deviceTypes, p,
+                        [&](mlir::Attribute attr) { p << attr; });
+  p << "]";
+}
+
 //===----------------------------------------------------------------------===//
 // DataBoundsOp
 //===----------------------------------------------------------------------===//
@@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
 }
 
 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::ParallelOp::getAsyncValue() {
@@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
 }
 
 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range ParallelOp::getWaitValues() {
@@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
   return success();
 }
 
-static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
-  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
-    return true;
-  return false;
-}
-
-static void printDeviceTypes(mlir::OpAsmPrinter &p,
-                             std::optional<mlir::ArrayAttr> deviceTypes) {
-  if (!hasDeviceTypeValues(deviceTypes))
-    return;
-
-  p << "[";
-  llvm::interleaveComma(*deviceTypes, p,
-                        [&](mlir::Attribute attr) { p << attr; });
-  p << "]";
-}
-
 static void printDeviceTypeOperandsWithKeywordOnly(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
 }
 
 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::SerialOp::getAsyncValue() {
@@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
 }
 
 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range SerialOp::getWaitValues() {
@@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
 }
 
 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value acc::KernelsOp::getAsyncValue() {
@@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
 }
 
 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range KernelsOp::getWaitValues() {
@@ -1646,11 +1640,7 @@ Value LoopOp::getDataOperand(unsigned i) {
 bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAuto_()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAuto_(), deviceType);
 }
 
 bool LoopOp::hasIndependent() {
@@ -1658,21 +1648,13 @@ bool LoopOp::hasIndependent() {
 }
 
 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getIndependent()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getIndependent(), deviceType);
 }
 
 bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getSeq()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getSeq(), deviceType);
 }
 
 mlir::Value LoopOp::getVectorValue() {
@@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
 bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getVector()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getVector(), deviceType);
 }
 
 mlir::Value LoopOp::getWorkerValue() {
@@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
 bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWorker()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWorker(), deviceType);
 }
 
 mlir::Operation::operand_range LoopOp::getTileValues() {
@@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
 bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
 
 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getGang()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getGang(), deviceType);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
 }
 
 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getAsyncOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getAsyncOnly(), deviceType);
 }
 
 mlir::Value DataOp::getAsyncValue() {
@@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
 bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
 
 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
-  if (auto arrayAttr = getWaitOnly()) {
-    if (findSegment(*arrayAttr, deviceType))
-      return true;
-  }
-  return false;
+  return hasDeviceType(getWaitOnly(), deviceType);
 }
 
 mlir::Operation::operand_range DataOp::getWaitValues() {
@@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
 // RoutineOp
 //===----------------------------------------------------------------------===//
 
-static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
-                          mlir::acc::DeviceType deviceType) {
-  if (!hasDeviceTypeValues(arrayAttr))
-    return false;
-
-  for (auto attr : *arrayAttr) {
-    auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
-    if (deviceTypeAttr.getValue() == deviceType)
-      return true;
-  }
-
-  return false;
-}
-
 static unsigned getParallelismForDeviceType(acc::RoutineOp op,
                                             acc::DeviceType dtype) {
   unsigned parallelism = 0;

``````````

</details>


https://github.com/llvm/llvm-project/pull/78800


More information about the Mlir-commits mailing list