[flang-commits] [flang] 40f5f90 - [mlir][openacc][flang] Simplify gang, vector and worker representation (#77667)

via flang-commits flang-commits at lists.llvm.org
Thu Jan 11 13:02:11 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-01-11T13:02:06-08:00
New Revision: 40f5f90507725c1f93779c0e6e8d5a13854b4a3f

URL: https://github.com/llvm/llvm-project/commit/40f5f90507725c1f93779c0e6e8d5a13854b4a3f
DIFF: https://github.com/llvm/llvm-project/commit/40f5f90507725c1f93779c0e6e8d5a13854b4a3f.diff

LOG: [mlir][openacc][flang] Simplify gang, vector and worker representation (#77667)

The IR representation for gang, vector and worker has grown with the
support for device_type. This patch simplify the IR representation for
gang, vector and worker information on the acc.loop operation.

When the only the keyword is present without any values, the information
is printed at the same place than when there is values. The device_type
is omitted if there is no values and it is equal to None. Otherwise the
full information is displayed. First the keyword only device_type
information and then the values with their device_type.

Added: 
    

Modified: 
    flang/test/Lower/OpenACC/acc-kernels-loop.f90
    flang/test/Lower/OpenACC/acc-loop.f90
    flang/test/Lower/OpenACC/acc-parallel-loop.f90
    flang/test/Lower/OpenACC/acc-serial-loop.f90
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index b17f2e2c80b20f..755111a69467bf 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 

diff  --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index e7f65770498fe2..42e14afb35f522 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -67,10 +67,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop gang {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -109,10 +109,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop vector {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {vector = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop vector(128)
   DO i = 1, n
@@ -141,10 +141,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop worker {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {worker = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop worker(128)
   DO i = 1, n
@@ -320,4 +320,10 @@ program acc_loop
 ! CHECK: acc.loop
 ! CHECK: fir.do_loop
 
+  !$acc loop gang device_type(nvidia) gang(8)
+  DO i = 1, n
+  END DO
+
+! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
+
 end program

diff  --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index e9150a71f3826b..faef8517850e0d 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -512,10 +512,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -565,10 +565,10 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -606,10 +606,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 

diff  --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 6041e7fb1b4906..9333761e4c2962 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -447,10 +447,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -500,10 +500,10 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -541,10 +541,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e6954062a50e0c..24f129d92805c0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1483,7 +1483,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
     Example:
 
     ```mlir
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -1492,10 +1492,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
         }
       }
       acc.yield
-    } attributes {
-      collapse = [3], gang = [#acc.device_type<none>],
-      vector = [#acc.device_type<none>]
-    }
+    } attributes { collapse = [3] }
     ```
   }];
 
@@ -1613,13 +1610,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` `(` custom<GangClause>($gangOperands, type($gangOperands),
+        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
-            $gangOperandsSegments) `)`
-      | `worker` `` `(` custom<DeviceTypeOperands>($workerNumOperands,
-            type($workerNumOperands), $workerNumOperandsDeviceType) `)`
-      | `vector` `` `(` custom<DeviceTypeOperands>($vectorOperands,
-            type($vectorOperands), $vectorOperandsDeviceType) `)`
+            $gangOperandsSegments, $gang)
+      | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $workerNumOperands, type($workerNumOperands),
+            $workerNumOperandsDeviceType, $worker)
+      | `vector` `` custom<DeviceTypeOperandsWithKeywordOnly>($vectorOperands,
+            type($vectorOperands), $vectorOperandsDeviceType, $vector)
       | `private` `(` custom<SymOperandList>(
             $privateOperands, type($privateOperands), $privatizations) `)`
       | `tile` `(` custom<DeviceTypeOperandsWithSegment>($tileOperands,

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c53673fa426038..bf3264b5da9802 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -921,6 +921,12 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
   return success();
 }
 
+static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
+  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+    p << " [" << attr << "]";
+}
+
 static void printDeviceTypeOperandsWithSegment(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -937,10 +943,7 @@ static void printDeviceTypeOperandsWithSegment(
       ++opIdx;
     }
     p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
 }
 
@@ -978,11 +981,120 @@ printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
     if (i != 0)
       p << ", ";
     p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
+  }
+}
+
+static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::ArrayAttr &keywordOnlyDeviceType) {
+
+  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  keywordOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(attributes.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  return success();
+}
+
+bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+  p << "[";
+  for (unsigned i = 0; i < deviceTypes.value().size(); ++i) {
+    if (i != 0)
+      p << ", ";
     auto deviceTypeAttr =
         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    p << deviceTypeAttr;
+  }
+  p << "]";
+}
+
+static void printDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
+    mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
+    std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
+      keywordOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnlyDeviceTypes);
+
+  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
+  p << ")";
 }
 
 //===----------------------------------------------------------------------===//
@@ -1215,7 +1327,7 @@ static ParseResult parseGangValue(
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
     llvm::SmallVectorImpl<Type> &types,
     llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
-    bool &needComa, bool &newValue) {
+    bool &needCommaBetweenValues, bool &newValue) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (parser.parseEqual())
       return failure();
@@ -1223,7 +1335,7 @@ static ParseResult parseGangValue(
         parser.parseColonType(types.emplace_back()))
       return failure();
     attributes.push_back(gangArgType);
-    needComa = true;
+    needCommaBetweenValues = true;
     newValue = true;
   }
   return success();
@@ -1233,11 +1345,37 @@ static ParseResult parseGangClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
     llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
-    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments) {
-  llvm::SmallVector<GangArgTypeAttr> attributes;
-  llvm::SmallVector<DeviceTypeAttr> deviceTypeAttributes;
+    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
+    mlir::ArrayAttr &gangOnlyDeviceType) {
+  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
   llvm::SmallVector<int32_t> seg;
-  bool needComa = false;
+  bool needCommaBetweenValues = false;
+  bool needCommaBeforeOperands = false;
+
+  // Gang only keyword
+  if (failed(parser.parseOptionalLParen())) {
+    gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gangOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse gang only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  gangOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
 
   auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
                                                 mlir::acc::GangArgType::Num);
@@ -1247,6 +1385,11 @@ static ParseResult parseGangClause(
       parser.getContext(), mlir::acc::GangArgType::Static);
 
   do {
+    if (needCommaBeforeOperands) {
+      needCommaBeforeOperands = false;
+      continue;
+    }
+
     if (failed(parser.parseLBrace()))
       return failure();
 
@@ -1254,7 +1397,7 @@ static ParseResult parseGangClause(
     while (true) {
       bool newValue = false;
       bool needValue = false;
-      if (needComa) {
+      if (needCommaBetweenValues) {
         if (succeeded(parser.parseOptionalComma()))
           needValue = true; // expect a new value after comma.
         else
@@ -1262,16 +1405,19 @@ static ParseResult parseGangClause(
       }
 
       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argNum, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argNum,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argDim, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argDim,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argStatic, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argStatic,
+                                needCommaBetweenValues, newValue)))
         return failure();
 
       if (!newValue && needValue) {
@@ -1305,13 +1451,18 @@ static ParseResult parseGangClause(
 
   } while (succeeded(parser.parseOptionalComma()));
 
-  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
-                                               attributes.end());
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
+                                               gangArgTypeAttributes.end());
   gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
+  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
+
+  llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
+      gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
+  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
 
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttr(
-      deviceTypeAttributes.begin(), deviceTypeAttributes.end());
-  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttr);
   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
   return success();
 }
@@ -1320,33 +1471,63 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
                      mlir::OperandRange operands, mlir::TypeRange types,
                      std::optional<mlir::ArrayAttr> gangArgTypes,
                      std::optional<mlir::ArrayAttr> deviceTypes,
-                     std::optional<mlir::DenseI32ArrayAttr> segments) {
-  unsigned opIdx = 0;
-  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
-    if (i != 0)
-      p << ", ";
-    p << "{";
-    for (int32_t j = 0; j < (*segments)[i]; ++j) {
-      if (j != 0)
+                     std::optional<mlir::DenseI32ArrayAttr> segments,
+                     std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+      gangOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes)) {
+    p << "[";
+    for (unsigned i = 0; i < gangOnlyDeviceTypes.value().size(); ++i) {
+      if (i != 0)
         p << ", ";
-      auto gangArgTypeAttr =
-          mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
-      if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
-        p << LoopOp::getGangNumKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
-        p << LoopOp::getGangDimKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
-        p << LoopOp::getGangStaticKeyword();
-      p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
-      ++opIdx;
+      auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[i]);
+      p << deviceTypeAttr;
     }
+    p << "]";
+  }
 
-    p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  if (deviceTypes) {
+    unsigned opIdx = 0;
+    for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+      if (i != 0)
+        p << ", ";
+      p << "{";
+      for (int32_t j = 0; j < (*segments)[i]; ++j) {
+        if (j != 0)
+          p << ", ";
+        auto gangArgTypeAttr =
+            mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
+        if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
+          p << LoopOp::getGangNumKeyword();
+        else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
+          p << LoopOp::getGangDimKeyword();
+        else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
+          p << LoopOp::getGangStaticKeyword();
+        p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
+        ++opIdx;
+      }
+
+      p << "}";
+      auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+      if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+        p << " [" << (*deviceTypes)[i] << "]";
+    }
   }
+  p << ")";
 }
 
 bool hasDuplicateDeviceTypes(

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index ce5bfa490013e0..8fa37bc98294ce 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -11,7 +11,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
   %async = arith.constant 1 : i64
 
   acc.parallel async(%async: i64) {
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -25,7 +25,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
         }
       }
       acc.yield
-    } attributes { collapse = [3], collapseDeviceType = [#acc.device_type<none>], vector = [#acc.device_type<none>], gang = [#acc.device_type<none>]}
+    } attributes { collapse = [3], collapseDeviceType = [#acc.device_type<none>]}
     acc.yield
   }
 
@@ -38,7 +38,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
 //  CHECK-NEXT:   %{{.*}} = arith.constant 1 : index
 //  CHECK-NEXT:   [[ASYNC:%.*]] = arith.constant 1 : i64
 //  CHECK-NEXT:   acc.parallel async([[ASYNC]] : i64) {
-//  CHECK-NEXT:     acc.loop {
+//  CHECK-NEXT:     acc.loop gang vector {
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:         scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:           scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -52,7 +52,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
 //  CHECK-NEXT:         }
 //  CHECK-NEXT:       }
 //  CHECK-NEXT:       acc.yield
-//  CHECK-NEXT:     } attributes {collapse = [3], collapseDeviceType = [#acc.device_type<none>], gang = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
+//  CHECK-NEXT:     } attributes {collapse = [3], collapseDeviceType = [#acc.device_type<none>]}
 //  CHECK-NEXT:     acc.yield
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   return %{{.*}} : memref<10x10xf32>
@@ -138,9 +138,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
   acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
     %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32>
     acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
-      acc.loop {
+      acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
-          acc.loop {
+          acc.loop worker {
             scf.for %y = %lb to %c10 step %st {
               %axy = memref.load %a[%x, %y] : memref<10x10xf32>
               %bxy = memref.load %b[%x, %y] : memref<10x10xf32>
@@ -148,7 +148,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
               memref.store %tmp, %c[%y] : memref<10xf32>
             }
             acc.yield
-          } attributes {worker = [#acc.device_type<none>]}
+          }
 
           acc.loop {
             // for i = 0 to 10 step 1
@@ -163,7 +163,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
           } attributes {seq = [#acc.device_type<none>]}
         }
         acc.yield
-      } attributes {gang = [#acc.device_type<none>]}
+      }
       acc.yield
     }
     acc.terminator
@@ -181,9 +181,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK:        acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
 // CHECK-NEXT:     %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> 
 // CHECK-NEXT:     acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
-// CHECK-NEXT:       acc.loop {
+// CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
-// CHECK-NEXT:           acc.loop {
+// CHECK-NEXT:           acc.loop worker {
 // CHECK-NEXT:             scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
@@ -191,7 +191,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:               memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
 // CHECK-NEXT:             }
 // CHECK-NEXT:             acc.yield
-// CHECK-NEXT:           } attributes {worker = [#acc.device_type<none>]}
+// CHECK-NEXT:           }
 // CHECK-NEXT:           acc.loop {
 // CHECK-NEXT:             scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}] : memref<10xf32>
@@ -203,7 +203,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:           } attributes {seq = [#acc.device_type<none>]}
 // CHECK-NEXT:         }
 // CHECK-NEXT:         acc.yield
-// CHECK-NEXT:       } attributes {gang = [#acc.device_type<none>]}
+// CHECK-NEXT:       }
 // CHECK-NEXT:       acc.yield
 // CHECK-NEXT:     }
 // CHECK-NEXT:     acc.terminator
@@ -218,10 +218,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
   %i32Value = arith.constant 128 : i32
   %idxValue = arith.constant 8 : index
 
-  acc.loop {
+  acc.loop gang vector worker {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
-  } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>], gang = [#acc.device_type<none>]}
+  }
   acc.loop gang({num=%i64Value: i64}) {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
@@ -254,10 +254,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
   }
-  acc.loop gang({num=%i64Value: i64}) {
+  acc.loop gang({num=%i64Value: i64}) worker vector {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
-  } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+  }
   acc.loop gang({num=%i64Value: i64, static=%i64Value: i64}) worker(%i64Value: i64) vector(%i64Value: i64) {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
@@ -293,10 +293,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
 // CHECK:      [[I64VALUE:%.*]] = arith.constant 1 : i64
 // CHECK-NEXT: [[I32VALUE:%.*]] = arith.constant 128 : i32
 // CHECK-NEXT: [[IDXVALUE:%.*]] = arith.constant 8 : index
-// CHECK:      acc.loop {
+// CHECK:      acc.loop gang worker vector {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
-// CHECK-NEXT: } attributes {gang = [#acc.device_type<none>], vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+// CHECK-NEXT: }
 // CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
@@ -329,10 +329,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) {
+// CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) worker vector {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
-// CHECK-NEXT: } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+// CHECK-NEXT: }
 // CHECK:      acc.loop gang({num=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64}) worker([[I64VALUE]] : i64) vector([[I64VALUE]] : i64) {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield


        


More information about the flang-commits mailing list