[Mlir-commits] [mlir] 6d3cabd - [mlir][openacc] Change operand type from index to AnyInteger in parallel op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 17 08:34:02 PDT 2020


Author: Valentin Clement
Date: 2020-09-17T11:33:55-04:00
New Revision: 6d3cabd90eedee07a6e6cbf2dfa952e23cef192c

URL: https://github.com/llvm/llvm-project/commit/6d3cabd90eedee07a6e6cbf2dfa952e23cef192c
DIFF: https://github.com/llvm/llvm-project/commit/6d3cabd90eedee07a6e6cbf2dfa952e23cef192c.diff

LOG: [mlir][openacc] Change operand type from index to AnyInteger in parallel op

This patch change the type of operands async, wait, numGangs, numWorkers and vectorLength from index
to AnyInteger to fit with acc.loop and the OpenACC specification.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87712

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index f6350dbdf0db..3fa26f932bd9 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -36,7 +36,7 @@ class OpenACC_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-// Reduction operation enumeration
+// Reduction operation enumeration.
 def OpenACC_ReductionOpAdd : StrEnumAttrCase<"redop_add">;
 def OpenACC_ReductionOpMul : StrEnumAttrCase<"redop_mul">;
 def OpenACC_ReductionOpMax : StrEnumAttrCase<"redop_max">;
@@ -60,6 +60,9 @@ def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr",
   let cppNamespace = "::mlir::acc";
 }
 
+// Type used in operation below.
+def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>;
+
 //===----------------------------------------------------------------------===//
 // 2.5.1 parallel Construct
 //===----------------------------------------------------------------------===//
@@ -90,11 +93,11 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     ```
   }];
 
-  let arguments = (ins Optional<Index>:$async,
-                       Variadic<Index>:$waitOperands,
-                       Optional<Index>:$numGangs,
-                       Optional<Index>:$numWorkers,
-                       Optional<Index>:$vectorLength,
+  let arguments = (ins Optional<IntOrIndex>:$async,
+                       Variadic<IntOrIndex>:$waitOperands,
+                       Optional<IntOrIndex>:$numGangs,
+                       Optional<IntOrIndex>:$numWorkers,
+                       Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,
                        Optional<I1>:$selfCond,
                        OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 614951225042..3cae3c8feb8f 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -101,6 +101,22 @@ static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
   return success();
 }
 
+static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
+                                                       StringRef keyword,
+                                                       OperationState &result) {
+  OpAsmParser::OperandType operand;
+  Type type;
+  if (succeeded(parser.parseOptionalKeyword(keyword))) {
+    if (parser.parseLParen() || parser.parseOperand(operand) ||
+        parser.parseColonType(type) ||
+        parser.resolveOperand(operand, type, result.operands) ||
+        parser.parseRParen())
+      return failure();
+    return success();
+  }
+  return llvm::None;
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
@@ -142,17 +158,17 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
       firstprivateOperandTypes;
-  OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond,
-      selfCond;
-  bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false;
-  bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false;
 
-  Type indexType = builder.getIndexType();
+  SmallVector<Type, 8> operandTypes;
+  OpAsmParser::OperandType ifCond, selfCond;
+  bool hasIfCond = false, hasSelfCond = false;
+  OptionalParseResult async, numGangs, numWorkers, vectorLength;
   Type i1Type = builder.getI1Type();
 
   // async()?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async,
-                                  indexType, hasAsync, result)))
+  async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
+                                      result);
+  if (async.hasValue() && failed(*async))
     return failure();
 
   // wait()?
@@ -161,20 +177,21 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
     return failure();
 
   // num_gangs(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(),
-                                  numGangs, indexType, hasNumGangs, result)))
+  numGangs = parseOptionalOperandAndType(
+      parser, ParallelOp::getNumGangsKeyword(), result);
+  if (numGangs.hasValue() && failed(*numGangs))
     return failure();
 
   // num_workers(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(),
-                                  numWorkers, indexType, hasNumWorkers,
-                                  result)))
+  numWorkers = parseOptionalOperandAndType(
+      parser, ParallelOp::getNumWorkersKeyword(), result);
+  if (numWorkers.hasValue() && failed(*numWorkers))
     return failure();
 
   // vector_length(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(),
-                                  vectorLength, indexType, hasVectorLength,
-                                  result)))
+  vectorLength = parseOptionalOperandAndType(
+      parser, ParallelOp::getVectorLengthKeyword(), result);
+  if (vectorLength.hasValue() && failed(*vectorLength))
     return failure();
 
   // if()?
@@ -267,29 +284,30 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
   if (failed(parseRegions<ParallelOp>(parser, result)))
     return failure();
 
-  result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(),
-                      builder.getI32VectorAttr(
-                          {static_cast<int32_t>(hasAsync ? 1 : 0),
-                           static_cast<int32_t>(waitOperands.size()),
-                           static_cast<int32_t>(hasNumGangs ? 1 : 0),
-                           static_cast<int32_t>(hasNumWorkers ? 1 : 0),
-                           static_cast<int32_t>(hasVectorLength ? 1 : 0),
-                           static_cast<int32_t>(hasIfCond ? 1 : 0),
-                           static_cast<int32_t>(hasSelfCond ? 1 : 0),
-                           static_cast<int32_t>(reductionOperands.size()),
-                           static_cast<int32_t>(copyOperands.size()),
-                           static_cast<int32_t>(copyinOperands.size()),
-                           static_cast<int32_t>(copyinReadonlyOperands.size()),
-                           static_cast<int32_t>(copyoutOperands.size()),
-                           static_cast<int32_t>(copyoutZeroOperands.size()),
-                           static_cast<int32_t>(createOperands.size()),
-                           static_cast<int32_t>(createZeroOperands.size()),
-                           static_cast<int32_t>(noCreateOperands.size()),
-                           static_cast<int32_t>(presentOperands.size()),
-                           static_cast<int32_t>(devicePtrOperands.size()),
-                           static_cast<int32_t>(attachOperands.size()),
-                           static_cast<int32_t>(privateOperands.size()),
-                           static_cast<int32_t>(firstprivateOperands.size())}));
+  result.addAttribute(
+      ParallelOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr(
+          {static_cast<int32_t>(async.hasValue() ? 1 : 0),
+           static_cast<int32_t>(waitOperands.size()),
+           static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
+           static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
+           static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
+           static_cast<int32_t>(hasIfCond ? 1 : 0),
+           static_cast<int32_t>(hasSelfCond ? 1 : 0),
+           static_cast<int32_t>(reductionOperands.size()),
+           static_cast<int32_t>(copyOperands.size()),
+           static_cast<int32_t>(copyinOperands.size()),
+           static_cast<int32_t>(copyinReadonlyOperands.size()),
+           static_cast<int32_t>(copyoutOperands.size()),
+           static_cast<int32_t>(copyoutZeroOperands.size()),
+           static_cast<int32_t>(createOperands.size()),
+           static_cast<int32_t>(createZeroOperands.size()),
+           static_cast<int32_t>(noCreateOperands.size()),
+           static_cast<int32_t>(presentOperands.size()),
+           static_cast<int32_t>(devicePtrOperands.size()),
+           static_cast<int32_t>(attachOperands.size()),
+           static_cast<int32_t>(privateOperands.size()),
+           static_cast<int32_t>(firstprivateOperands.size())}));
 
   // Additional attributes
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -303,7 +321,8 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) {
 
   // async()?
   if (Value async = op.async())
-    printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")";
+    printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
+            << async.getType() << ")";
 
   // wait()?
   printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
@@ -311,17 +330,17 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) {
   // num_gangs()?
   if (Value numGangs = op.numGangs())
     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
-            << ")";
+            << ": " << numGangs.getType() << ")";
 
   // num_workers()?
   if (Value numWorkers = op.numWorkers())
     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
-            << ")";
+            << ": " << numWorkers.getType() << ")";
 
   // vector_length()?
   if (Value vectorLength = op.vectorLength())
     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
-            << vectorLength << ")";
+            << vectorLength << ": " << vectorLength.getType() << ")";
 
   // if()?
   if (Value ifCond = op.ifCond())

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 3398f95bf607..196949839db4 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -8,8 +8,9 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
   %c0 = constant 0 : index
   %c10 = constant 10 : index
   %c1 = constant 1 : index
+  %async = constant 1 : i64
 
-  acc.parallel async(%c1) {
+  acc.parallel async(%async: i64) {
     acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
@@ -35,7 +36,8 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:   %{{.*}} = constant 0 : index
 //  CHECK-NEXT:   %{{.*}} = constant 10 : index
 //  CHECK-NEXT:   %{{.*}} = constant 1 : index
-//  CHECK-NEXT:   acc.parallel async(%{{.*}}) {
+//  CHECK-NEXT:   [[ASYNC:%.*]] = constant 1 : i64
+//  CHECK-NEXT:   acc.parallel async([[ASYNC]]: i64) {
 //  CHECK-NEXT:     acc.loop gang vector {
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:         scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -113,9 +115,11 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
   %lb = constant 0 : index
   %st = constant 1 : index
   %c10 = constant 10 : index
+  %numGangs = constant 10 : i64
+  %numWorkers = constant 10 : i64
 
   acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) {
-    acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) {
+    acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
       acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
           acc.loop worker {
@@ -154,8 +158,10 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:   [[C0:%.*]] = constant 0 : index
 // CHECK-NEXT:   [[C1:%.*]] = constant 1 : index
 // CHECK-NEXT:   [[C10:%.*]] = constant 10 : index
+// CHECK-NEXT:   [[NUMGANG:%.*]] = constant 10 : i64
+// CHECK-NEXT:   [[NUMWORKERS:%.*]] = constant 10 : i64
 // CHECK-NEXT:   acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) {
-// CHECK-NEXT:     acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) {
+// CHECK-NEXT:     acc.parallel num_gangs([[NUMGANG]]: i64) num_workers([[NUMWORKERS]]: i64) private([[ARG2]]: memref<10xf32>) {
 // CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:           acc.loop worker {
@@ -265,9 +271,42 @@ func @testop(%a: memref<10xf32>) -> () {
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
 
+
 func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
-  %vectorLength = constant 128 : index
-  acc.parallel vector_length(%vectorLength) {
+  %i64value = constant 1 : i64
+  %i32value = constant 1 : i32
+  %idxValue = constant 1 : index
+  acc.parallel async(%i64value: i64) {
+  }
+  acc.parallel async(%i32value: i32) {
+  }
+  acc.parallel async(%idxValue: index) {
+  }
+  acc.parallel wait(%i64value: i64) {
+  }
+  acc.parallel wait(%i32value: i32) {
+  }
+  acc.parallel wait(%idxValue: index) {
+  }
+  acc.parallel wait(%i64value: i64, %i32value: i32, %idxValue: index) {
+  }
+  acc.parallel num_gangs(%i64value: i64) {
+  }
+  acc.parallel num_gangs(%i32value: i32) {
+  }
+  acc.parallel num_gangs(%idxValue: index) {
+  }
+  acc.parallel num_workers(%i64value: i64) {
+  }
+  acc.parallel num_workers(%i32value: i32) {
+  }
+  acc.parallel num_workers(%idxValue: index) {
+  }
+  acc.parallel vector_length(%i64value: i64) {
+  }
+  acc.parallel vector_length(%i32value: i32) {
+  }
+  acc.parallel vector_length(%idxValue: index) {
   }
   acc.parallel copyin(%a: memref<10xf32>, %b: memref<10xf32>) {
   }
@@ -293,26 +332,58 @@ func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf3
 }
 
 // CHECK:      func @testparallelop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
-// CHECK:        [[VECTORLENGTH:%.*]] = constant 128 : index
-// CHECK:        acc.parallel vector_length([[VECTORLENGTH]]) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin_readonly([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>) copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyout([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create([[ARGA]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create_zero([[ARGA]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel no_create([[ARGA]]: memref<10xf32>) present([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel deviceptr([[ARGA]]: memref<10xf32>) attach([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel private([[ARGA]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) firstprivate([[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel {
-// CHECK-NEXT:   } attributes {defaultAttr = "none"}
-// CHECK:        acc.parallel {
-// CHECK-NEXT:   } attributes {defaultAttr = "present"}
+// CHECK:      [[I64VALUE:%.*]] = constant 1 : i64
+// CHECK:      [[I32VALUE:%.*]] = constant 1 : i32
+// CHECK:      [[IDXVALUE:%.*]] = constant 1 : index
+// CHECK:      acc.parallel async([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel async([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel async([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I64VALUE]]: i64, [[I32VALUE]]: i32, [[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin_readonly([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin([[ARGA]]: memref<10xf32>) copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyout([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create_zero([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel no_create([[ARGA]]: memref<10xf32>) present([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel deviceptr([[ARGA]]: memref<10xf32>) attach([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel private([[ARGA]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) firstprivate([[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel {
+// CHECK-NEXT: } attributes {defaultAttr = "none"}
+// CHECK:      acc.parallel {
+// CHECK-NEXT: } attributes {defaultAttr = "present"}


        


More information about the Mlir-commits mailing list