[flang-commits] [flang] 7f3d2cc - [mlir][openacc] Add gang dim operand to acc.loop operation

Valentin Clement via flang-commits flang-commits at lists.llvm.org
Tue Jun 13 13:39:37 PDT 2023


Author: Valentin Clement
Date: 2023-06-13T13:39:29-07:00
New Revision: 7f3d2cc26b47568a0ac93327ab579ed1f0e21546

URL: https://github.com/llvm/llvm-project/commit/7f3d2cc26b47568a0ac93327ab579ed1f0e21546
DIFF: https://github.com/llvm/llvm-project/commit/7f3d2cc26b47568a0ac93327ab579ed1f0e21546.diff

LOG: [mlir][openacc] Add gang dim operand to acc.loop operation

OpenACC 3.3 introduces a dim argument on the gang clause. This patch
adds a new operand for it on the acc.loop and update the custom
gang clause parser/printer for it.

Depends on D151970

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D151971

Added: 
    

Modified: 
    flang/lib/Lower/OpenACC.cpp
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/invalid.mlir
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 9126b3cdea9ba..4d58096938336 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -799,6 +799,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value workerNum;
   mlir::Value vectorNum;
   mlir::Value gangNum;
+  mlir::Value gangDim;
   mlir::Value gangStatic;
   llvm::SmallVector<mlir::Value, 2> tileOperands, privateOperands,
       reductionOperands;
@@ -883,6 +884,7 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Value> operands;
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, gangNum);
+  addOperand(operands, operandSegments, gangDim);
   addOperand(operands, operandSegments, gangStatic);
   addOperand(operands, operandSegments, workerNum);
   addOperand(operands, operandSegments, vectorNum);

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 097d01e9dd60e..85ffc39586390 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1034,6 +1034,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
   let arguments = (ins OptionalAttr<I64Attr>:$collapse,
                        Optional<IntOrIndex>:$gangNum,
+                       Optional<IntOrIndex>:$gangDim,
                        Optional<IntOrIndex>:$gangStatic,
                        Optional<IntOrIndex>:$workerNum,
                        Optional<IntOrIndex>:$vectorLength,
@@ -1056,13 +1057,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let extraClassDeclaration = [{
     static StringRef getAutoAttrStrName() { return "auto"; }
     static StringRef getGangNumKeyword() { return "num"; }
+    static StringRef getGangDimKeyword() { return "dim"; }
     static StringRef getGangStaticKeyword() { return "static"; }
   }];
 
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` custom<GangClause>($gangNum, type($gangNum), $gangStatic, type($gangStatic), $hasGang)
+        `gang` `` custom<GangClause>($gangNum, type($gangNum), $gangDim, type($gangDim), $gangStatic, type($gangStatic), $hasGang)
       | `worker` `` custom<WorkerClause>($workerNum, type($workerNum), $hasWorker)
       | `vector` `` custom<VectorClause>($vectorLength, type($vectorLength), $hasVector)
       | `private` `(` custom<SymOperandList>(

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index d722d5aef2db2..c6509e55a250a 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -650,14 +650,15 @@ parseGangValue(OpAsmParser &parser, llvm::StringRef keyword,
   return success();
 }
 
-static ParseResult
-parseGangClause(OpAsmParser &parser,
-                std::optional<OpAsmParser::UnresolvedOperand> &gangNum,
-                Type &gangNumType,
-                std::optional<OpAsmParser::UnresolvedOperand> &gangStatic,
-                Type &gangStaticType, UnitAttr &hasGang) {
+static ParseResult parseGangClause(
+    OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &gangNum,
+    Type &gangNumType, std::optional<OpAsmParser::UnresolvedOperand> &gangDim,
+    Type &gangDimType,
+    std::optional<OpAsmParser::UnresolvedOperand> &gangStatic,
+    Type &gangStaticType, UnitAttr &hasGang) {
   hasGang = UnitAttr::get(parser.getBuilder().getContext());
   gangNum = std::nullopt;
+  gangDim = std::nullopt;
   gangStatic = std::nullopt;
   bool needComa = false;
 
@@ -676,6 +677,9 @@ parseGangClause(OpAsmParser &parser,
       if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), gangNum,
                                 gangNumType, needComa, newValue)))
         return failure();
+      if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), gangDim,
+                                gangDimType, needComa, newValue)))
+        return failure();
       if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
                                 gangStatic, gangStaticType, needComa,
                                 newValue)))
@@ -691,9 +695,9 @@ parseGangClause(OpAsmParser &parser,
         break;
     }
 
-    if (!gangNum && !gangStatic) {
+    if (!gangNum && !gangDim && !gangStatic) {
       parser.emitError(parser.getCurrentLocation(),
-                       "expect num and/or static value(s)");
+                       "expect at least one of num, dim or static values");
       return failure();
     }
 
@@ -704,13 +708,19 @@ parseGangClause(OpAsmParser &parser,
 }
 
 void printGangClause(OpAsmPrinter &p, Operation *op, Value gangNum,
-                     Type gangNumType, Value gangStatic, Type gangStaticType,
-                     UnitAttr hasGang) {
-  if (gangNum || gangStatic) {
+                     Type gangNumType, Value gangDim, Type gangDimType,
+                     Value gangStatic, Type gangStaticType, UnitAttr hasGang) {
+  if (gangNum || gangStatic || gangDim) {
     p << "(";
     if (gangNum) {
       p << LoopOp::getGangNumKeyword() << "=" << gangNum << " : "
         << gangNumType;
+      if (gangStatic || gangDim)
+        p << ", ";
+    }
+    if (gangDim) {
+      p << LoopOp::getGangDimKeyword() << "=" << gangDim << " : "
+        << gangDimType;
       if (gangStatic)
         p << ", ";
     }

diff  --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ebd814205d230..31e4b4bdf74da 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -490,7 +490,7 @@ func.func @fct1(%0 : !llvm.ptr<i32>) -> () {
 
 // -----
 
-// expected-error at +1 {{expect num and/or static value(s)}}
+// expected-error at +1 {{expect at least one of num, dim or static values}}
 acc.loop gang() {
   "test.openacc_dummy_op"() : () -> ()
   acc.yield

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 95553ee3885b1..aa2f95a136205 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -276,6 +276,10 @@ func.func @testloopop() -> () {
     "test.openacc_dummy_op"() : () -> ()
     acc.yield
   }
+  acc.loop gang(dim=%i64Value : i64, static=%i64Value: i64) {
+    "test.openacc_dummy_op"() : () -> ()
+    acc.yield
+  }
   return
 }
 
@@ -342,6 +346,10 @@ func.func @testloopop() -> () {
 // CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
+// CHECK:      acc.loop gang(dim=[[I64VALUE]] : i64, static=[[I64VALUE]] : i64) {
+// CHECK-NEXT:   "test.openacc_dummy_op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
 
 // -----
 


        


More information about the flang-commits mailing list