[Mlir-commits] [flang] [mlir] [mlir][openacc][flang] Simplify gang, vector and worker representation (PR #77667)
    Valentin Clement バレンタイン クレメン 
    llvmlistbot at llvm.org
       
    Wed Jan 10 10:53:09 PST 2024
    
    
  
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/77667
The IR representation for gang, vector and worker has grown with the support for device_type. This patch simplify the IR representation for gang, vector and worker information on the acc.loop operation.
When the only the keyword is present without any values, the information is printed at the same place than when there is values. The device_type is omitted if there is no values and it is equal to None. Otherwise the full information is displayed. First the keyword only device_type information and then the values with their device_type. 
>From b02da126359e9593f9b82610467e82f733b3ab64 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 8 Jan 2024 11:51:28 -0800
Subject: [PATCH] [mlir][openacc][flang] Simplify gang, vector and worker
 representation
This patch simplify the IR representation for gang, vector and worker
information on the acc.loop operation.
The device_type is omitted when possible when it is equal to `None`.
---
 flang/test/Lower/OpenACC/acc-kernels-loop.f90 |  12 +-
 flang/test/Lower/OpenACC/acc-loop.f90         |  18 +-
 .../test/Lower/OpenACC/acc-parallel-loop.f90  |  12 +-
 flang/test/Lower/OpenACC/acc-serial-loop.f90  |  12 +-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  20 +-
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 275 +++++++++++++++---
 mlir/test/Dialect/OpenACC/ops.mlir            |  40 +--
 7 files changed, 287 insertions(+), 102 deletions(-)
diff --git a/flang/test/Lower/OpenACC/acc-kernels-loop.f90 b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
index b17f2e2c80b20f..755111a69467bf 100644
--- a/flang/test/Lower/OpenACC/acc-kernels-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-kernels-loop.f90
@@ -497,10 +497,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -550,10 +550,10 @@ subroutine acc_kernels_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
@@ -591,10 +591,10 @@ subroutine acc_kernels_loop
   END DO
 
 ! CHECK:      acc.kernels {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.terminator
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-loop.f90 b/flang/test/Lower/OpenACC/acc-loop.f90
index e7f65770498fe2..42e14afb35f522 100644
--- a/flang/test/Lower/OpenACC/acc-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-loop.f90
@@ -67,10 +67,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop gang {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {gang = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop gang(num: 8)
   DO i = 1, n
@@ -109,10 +109,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop vector {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {vector = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop vector(128)
   DO i = 1, n
@@ -141,10 +141,10 @@ program acc_loop
     a(i) = b(i)
   END DO
 
-!CHECK:      acc.loop {
+!CHECK:      acc.loop worker {
 !CHECK:        fir.do_loop
 !CHECK:        acc.yield
-!CHECK-NEXT: } attributes {worker = [#acc.device_type<none>]}{{$}}
+!CHECK-NEXT: }{{$}}
 
   !$acc loop worker(128)
   DO i = 1, n
@@ -320,4 +320,10 @@ program acc_loop
 ! CHECK: acc.loop
 ! CHECK: fir.do_loop
 
+  !$acc loop gang device_type(nvidia) gang(8)
+  DO i = 1, n
+  END DO
+
+! CHECK: acc.loop gang([#acc.device_type<none>], {num=%c8{{.*}} : i32} [#acc.device_type<nvidia>])
+
 end program
diff --git a/flang/test/Lower/OpenACC/acc-parallel-loop.f90 b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
index e9150a71f3826b..faef8517850e0d 100644
--- a/flang/test/Lower/OpenACC/acc-parallel-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-parallel-loop.f90
@@ -512,10 +512,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -565,10 +565,10 @@ subroutine acc_parallel_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -606,10 +606,10 @@ subroutine acc_parallel_loop
   END DO
 
 ! CHECK:      acc.parallel {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/flang/test/Lower/OpenACC/acc-serial-loop.f90 b/flang/test/Lower/OpenACC/acc-serial-loop.f90
index 6041e7fb1b4906..9333761e4c2962 100644
--- a/flang/test/Lower/OpenACC/acc-serial-loop.f90
+++ b/flang/test/Lower/OpenACC/acc-serial-loop.f90
@@ -447,10 +447,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop gang {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {gang = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -500,10 +500,10 @@ subroutine acc_serial_loop
     a(i) = b(i)
   END DO
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop vector {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {vector = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
@@ -541,10 +541,10 @@ subroutine acc_serial_loop
   END DO
 
 ! CHECK:      acc.serial {
-! CHECK:        acc.loop {
+! CHECK:        acc.loop worker {
 ! CHECK:          fir.do_loop
 ! CHECK:          acc.yield
-! CHECK-NEXT:   } attributes {worker = [#acc.device_type<none>]}{{$}}
+! CHECK-NEXT:   }{{$}}
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index e6954062a50e0c..24f129d92805c0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1483,7 +1483,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
     Example:
 
     ```mlir
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -1492,10 +1492,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
         }
       }
       acc.yield
-    } attributes {
-      collapse = [3], gang = [#acc.device_type<none>],
-      vector = [#acc.device_type<none>]
-    }
+    } attributes { collapse = [3] }
     ```
   }];
 
@@ -1613,13 +1610,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` `(` custom<GangClause>($gangOperands, type($gangOperands),
+        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
-            $gangOperandsSegments) `)`
-      | `worker` `` `(` custom<DeviceTypeOperands>($workerNumOperands,
-            type($workerNumOperands), $workerNumOperandsDeviceType) `)`
-      | `vector` `` `(` custom<DeviceTypeOperands>($vectorOperands,
-            type($vectorOperands), $vectorOperandsDeviceType) `)`
+            $gangOperandsSegments, $gang)
+      | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
+            $workerNumOperands, type($workerNumOperands),
+            $workerNumOperandsDeviceType, $worker)
+      | `vector` `` custom<DeviceTypeOperandsWithKeywordOnly>($vectorOperands,
+            type($vectorOperands), $vectorOperandsDeviceType, $vector)
       | `private` `(` custom<SymOperandList>(
             $privateOperands, type($privateOperands), $privatizations) `)`
       | `tile` `(` custom<DeviceTypeOperandsWithSegment>($tileOperands,
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c53673fa426038..bf3264b5da9802 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -921,6 +921,12 @@ static ParseResult parseDeviceTypeOperandsWithSegment(
   return success();
 }
 
+static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
+  auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+  if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+    p << " [" << attr << "]";
+}
+
 static void printDeviceTypeOperandsWithSegment(
     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
     mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
@@ -937,10 +943,7 @@ static void printDeviceTypeOperandsWithSegment(
       ++opIdx;
     }
     p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
 }
 
@@ -978,11 +981,120 @@ printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
     if (i != 0)
       p << ", ";
     p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
+  }
+}
+
+static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
+    mlir::ArrayAttr &keywordOnlyDeviceType) {
+
+  llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
+  bool needCommaBeforeOperands = false;
+
+  // Keyword only
+  if (failed(parser.parseOptionalLParen())) {
+    keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    keywordOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse keyword only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  keywordOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
+
+  if (needCommaBeforeOperands && failed(parser.parseComma()))
+    return failure();
+
+  llvm::SmallVector<DeviceTypeAttr> attributes;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        if (succeeded(parser.parseOptionalLSquare())) {
+          if (parser.parseAttribute(attributes.emplace_back()) ||
+              parser.parseRSquare())
+            return failure();
+        } else {
+          attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+              parser.getContext(), mlir::acc::DeviceType::None));
+        }
+        return success();
+      })))
+    return failure();
+
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
+  return success();
+}
+
+bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+  if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
+    return true;
+  return false;
+}
+
+static void printDeviceTypes(mlir::OpAsmPrinter &p,
+                             std::optional<mlir::ArrayAttr> deviceTypes) {
+  if (!hasDeviceTypeValues(deviceTypes))
+    return;
+  p << "[";
+  for (unsigned i = 0; i < deviceTypes.value().size(); ++i) {
+    if (i != 0)
+      p << ", ";
     auto deviceTypeAttr =
         mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+    p << deviceTypeAttr;
+  }
+  p << "]";
+}
+
+static void printDeviceTypeOperandsWithKeywordOnly(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
+    mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
+    std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && keywordOnlyDeviceTypes &&
+      keywordOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*keywordOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+
+  printDeviceTypes(p, keywordOnlyDeviceTypes);
+
+  if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  for (unsigned i = 0, e = deviceTypes->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << operands[i] << " : " << operands[i].getType();
+    printSingleDeviceType(p, (*deviceTypes)[i]);
   }
+  p << ")";
 }
 
 //===----------------------------------------------------------------------===//
@@ -1215,7 +1327,7 @@ static ParseResult parseGangValue(
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
     llvm::SmallVectorImpl<Type> &types,
     llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
-    bool &needComa, bool &newValue) {
+    bool &needCommaBetweenValues, bool &newValue) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (parser.parseEqual())
       return failure();
@@ -1223,7 +1335,7 @@ static ParseResult parseGangValue(
         parser.parseColonType(types.emplace_back()))
       return failure();
     attributes.push_back(gangArgType);
-    needComa = true;
+    needCommaBetweenValues = true;
     newValue = true;
   }
   return success();
@@ -1233,11 +1345,37 @@ static ParseResult parseGangClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
     llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
-    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments) {
-  llvm::SmallVector<GangArgTypeAttr> attributes;
-  llvm::SmallVector<DeviceTypeAttr> deviceTypeAttributes;
+    mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
+    mlir::ArrayAttr &gangOnlyDeviceType) {
+  llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
+  llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
   llvm::SmallVector<int32_t> seg;
-  bool needComa = false;
+  bool needCommaBetweenValues = false;
+  bool needCommaBeforeOperands = false;
+
+  // Gang only keyword
+  if (failed(parser.parseOptionalLParen())) {
+    gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
+        parser.getContext(), mlir::acc::DeviceType::None));
+    gangOnlyDeviceType =
+        ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
+    return success();
+  }
+
+  // Parse gang only attributes
+  if (succeeded(parser.parseOptionalLSquare())) {
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (parser.parseAttribute(
+                  gangOnlyDeviceTypeAttributes.emplace_back()))
+            return failure();
+          return success();
+        })))
+      return failure();
+    if (parser.parseRSquare())
+      return failure();
+    needCommaBeforeOperands = true;
+  }
 
   auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
                                                 mlir::acc::GangArgType::Num);
@@ -1247,6 +1385,11 @@ static ParseResult parseGangClause(
       parser.getContext(), mlir::acc::GangArgType::Static);
 
   do {
+    if (needCommaBeforeOperands) {
+      needCommaBeforeOperands = false;
+      continue;
+    }
+
     if (failed(parser.parseLBrace()))
       return failure();
 
@@ -1254,7 +1397,7 @@ static ParseResult parseGangClause(
     while (true) {
       bool newValue = false;
       bool needValue = false;
-      if (needComa) {
+      if (needCommaBetweenValues) {
         if (succeeded(parser.parseOptionalComma()))
           needValue = true; // expect a new value after comma.
         else
@@ -1262,16 +1405,19 @@ static ParseResult parseGangClause(
       }
 
       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argNum, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argNum,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argDim, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argDim,
+                                needCommaBetweenValues, newValue)))
         return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
-                                gangOperands, gangOperandsType, attributes,
-                                argStatic, needComa, newValue)))
+                                gangOperands, gangOperandsType,
+                                gangArgTypeAttributes, argStatic,
+                                needCommaBetweenValues, newValue)))
         return failure();
 
       if (!newValue && needValue) {
@@ -1305,13 +1451,18 @@ static ParseResult parseGangClause(
 
   } while (succeeded(parser.parseOptionalComma()));
 
-  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
-                                               attributes.end());
+  if (failed(parser.parseRParen()))
+    return failure();
+
+  llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
+                                               gangArgTypeAttributes.end());
   gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
+  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
+
+  llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
+      gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
+  gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
 
-  llvm::SmallVector<mlir::Attribute> deviceTypeAttr(
-      deviceTypeAttributes.begin(), deviceTypeAttributes.end());
-  deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttr);
   segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
   return success();
 }
@@ -1320,33 +1471,63 @@ void printGangClause(OpAsmPrinter &p, Operation *op,
                      mlir::OperandRange operands, mlir::TypeRange types,
                      std::optional<mlir::ArrayAttr> gangArgTypes,
                      std::optional<mlir::ArrayAttr> deviceTypes,
-                     std::optional<mlir::DenseI32ArrayAttr> segments) {
-  unsigned opIdx = 0;
-  for (unsigned i = 0; i < deviceTypes->size(); ++i) {
-    if (i != 0)
-      p << ", ";
-    p << "{";
-    for (int32_t j = 0; j < (*segments)[i]; ++j) {
-      if (j != 0)
+                     std::optional<mlir::DenseI32ArrayAttr> segments,
+                     std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
+
+  if (operands.begin() == operands.end() && gangOnlyDeviceTypes &&
+      gangOnlyDeviceTypes->size() == 1) {
+    auto deviceTypeAttr =
+        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[0]);
+    if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+      return;
+  }
+
+  p << "(";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes)) {
+    p << "[";
+    for (unsigned i = 0; i < gangOnlyDeviceTypes.value().size(); ++i) {
+      if (i != 0)
         p << ", ";
-      auto gangArgTypeAttr =
-          mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
-      if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
-        p << LoopOp::getGangNumKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
-        p << LoopOp::getGangDimKeyword();
-      else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
-        p << LoopOp::getGangStaticKeyword();
-      p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
-      ++opIdx;
+      auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gangOnlyDeviceTypes)[i]);
+      p << deviceTypeAttr;
     }
+    p << "]";
+  }
 
-    p << "}";
-    auto deviceTypeAttr =
-        mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
-    if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
-      p << " [" << (*deviceTypes)[i] << "]";
+  if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
+      hasDeviceTypeValues(deviceTypes))
+    p << ", ";
+
+  if (deviceTypes) {
+    unsigned opIdx = 0;
+    for (unsigned i = 0; i < deviceTypes->size(); ++i) {
+      if (i != 0)
+        p << ", ";
+      p << "{";
+      for (int32_t j = 0; j < (*segments)[i]; ++j) {
+        if (j != 0)
+          p << ", ";
+        auto gangArgTypeAttr =
+            mlir::dyn_cast<mlir::acc::GangArgTypeAttr>((*gangArgTypes)[opIdx]);
+        if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
+          p << LoopOp::getGangNumKeyword();
+        else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
+          p << LoopOp::getGangDimKeyword();
+        else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Static)
+          p << LoopOp::getGangStaticKeyword();
+        p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
+        ++opIdx;
+      }
+
+      p << "}";
+      auto deviceTypeAttr =
+          mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[i]);
+      if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
+        p << " [" << (*deviceTypes)[i] << "]";
+    }
   }
+  p << ")";
 }
 
 bool hasDuplicateDeviceTypes(
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index ce5bfa490013e0..8fa37bc98294ce 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -11,7 +11,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
   %async = arith.constant 1 : i64
 
   acc.parallel async(%async: i64) {
-    acc.loop {
+    acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -25,7 +25,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
         }
       }
       acc.yield
-    } attributes { collapse = [3], collapseDeviceType = [#acc.device_type<none>], vector = [#acc.device_type<none>], gang = [#acc.device_type<none>]}
+    } attributes { collapse = [3], collapseDeviceType = [#acc.device_type<none>]}
     acc.yield
   }
 
@@ -38,7 +38,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
 //  CHECK-NEXT:   %{{.*}} = arith.constant 1 : index
 //  CHECK-NEXT:   [[ASYNC:%.*]] = arith.constant 1 : i64
 //  CHECK-NEXT:   acc.parallel async([[ASYNC]] : i64) {
-//  CHECK-NEXT:     acc.loop {
+//  CHECK-NEXT:     acc.loop gang vector {
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:         scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:           scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -52,7 +52,7 @@ func.func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
 //  CHECK-NEXT:         }
 //  CHECK-NEXT:       }
 //  CHECK-NEXT:       acc.yield
-//  CHECK-NEXT:     } attributes {collapse = [3], collapseDeviceType = [#acc.device_type<none>], gang = [#acc.device_type<none>], vector = [#acc.device_type<none>]}
+//  CHECK-NEXT:     } attributes {collapse = [3], collapseDeviceType = [#acc.device_type<none>]}
 //  CHECK-NEXT:     acc.yield
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   return %{{.*}} : memref<10x10xf32>
@@ -138,9 +138,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
   acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
     %private = acc.private varPtr(%c : memref<10xf32>) -> memref<10xf32>
     acc.parallel num_gangs({%numGangs: i64}) num_workers(%numWorkers: i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %private : memref<10xf32>) {
-      acc.loop {
+      acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
-          acc.loop {
+          acc.loop worker {
             scf.for %y = %lb to %c10 step %st {
               %axy = memref.load %a[%x, %y] : memref<10x10xf32>
               %bxy = memref.load %b[%x, %y] : memref<10x10xf32>
@@ -148,7 +148,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
               memref.store %tmp, %c[%y] : memref<10xf32>
             }
             acc.yield
-          } attributes {worker = [#acc.device_type<none>]}
+          }
 
           acc.loop {
             // for i = 0 to 10 step 1
@@ -163,7 +163,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
           } attributes {seq = [#acc.device_type<none>]}
         }
         acc.yield
-      } attributes {gang = [#acc.device_type<none>]}
+      }
       acc.yield
     }
     acc.terminator
@@ -181,9 +181,9 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK:        acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
 // CHECK-NEXT:     %[[P_ARG2:.*]] = acc.private varPtr([[ARG2]] : memref<10xf32>) -> memref<10xf32> 
 // CHECK-NEXT:     acc.parallel num_gangs({[[NUMGANG]] : i64}) num_workers([[NUMWORKERS]] : i64 [#acc.device_type<nvidia>]) private(@privatization_memref_10_f32 -> %[[P_ARG2]] : memref<10xf32>) {
-// CHECK-NEXT:       acc.loop {
+// CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
-// CHECK-NEXT:           acc.loop {
+// CHECK-NEXT:           acc.loop worker {
 // CHECK-NEXT:             scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
@@ -191,7 +191,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:               memref.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
 // CHECK-NEXT:             }
 // CHECK-NEXT:             acc.yield
-// CHECK-NEXT:           } attributes {worker = [#acc.device_type<none>]}
+// CHECK-NEXT:           }
 // CHECK-NEXT:           acc.loop {
 // CHECK-NEXT:             scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:               %{{.*}} = memref.load %{{.*}}[%{{.*}}] : memref<10xf32>
@@ -203,7 +203,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:           } attributes {seq = [#acc.device_type<none>]}
 // CHECK-NEXT:         }
 // CHECK-NEXT:         acc.yield
-// CHECK-NEXT:       } attributes {gang = [#acc.device_type<none>]}
+// CHECK-NEXT:       }
 // CHECK-NEXT:       acc.yield
 // CHECK-NEXT:     }
 // CHECK-NEXT:     acc.terminator
@@ -218,10 +218,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
   %i32Value = arith.constant 128 : i32
   %idxValue = arith.constant 8 : index
 
-  acc.loop {
+  acc.loop gang vector worker {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
-  } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>], gang = [#acc.device_type<none>]}
+  }
   acc.loop gang({num=%i64Value: i64}) {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
@@ -254,10 +254,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
   }
-  acc.loop gang({num=%i64Value: i64}) {
+  acc.loop gang({num=%i64Value: i64}) worker vector {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
-  } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+  }
   acc.loop gang({num=%i64Value: i64, static=%i64Value: i64}) worker(%i64Value: i64) vector(%i64Value: i64) {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
@@ -293,10 +293,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
 // CHECK:      [[I64VALUE:%.*]] = arith.constant 1 : i64
 // CHECK-NEXT: [[I32VALUE:%.*]] = arith.constant 128 : i32
 // CHECK-NEXT: [[IDXVALUE:%.*]] = arith.constant 8 : index
-// CHECK:      acc.loop {
+// CHECK:      acc.loop gang worker vector {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
-// CHECK-NEXT: } attributes {gang = [#acc.device_type<none>], vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+// CHECK-NEXT: }
 // CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
@@ -329,10 +329,10 @@ func.func @testloopop(%a : memref<10xf32>) -> () {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) {
+// CHECK:      acc.loop gang({num=[[I64VALUE]] : i64}) worker vector {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
-// CHECK-NEXT: } attributes {vector = [#acc.device_type<none>], worker = [#acc.device_type<none>]}
+// CHECK-NEXT: }
 // CHECK:      acc.loop gang({num=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64}) worker([[I64VALUE]] : i64) vector([[I64VALUE]] : i64) {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
    
    
More information about the Mlir-commits
mailing list