[Mlir-commits] [mlir] d9953d1 - [mlir][openacc] Add missing operands for acc.parallel operation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 16 07:49:12 PDT 2020


Author: Valentin Clement
Date: 2020-09-16T10:49:03-04:00
New Revision: d9953d155493bf11a2276e202800f844a1d02396

URL: https://github.com/llvm/llvm-project/commit/d9953d155493bf11a2276e202800f844a1d02396
DIFF: https://github.com/llvm/llvm-project/commit/d9953d155493bf11a2276e202800f844a1d02396.diff

LOG: [mlir][openacc] Add missing operands for acc.parallel operation

Add missing operands to represent copin with readonly modifier, copyout with zero
modifier, create with zero modifier and default clause.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87733

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 0d37215ea4e5..f6350dbdf0db 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -64,6 +64,15 @@ def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr",
 // 2.5.1 parallel Construct
 //===----------------------------------------------------------------------===//
 
+// Parallel op default enumeration
+def OpenACC_DefaultNone : StrEnumAttrCase<"none">;
+def OpenACC_DefaultPresent : StrEnumAttrCase<"present">;
+def OpenACC_DefaultAttr : StrEnumAttr<"DefaultAttr",
+    "default attribute value for parallel op",
+    [OpenACC_DefaultNone, OpenACC_DefaultPresent]> {
+  let cppNamespace = "::mlir::acc";
+}
+
 def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     [AttrSizedOperandSegments]> {
   let summary = "parallel construct";
@@ -92,14 +101,18 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
                        Variadic<AnyType>:$reductionOperands,
                        Variadic<AnyType>:$copyOperands,
                        Variadic<AnyType>:$copyinOperands,
+                       Variadic<AnyType>:$copyinReadonlyOperands,
                        Variadic<AnyType>:$copyoutOperands,
+                       Variadic<AnyType>:$copyoutZeroOperands,
                        Variadic<AnyType>:$createOperands,
+                       Variadic<AnyType>:$createZeroOperands,
                        Variadic<AnyType>:$noCreateOperands,
                        Variadic<AnyType>:$presentOperands,
                        Variadic<AnyType>:$devicePtrOperands,
                        Variadic<AnyType>:$attachOperands,
                        Variadic<AnyType>:$gangPrivateOperands,
-                       Variadic<AnyType>:$gangFirstPrivateOperands);
+                       Variadic<AnyType>:$gangFirstPrivateOperands,
+                       OptionalAttr<OpenACC_DefaultAttr>:$defaultAttr);
 
   let regions = (region AnyRegion:$region);
 
@@ -114,8 +127,11 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     static StringRef getReductionKeyword() { return "reduction"; }
     static StringRef getCopyKeyword() { return "copy"; }
     static StringRef getCopyinKeyword() { return "copyin"; }
+    static StringRef getCopyinReadonlyKeyword() { return "copyin_readonly"; }
     static StringRef getCopyoutKeyword() { return "copyout"; }
+    static StringRef getCopyoutZeroKeyword() { return "copyout_zero"; }
     static StringRef getCreateKeyword() { return "create"; }
+    static StringRef getCreateZeroKeyword() { return "create_zero"; }
     static StringRef getNoCreateKeyword() { return "no_create"; }
     static StringRef getPresentKeyword() { return "present"; }
     static StringRef getDevicePtrKeyword() { return "deviceptr"; }

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 3e4d1c3f0e7d..614951225042 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -116,8 +116,11 @@ static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
 ///                             `reduction` `(` value-list `)`?
 ///                             `copy` `(` value-list `)`?
 ///                             `copyin` `(` value-list `)`?
+///                             `copyin_readonly` `(` value-list `)`?
 ///                             `copyout` `(` value-list `)`?
+///                             `copyout_zero` `(` value-list `)`?
 ///                             `create` `(` value-list `)`?
+///                             `create_zero` `(` value-list `)`?
 ///                             `no_create` `(` value-list `)`?
 ///                             `present` `(` value-list `)`?
 ///                             `deviceptr` `(` value-list `)`?
@@ -129,10 +132,16 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
                                    OperationState &result) {
   Builder &builder = parser.getBuilder();
   SmallVector<OpAsmParser::OperandType, 8> privateOperands,
-      firstprivateOperands, createOperands, copyOperands, copyinOperands,
-      copyoutOperands, noCreateOperands, presentOperands, devicePtrOperands,
-      attachOperands, waitOperands, reductionOperands;
-  SmallVector<Type, 8> operandTypes;
+      firstprivateOperands, copyOperands, copyinOperands,
+      copyinReadonlyOperands, copyoutOperands, copyoutZeroOperands,
+      createOperands, createZeroOperands, noCreateOperands, presentOperands,
+      devicePtrOperands, attachOperands, waitOperands, reductionOperands;
+  SmallVector<Type, 8> waitOperandTypes, reductionOperandTypes,
+      copyOperandTypes, copyinOperandTypes, copyinReadonlyOperandTypes,
+      copyoutOperandTypes, copyoutZeroOperandTypes, createOperandTypes,
+      createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
+      deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
+      firstprivateOperandTypes;
   OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond,
       selfCond;
   bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false;
@@ -148,7 +157,7 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
 
   // wait()?
   if (failed(parseOperandList(parser, ParallelOp::getWaitKeyword(),
-                              waitOperands, operandTypes, result)))
+                              waitOperands, waitOperandTypes, result)))
     return failure();
 
   // num_gangs(value)?
@@ -180,57 +189,78 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
 
   // reduction()?
   if (failed(parseOperandList(parser, ParallelOp::getReductionKeyword(),
-                              reductionOperands, operandTypes, result)))
+                              reductionOperands, reductionOperandTypes,
+                              result)))
     return failure();
 
   // copy()?
   if (failed(parseOperandList(parser, ParallelOp::getCopyKeyword(),
-                              copyOperands, operandTypes, result)))
+                              copyOperands, copyOperandTypes, result)))
     return failure();
 
   // copyin()?
   if (failed(parseOperandList(parser, ParallelOp::getCopyinKeyword(),
-                              copyinOperands, operandTypes, result)))
+                              copyinOperands, copyinOperandTypes, result)))
+    return failure();
+
+  // copyin_readonly()?
+  if (failed(parseOperandList(parser, ParallelOp::getCopyinReadonlyKeyword(),
+                              copyinReadonlyOperands,
+                              copyinReadonlyOperandTypes, result)))
     return failure();
 
   // copyout()?
   if (failed(parseOperandList(parser, ParallelOp::getCopyoutKeyword(),
-                              copyoutOperands, operandTypes, result)))
+                              copyoutOperands, copyoutOperandTypes, result)))
+    return failure();
+
+  // copyout_zero()?
+  if (failed(parseOperandList(parser, ParallelOp::getCopyoutZeroKeyword(),
+                              copyoutZeroOperands, copyoutZeroOperandTypes,
+                              result)))
     return failure();
 
   // create()?
   if (failed(parseOperandList(parser, ParallelOp::getCreateKeyword(),
-                              createOperands, operandTypes, result)))
+                              createOperands, createOperandTypes, result)))
+    return failure();
+
+  // create_zero()?
+  if (failed(parseOperandList(parser, ParallelOp::getCreateZeroKeyword(),
+                              createZeroOperands, createZeroOperandTypes,
+                              result)))
     return failure();
 
   // no_create()?
   if (failed(parseOperandList(parser, ParallelOp::getNoCreateKeyword(),
-                              noCreateOperands, operandTypes, result)))
+                              noCreateOperands, noCreateOperandTypes, result)))
     return failure();
 
   // present()?
   if (failed(parseOperandList(parser, ParallelOp::getPresentKeyword(),
-                              presentOperands, operandTypes, result)))
+                              presentOperands, presentOperandTypes, result)))
     return failure();
 
   // deviceptr()?
   if (failed(parseOperandList(parser, ParallelOp::getDevicePtrKeyword(),
-                              devicePtrOperands, operandTypes, result)))
+                              devicePtrOperands, deviceptrOperandTypes,
+                              result)))
     return failure();
 
   // attach()?
   if (failed(parseOperandList(parser, ParallelOp::getAttachKeyword(),
-                              attachOperands, operandTypes, result)))
+                              attachOperands, attachOperandTypes, result)))
     return failure();
 
   // private()?
   if (failed(parseOperandList(parser, ParallelOp::getPrivateKeyword(),
-                              privateOperands, operandTypes, result)))
+                              privateOperands, privateOperandTypes, result)))
     return failure();
 
   // firstprivate()?
   if (failed(parseOperandList(parser, ParallelOp::getFirstPrivateKeyword(),
-                              firstprivateOperands, operandTypes, result)))
+                              firstprivateOperands, firstprivateOperandTypes,
+                              result)))
     return failure();
 
   // Parallel op region
@@ -249,8 +279,11 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
                            static_cast<int32_t>(reductionOperands.size()),
                            static_cast<int32_t>(copyOperands.size()),
                            static_cast<int32_t>(copyinOperands.size()),
+                           static_cast<int32_t>(copyinReadonlyOperands.size()),
                            static_cast<int32_t>(copyoutOperands.size()),
+                           static_cast<int32_t>(copyoutZeroOperands.size()),
                            static_cast<int32_t>(createOperands.size()),
+                           static_cast<int32_t>(createZeroOperands.size()),
                            static_cast<int32_t>(noCreateOperands.size()),
                            static_cast<int32_t>(presentOperands.size()),
                            static_cast<int32_t>(devicePtrOperands.size()),
@@ -309,14 +342,26 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) {
   printOperandList(op.copyinOperands(), ParallelOp::getCopyinKeyword(),
                    printer);
 
+  // copyin_readonly()?
+  printOperandList(op.copyinReadonlyOperands(),
+                   ParallelOp::getCopyinReadonlyKeyword(), printer);
+
   // copyout()?
   printOperandList(op.copyoutOperands(), ParallelOp::getCopyoutKeyword(),
                    printer);
 
+  // copyout_zero()?
+  printOperandList(op.copyoutZeroOperands(),
+                   ParallelOp::getCopyoutZeroKeyword(), printer);
+
   // create()?
   printOperandList(op.createOperands(), ParallelOp::getCreateKeyword(),
                    printer);
 
+  // create_zero()?
+  printOperandList(op.createZeroOperands(), ParallelOp::getCreateZeroKeyword(),
+                   printer);
+
   // no_create()?
   printOperandList(op.noCreateOperands(), ParallelOp::getNoCreateKeyword(),
                    printer);

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index b1a78c61d65d..3398f95bf607 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -265,14 +265,54 @@ func @testop(%a: memref<10xf32>) -> () {
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
 
-
-func @testparallelop() -> () {
+func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
   %vectorLength = constant 128 : index
   acc.parallel vector_length(%vectorLength) {
   }
+  acc.parallel copyin(%a: memref<10xf32>, %b: memref<10xf32>) {
+  }
+  acc.parallel copyin_readonly(%a: memref<10xf32>, %b: memref<10xf32>) {
+  }
+  acc.parallel copyin(%a: memref<10xf32>) copyout_zero(%b: memref<10xf32>, %c: memref<10x10xf32>) {
+  }
+  acc.parallel copyout(%b: memref<10xf32>, %c: memref<10x10xf32>) create(%a: memref<10xf32>) {
+  }
+  acc.parallel copyout_zero(%b: memref<10xf32>, %c: memref<10x10xf32>) create_zero(%a: memref<10xf32>) {
+  }
+  acc.parallel no_create(%a: memref<10xf32>) present(%b: memref<10xf32>, %c: memref<10x10xf32>) {
+  }
+  acc.parallel deviceptr(%a: memref<10xf32>) attach(%b: memref<10xf32>, %c: memref<10x10xf32>) {
+  }
+  acc.parallel private(%a: memref<10xf32>, %c: memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
+  }
+  acc.parallel {
+  } attributes {defaultAttr = "none"}
+  acc.parallel {
+  } attributes {defaultAttr = "present"}
   return
 }
 
-// CHECK:      [[VECTORLENGTH:%.*]] = constant 128 : index
-// CHECK-NEXT: acc.parallel vector_length([[VECTORLENGTH]]) {
-// CHECK-NEXT: }
+// CHECK:      func @testparallelop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
+// CHECK:        [[VECTORLENGTH:%.*]] = constant 128 : index
+// CHECK:        acc.parallel vector_length([[VECTORLENGTH]]) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel copyin_readonly([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>) copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel copyout([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create_zero([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel no_create([[ARGA]]: memref<10xf32>) present([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel deviceptr([[ARGA]]: memref<10xf32>) attach([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel private([[ARGA]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) firstprivate([[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT:   }
+// CHECK:        acc.parallel {
+// CHECK-NEXT:   } attributes {defaultAttr = "none"}
+// CHECK:        acc.parallel {
+// CHECK-NEXT:   } attributes {defaultAttr = "present"}


        


More information about the Mlir-commits mailing list