[Mlir-commits] [mlir] 22dde1f - [mlir][openacc] Support Index and AnyInteger in loop op
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Sep 18 08:37:56 PDT 2020
Author: Valentin Clement
Date: 2020-09-18T11:37:49-04:00
New Revision: 22dde1f92f68b4249dbae30c119972a17753236a
URL: https://github.com/llvm/llvm-project/commit/22dde1f92f68b4249dbae30c119972a17753236a
DIFF: https://github.com/llvm/llvm-project/commit/22dde1f92f68b4249dbae30c119972a17753236a.diff
LOG: [mlir][openacc] Support Index and AnyInteger in loop op
Following patch D87712, this patch switch AnyInteger for operands gangNum, gangStatic,
workerNum, vectoreLength and tileOperands to Index and AnyInteger.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D87848
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 3fa26f932bd9..bd685f90ad4a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -244,14 +244,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let arguments = (ins OptionalAttr<I64Attr>:$collapse,
- Optional<AnyInteger>:$gangNum,
- Optional<AnyInteger>:$gangStatic,
- Optional<AnyInteger>:$workerNum,
- Optional<AnyInteger>:$vectorLength,
+ Optional<IntOrIndex>:$gangNum,
+ Optional<IntOrIndex>:$gangStatic,
+ Optional<IntOrIndex>:$workerNum,
+ Optional<IntOrIndex>:$vectorLength,
UnitAttr:$seq,
UnitAttr:$independent,
UnitAttr:$auto_,
- Variadic<AnyInteger>:$tileOperands,
+ Variadic<IntOrIndex>:$tileOperands,
Variadic<AnyType>:$privateOperands,
OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands,
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 3cae3c8feb8f..efd7f866c491 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -545,10 +545,14 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
//===----------------------------------------------------------------------===//
/// Parse acc.loop operation
-/// operation := `acc.loop` `gang`? `vector`? `worker`?
-/// `private` `(` value-list `)`?
-/// `reduction` `(` value-list `)`?
-/// region attr-dict?
+/// operation := `acc.loop`
+/// (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
+/// (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
+/// (`vector_length` `(` value `)`)?
+/// (`tile` `(` value-list `)`)?
+/// (`private` `(` value-list `)`)?
+/// (`reduction` `(` value-list `)`)?
+/// region attr-dict?
static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
Builder &builder = parser.getBuilder();
unsigned executionMapping = OpenACCExecMapping::NONE;
@@ -558,7 +562,7 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
bool hasGangStatic = false;
OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
- Type intType = builder.getI64Type();
+ Type gangNumType, gangStaticType, workerType, vectorLengthType;
// gang?
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
@@ -568,9 +572,9 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
hasGangNum = true;
- parser.parseColon();
- if (parser.parseOperand(gangNum) ||
- parser.resolveOperand(gangNum, intType, result.operands)) {
+ parser.parseEqual();
+ if (parser.parseOperand(gangNum) || parser.parseColonType(gangNumType) ||
+ parser.resolveOperand(gangNum, gangNumType, result.operands)) {
return failure();
}
}
@@ -578,9 +582,10 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
if (succeeded(
parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
hasGangStatic = true;
- parser.parseColon();
+ parser.parseEqual();
if (parser.parseOperand(gangStatic) ||
- parser.resolveOperand(gangStatic, intType, result.operands)) {
+ parser.parseColonType(gangStaticType) ||
+ parser.resolveOperand(gangStatic, gangStaticType, result.operands)) {
return failure();
}
}
@@ -595,8 +600,8 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
// optional worker operand
if (succeeded(parser.parseOptionalLParen())) {
hasWorkerNum = true;
- if (parser.parseOperand(workerNum) ||
- parser.resolveOperand(workerNum, intType, result.operands) ||
+ if (parser.parseOperand(workerNum) || parser.parseColonType(workerType) ||
+ parser.resolveOperand(workerNum, workerType, result.operands) ||
parser.parseRParen()) {
return failure();
}
@@ -610,7 +615,9 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
if (succeeded(parser.parseOptionalLParen())) {
hasVectorLength = true;
if (parser.parseOperand(vectorLength) ||
- parser.resolveOperand(vectorLength, intType, result.operands) ||
+ parser.parseColonType(vectorLengthType) ||
+ parser.resolveOperand(vectorLength, vectorLengthType,
+ result.operands) ||
parser.parseRParen()) {
return failure();
}
@@ -671,12 +678,14 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
if (gangNum || gangStatic) {
printer << "(";
if (gangNum) {
- printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
+ printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
+ << gangNum.getType();
if (gangStatic)
printer << ", ";
}
if (gangStatic)
- printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
+ printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
+ << gangStatic.getType();
printer << ")";
}
}
@@ -686,7 +695,7 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
// Print optional worker operand if present
if (Value workerNum = op.workerNum())
- printer << "(" << workerNum << ")";
+ printer << "(" << workerNum << ": " << workerNum.getType() << ")";
}
if (execMapping & OpenACCExecMapping::VECTOR) {
@@ -694,7 +703,7 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
// Print optional vector operand if present
if (Value vectorLength = op.vectorLength())
- printer << "(" << vectorLength << ")";
+ printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
}
// tile()?
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 196949839db4..07ec198b4736 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | FileCheck %s
// Verify the printed output can be parsed.
-// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
// Verify the generic form can be parsed.
-// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = constant 0 : index
@@ -58,6 +58,8 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
// CHECK-NEXT: }
+// -----
+
func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = constant 0 : index
%c10 = constant 10 : index
@@ -110,6 +112,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
// CHECK-NEXT: return %{{.*}} : memref<10x10xf32>
// CHECK-NEXT: }
+// -----
func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
%lb = constant 0 : index
@@ -192,85 +195,133 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
// CHECK-NEXT: return %{{.*}} : memref<10xf32>
// CHECK-NEXT: }
-func @testop(%a: memref<10xf32>) -> () {
- %workerNum = constant 1 : i64
- %vectorLength = constant 128 : i64
- %gangNum = constant 8 : i64
- %gangStatic = constant 2 : i64
- %tileSize = constant 2 : i64
+// -----
+
+func @testloopop() -> () {
+ %i64Value = constant 1 : i64
+ %i32Value = constant 128 : i32
+ %idxValue = constant 8 : index
+
acc.loop gang worker vector {
"some.op"() : () -> ()
acc.yield
}
- acc.loop gang(num: %gangNum) {
+ acc.loop gang(num=%i64Value: i64) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop gang(static=%i64Value: i64) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop worker(%i64Value: i64) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop worker(%i32Value: i32) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop worker(%idxValue: index) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop vector(%i64Value: i64) {
+ "some.op"() : () -> ()
+ acc.yield
+ }
+ acc.loop vector(%i32Value: i32) {
"some.op"() : () -> ()
acc.yield
}
- acc.loop gang(static: %gangStatic) {
+ acc.loop vector(%idxValue: index) {
"some.op"() : () -> ()
acc.yield
}
- acc.loop worker(%workerNum) {
+ acc.loop gang(num=%i64Value: i64) worker vector {
"some.op"() : () -> ()
acc.yield
}
- acc.loop vector(%vectorLength) {
+ acc.loop gang(num=%i64Value: i64, static=%i64Value: i64) worker(%i64Value: i64) vector(%i64Value: i64) {
"some.op"() : () -> ()
acc.yield
}
- acc.loop gang(num: %gangNum) worker vector {
+ acc.loop gang(num=%i32Value: i32, static=%idxValue: index) {
"some.op"() : () -> ()
acc.yield
}
- acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
+ acc.loop tile(%i64Value: i64, %i64Value: i64) {
"some.op"() : () -> ()
acc.yield
}
- acc.loop tile(%tileSize : i64, %tileSize : i64) {
+ acc.loop tile(%i32Value: i32, %i32Value: i32) {
"some.op"() : () -> ()
acc.yield
}
return
}
-// CHECK: [[WORKERNUM:%.*]] = constant 1 : i64
-// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
-// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
-// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
-// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
-// CHECK-NEXT: acc.loop gang worker vector {
+// CHECK: [[I64VALUE:%.*]] = constant 1 : i64
+// CHECK-NEXT: [[I32VALUE:%.*]] = constant 128 : i32
+// CHECK-NEXT: [[IDXVALUE:%.*]] = constant 8 : index
+// CHECK: acc.loop gang worker vector {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop gang(num=[[I64VALUE]]: i64) {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop gang(static=[[I64VALUE]]: i64) {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop worker([[I64VALUE]]: i64) {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop worker([[I32VALUE]]: i32) {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop worker([[IDXVALUE]]: index) {
+// CHECK-NEXT: "some.op"() : () -> ()
+// CHECK-NEXT: acc.yield
+// CHECK-NEXT: }
+// CHECK: acc.loop vector([[I64VALUE]]: i64) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
+// CHECK: acc.loop vector([[I32VALUE]]: i32) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
+// CHECK: acc.loop vector([[IDXVALUE]]: index) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
+// CHECK: acc.loop gang(num=[[I64VALUE]]: i64) worker vector {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
+// CHECK: acc.loop gang(num=[[I64VALUE]]: i64, static=[[I64VALUE]]: i64) worker([[I64VALUE]]: i64) vector([[I64VALUE]]: i64) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
+// CHECK: acc.loop gang(num=[[I32VALUE]]: i32, static=[[IDXVALUE]]: index) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
+// CHECK: acc.loop tile([[I64VALUE]]: i64, [[I64VALUE]]: i64) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
-// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
+// CHECK: acc.loop tile([[I32VALUE]]: i32, [[I32VALUE]]: i32) {
// CHECK-NEXT: "some.op"() : () -> ()
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
+// -----
func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
%i64value = constant 1 : i64
More information about the Mlir-commits
mailing list