[Mlir-commits] [mlir] 2bbbcae - [mlir][openacc] Add missing attributes and operands for acc.loop
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 31 16:50:13 PDT 2020
Author: Valentin Clement
Date: 2020-08-31T19:50:05-04:00
New Revision: 2bbbcae782adbea20ae50f9f5471056a91498ffc
URL: https://github.com/llvm/llvm-project/commit/2bbbcae782adbea20ae50f9f5471056a91498ffc
DIFF: https://github.com/llvm/llvm-project/commit/2bbbcae782adbea20ae50f9f5471056a91498ffc.diff
LOG: [mlir][openacc] Add missing attributes and operands for acc.loop
This patch add the missing operands to the acc.loop operation. Only the device_type
information is not part of the operation for now.
Reviewed By: rriddle, kiranchandramohan
Differential Revision: https://reviews.llvm.org/D86753
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACC.h
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 399db74fae2c..8f5e1daf9aeb 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -31,12 +31,11 @@ namespace acc {
/// 2.9.2. gang
/// 2.9.3. worker
/// 2.9.4. vector
-/// 2.9.5. seq
///
/// Value can be combined bitwise to reflect the mapping applied to the
/// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
/// combined and the final mapping value would be 5 (4 | 1).
-enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 };
+enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 };
} // end namespace acc
} // end namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 8dd0f849ddd2..30d6f435b75f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -224,6 +224,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let arguments = (ins OptionalAttr<I64Attr>:$collapse,
+ Optional<AnyInteger>:$gangNum,
+ Optional<AnyInteger>:$gangStatic,
+ Optional<AnyInteger>:$workerNum,
+ Optional<AnyInteger>:$vectorLength,
+ UnitAttr:$loopSeq,
+ UnitAttr:$loopIndependent,
+ UnitAttr:$loopAuto,
+ Variadic<AnyInteger>:$tileOperands,
Variadic<AnyType>:$privateOperands,
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands);
@@ -234,11 +242,16 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let extraClassDeclaration = [{
static StringRef getCollapseAttrName() { return "collapse"; }
- static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
- static StringRef getGangAttrName() { return "gang"; }
static StringRef getSeqAttrName() { return "seq"; }
- static StringRef getVectorAttrName() { return "vector"; }
- static StringRef getWorkerAttrName() { return "worker"; }
+ static StringRef getIndependentAttrName() { return "independent"; }
+ static StringRef getAutoAttrName() { return "auto"; }
+ static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
+ static StringRef getGangKeyword() { return "gang"; }
+ static StringRef getGangNumKeyword() { return "num"; }
+ static StringRef getGangStaticKeyword() { return "static"; }
+ static StringRef getVectorKeyword() { return "vector"; }
+ static StringRef getWorkerKeyword() { return "worker"; }
+ static StringRef getTileKeyword() { return "tile"; }
static StringRef getPrivateKeyword() { return "private"; }
static StringRef getReductionKeyword() { return "reduction"; }
}];
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6ac411d456af..b5dfa2c13358 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -476,32 +476,81 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
//===----------------------------------------------------------------------===//
/// Parse acc.loop operation
-/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`?
+/// operation := `acc.loop` `gang`? `vector`? `worker`?
/// `private` `(` value-list `)`?
/// `reduction` `(` value-list `)`?
/// region attr-dict?
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
-
Builder &builder = parser.getBuilder();
unsigned executionMapping = 0;
SmallVector<Type, 8> operandTypes;
SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
+ SmallVector<OpAsmParser::OperandType, 8> tileOperands;
+ bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
+ bool hasGangStatic = false;
+ OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
+ Type intType = builder.getI64Type();
// gang?
- if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName())))
+ if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
executionMapping |= OpenACCExecMapping::GANG;
- // vector?
- if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName())))
- executionMapping |= OpenACCExecMapping::VECTOR;
+ // optional gang operand
+ if (succeeded(parser.parseOptionalLParen())) {
+ if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
+ hasGangNum = true;
+ parser.parseColon();
+ if (parser.parseOperand(gangNum) ||
+ parser.resolveOperand(gangNum, intType, result.operands)) {
+ return failure();
+ }
+ }
+ parser.parseOptionalComma();
+ if (succeeded(
+ parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
+ hasGangStatic = true;
+ parser.parseColon();
+ if (parser.parseOperand(gangStatic) ||
+ parser.resolveOperand(gangStatic, intType, result.operands)) {
+ return failure();
+ }
+ }
+ if (failed(parser.parseRParen()))
+ return failure();
+ }
// worker?
- if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName())))
+ if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
executionMapping |= OpenACCExecMapping::WORKER;
- // seq?
- if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName())))
- executionMapping |= OpenACCExecMapping::SEQ;
+ // optional worker operand
+ if (succeeded(parser.parseOptionalLParen())) {
+ hasWorkerNum = true;
+ if (parser.parseOperand(workerNum) ||
+ parser.resolveOperand(workerNum, intType, result.operands) ||
+ parser.parseRParen()) {
+ return failure();
+ }
+ }
+
+ // vector?
+ if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
+ executionMapping |= OpenACCExecMapping::VECTOR;
+
+ // optional vector operand
+ if (succeeded(parser.parseOptionalLParen())) {
+ hasVectorLength = true;
+ if (parser.parseOperand(vectorLength) ||
+ parser.resolveOperand(vectorLength, intType, result.operands) ||
+ parser.parseRParen()) {
+ return failure();
+ }
+ }
+
+ // tile()?
+ if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
+ operandTypes, result)))
+ return failure();
// private()?
if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
@@ -526,7 +575,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(
- {static_cast<int32_t>(privateOperands.size()),
+ {static_cast<int32_t>(hasGangNum ? 1 : 0),
+ static_cast<int32_t>(hasGangStatic ? 1 : 0),
+ static_cast<int32_t>(hasWorkerNum ? 1 : 0),
+ static_cast<int32_t>(hasVectorLength ? 1 : 0),
+ static_cast<int32_t>(tileOperands.size()),
+ static_cast<int32_t>(privateOperands.size()),
static_cast<int32_t>(reductionOperands.size())}));
if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -544,17 +598,44 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
.getInt()
: 0;
- if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
- printer << " " << LoopOp::getGangAttrName();
- if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER)
- printer << " " << LoopOp::getWorkerAttrName();
+ if (execMapping & OpenACCExecMapping::GANG) {
+ printer << " " << LoopOp::getGangKeyword();
+ Value gangNum = op.gangNum();
+ Value gangStatic = op.gangStatic();
+
+ // Print optional gang operands
+ if (gangNum || gangStatic) {
+ printer << "(";
+ if (gangNum) {
+ printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
+ if (gangStatic)
+ printer << ", ";
+ }
+ if (gangStatic)
+ printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
+ printer << ")";
+ }
+ }
- if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR)
- printer << " " << LoopOp::getVectorAttrName();
+ if (execMapping & OpenACCExecMapping::WORKER) {
+ printer << " " << LoopOp::getWorkerKeyword();
+
+ // Print optional worker operand if present
+ if (Value workerNum = op.workerNum())
+ printer << "(" << workerNum << ")";
+ }
+
+ if (execMapping & OpenACCExecMapping::VECTOR) {
+ printer << " " << LoopOp::getVectorKeyword();
+
+ // Print optional vector operand if present
+ if (Value vectorLength = op.vectorLength())
+ printer << "(" << vectorLength << ")";
+ }
- if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ)
- printer << " " << LoopOp::getSeqAttrName();
+ // tile()?
+ printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
// private()?
printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 89431d8926e2..6cdba227d5da 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -62,7 +62,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
%c1 = constant 1 : index
acc.parallel {
- acc.loop seq {
+ acc.loop {
scf.for %arg3 = %c0 to %c10 step %c1 {
scf.for %arg4 = %c0 to %c10 step %c1 {
scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -76,7 +76,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
}
}
acc.yield
- }
+ } attributes {seq}
acc.yield
}
@@ -88,7 +88,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: %{{.*}} = constant 10 : index
// CHECK-NEXT: %{{.*}} = constant 1 : index
// CHECK-NEXT: acc.parallel {
-// CHECK-NEXT: acc.loop seq {
+// CHECK-NEXT: acc.loop {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -102,7 +102,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
-// CHECK-NEXT: }
+// CHECK-NEXT: } attributes {seq}
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
@@ -128,7 +128,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
acc.yield
}
- acc.loop seq {
+ acc.loop {
// for i = 0 to 10 step 1
// d[x] += c[i]
scf.for %i = %lb to %c10 step %st {
@@ -138,7 +138,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
store %z, %d[%x] : memref<10xf32>
}
acc.yield
- }
+ } attributes {seq}
}
acc.yield
}
@@ -167,7 +167,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop seq {
+// CHECK-NEXT: acc.loop {
// CHECK-NEXT: scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
@@ -175,7 +175,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
-// CHECK-NEXT: }
+// CHECK-NEXT: } attributes {seq}
// CHECK-NEXT: }
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
@@ -184,4 +184,51 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: acc.terminator
// CHECK-NEXT: }
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
-// CHECK-NEXT: }
\ No newline at end of file
+// CHECK-NEXT: }
+
+func @testop() -> () {
+ %workerNum = constant 1 : i64
+ %vectorLength = constant 128 : i64
+ %gangNum = constant 8 : i64
+ %gangStatic = constant 2 : i64
+ %tileSize = constant 2 : i64
+ acc.loop gang worker vector {
+ }
+ acc.loop gang(num: %gangNum) {
+ }
+ acc.loop gang(static: %gangStatic) {
+ }
+ acc.loop worker(%workerNum) {
+ }
+ acc.loop vector(%vectorLength) {
+ }
+ acc.loop gang(num: %gangNum) worker vector {
+ }
+ acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
+ }
+ acc.loop tile(%tileSize : i64, %tileSize : i64) {
+ }
+ return
+}
+
+// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64
+// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
+// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
+// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
+// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
+// CHECK-NEXT: acc.loop gang worker vector {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
+// CHECK-NEXT: }
More information about the Mlir-commits
mailing list