[Mlir-commits] [flang] [mlir] [mlir][openacc][flang] Simplify gang, vector and worker representation (PR #77667)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 10 10:53:36 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-openacc

@llvm/pr-subscribers-mlir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

The IR representation for gang, vector and worker has grown with the support for device_type. This patch simplify the IR representation for gang, vector and worker information on the acc.loop operation.

When the only the keyword is present without any values, the information is printed at the same place than when there is values. The device_type is omitted if there is no values and it is equal to None. Otherwise the full information is displayed. First the keyword only device_type information and then the values with their device_type. 

---

Patch is 29.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77667.diff


7 Files Affected:

- (modified) flang/test/Lower/OpenACC/acc-kernels-loop.f90 (+6-6) 
- (modified) flang/test/Lower/OpenACC/acc-loop.f90 (+12-6) 
- (modified) flang/test/Lower/OpenACC/acc-parallel-loop.f90 (+6-6) 
- (modified) flang/test/Lower/OpenACC/acc-serial-loop.f90 (+6-6) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+9-11) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+228-47) 
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+20-20) 


``````````diff
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index b17f2e2c80b20f..755111a69467bf 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index e7f65770498fe2..42e14afb35f522 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -67,10 +67,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop gang {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -109,10 +109,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop vector {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {vector = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop vector(128)
   DO i = 1, n
@@ -141,10 +141,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop worker {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {worker = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop worker(128)
   DO i = 1, n
@@ -320,4 +320,10 @@ program acc_loop
 ! CHECK: acc.loop
 ! CHECK: fir.do_loop
 
+  !$acc loop gang device_type(nvidia) gang(8)
+  DO i = 1, n
+  END DO
+
+! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
+
 end program
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index e9150a71f3826b..faef8517850e0d 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -512,10 +512,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -565,10 +565,10 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -606,10 +606,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 6041e7fb1b4906..9333761e4c2962 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -447,10 +447,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -500,10 +500,10 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -541,10 +541,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e6954062a50e0c..24f129d92805c0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1483,7 +1483,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
     Example:
 
     ```mlir
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -1492,10 +1492,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
         }
       }
       acc.yield
-    } attributes {
-      collapse = [3], gang = [#acc.device_type<none>],
-      vector = [#acc.device_type<none>]
-    }
+    } attributes { collapse = [3] }
     ```
   }];
 
@@ -1613,13 +1610,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` `(` custom<GangClause>($gangOperands, type($gangOperands),
+        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
-            $gangOperandsSegments) `)`
-      | `worker` `` `(` custom<DeviceTypeOperands>($workerNumOperands,
-            type($workerNumOperands), $workerNumOperandsDeviceType) `)`
-      | `vector` `` `(` custom<DeviceTypeOperands>($vectorOperands,
-            type($vectorOperands), $vectorOperandsDeviceType) `)`
+            $gangOperandsSegments, $gang)
+      | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $workerNumOperands, type($workerNumOperands),
+            $workerNumOperandsDeviceType, $worker)
+      | `vector` `` custom<DeviceTypeOperandsWithKeywordOnly>($vectorOperands,
+            type($vectorOperands), $vectorOperandsDeviceType, $vector)
       | `private` `(` custom<SymOperandList>(
             $privateOperands, type($privateOperands), $privatizations) `)`
       | `tile` `(` custom<DeviceTypeOperandsWithSegment>($tileOperands,
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c53673fa426038..bf3264b5da9802 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -921,6 +921,12 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
   return success();
 }
 
+static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
+  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+    p << " [" << attr << "]";
+}
+
 static void printDeviceTypeOperandsWithSegment(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -937,10 +943,7 @@ static void printDeviceTypeOperandsWithSegment(
       ++opIdx;
     }
     p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
 }
 
@@ -978,11 +981,120 @@ printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
     if (i != 0)
       p << ", ";
     p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
+  }
+}
+
+static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::ArrayAttr &keywordOnlyDeviceType) {
+
+  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  keywordOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(attributes.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  return success();
+}
+
+bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+  p << "[";
+  for (unsigned i = 0; i < deviceTypes.value().size(); ++i) {
+    if (i != 0)
+      p << ", ";
     auto deviceTypeAttr =
         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    p << deviceTypeAttr;
+  }
+  p << "]";
+}
+
+static void printDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
+    mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
+    std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
+      keywordOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnlyDeviceTypes);
+
+  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
+  p << ")";
 }
 
 //===----------------------------------------------------------------------===//
@@ -1215,7 +1327,7 @@ static ParseResult parseGangValue(
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
     llvm::SmallVectorImpl<Type> &types,
     llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
-    bool &needComa, bool &newValue) {
+    bool &needCommaBetweenValues, bool &newValue) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (parser.parseEqual())
       return failure();
@@ -1223,7 +1335,7 @@ static ParseResult parseGangValue(
         parser.parseColonType(types.emplace_back()))
       return failure();
     attributes.push_back(gangArgType);
-    needComa = true;
+    needCommaBetweenValues = true;
     newValue = true;
   }
   return success();
@@ -1233,11 +1345,37 @@ static ParseResult parseGangClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
     llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
-    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments) {
-  llvm::SmallVector<GangArgTypeAttr> attributes;
-  llvm::SmallVector<DeviceTypeAttr> deviceTypeAttributes;
+    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
+    mlir::ArrayAttr &gangOnlyDeviceType) {
+  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
   llvm::SmallVector<int32_t> seg;
-  bool needComa = false;
+  bool needCommaBetweenValues = false;
+  bool needCommaBeforeOperands = false;
+
+  // Gang only keyword
+  if (failed(parser.parseOptionalLParen())) {
+    gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gangOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse gang only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  gangOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
 
   auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
                                                 mlir::acc::GangArgType::Num);
@@ -1247,6 +1385,11 @@ static ParseResult parseGangClause(
       parser.getContext(), mlir::acc::GangArgType::Static);
 
   do {
+    if (needCommaBeforeOperands) {
+      needCommaBeforeOperands = false;
+      continue;
+    }
+
     if (failed(parser.parseLBrace()))
       return failure();
 
@@ -1254,7 +1397,7 @@ static ParseResult parseGangClause(
     while (true) {
       bool newValue = false;
       bool needValue = false;
-      if (needComa) {
+      if (needCommaBetweenValues) {
         if (succeeded(parser.parseOptionalComma()))
           needValue = true; // expect a new value after comma.
         else
@@ -1262,16 +1405,19 @@ static ParseResult parseGangClause(
       }
 
       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argNum, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argNum,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argDim, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argDim,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argStatic, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argStatic,
+                                needCommaBetweenValues, newValue)))
         return failure();
 
       if (!newValue && needValue) {
@@ -1305,13 +1451,18 @@ static ParseResult parseGangClause(
 
   } while (succeeded(parser.parseOptionalComma()));
 
-  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
-                                               attributes.end());
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
+                                               gangArgTypeAttributes.end());
   gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
+  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
+
+  llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
+      gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
+  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
 
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttr(
-      deviceTypeAttributes.begin(), deviceTypeAttributes.end());
-  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttr);
   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
   return success();
 }
@@ -1320,33 +1471,63 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
                      mlir::OperandRange operands, mlir::TypeRange types,
                      std::optional<mlir::ArrayAttr> gangArgTypes,
                      std::optional<mlir::ArrayAttr> deviceTypes,
-                     std::optional<mlir::DenseI32ArrayAttr> segments) {
-  unsigned opIdx = 0;
-  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
-    if (i != 0)
-      p << ", ";
-    p << "{";
-    for (int32_t j = 0; j < (*segments)[i]; ++j) {
-      if (j != 0)
+                     std::optional<mlir::DenseI32ArrayAttr> segments,
+                     std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+      gangOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes)) {
+    p << "[";
+    for (unsigned i = 0; i < gangOnlyDeviceTypes.value().size(); ++i) {
+      if (i != 0)
         p << ", ";
-      auto gangArgTypeAttr =
-          mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
-      if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
-        p << LoopOp::getGangNumKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
-        p << LoopOp::getGangDimKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
-        p << LoopOp::getGangStaticKeyword();
-      p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
-      ++opIdx;
+      auto deviceTypeAttr =
+          mlir::dyn...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77667


More information about the Mlir-commits mailing list