[Mlir-commits] [mlir] c4a63b8 - [mlir][openacc] Switch numGangs to a variadic operand
Valentin Clement
llvmlistbot at llvm.org
Tue Jun 27 11:08:50 PDT 2023
Author: Valentin Clement
Date: 2023-06-27T11:08:44-07:00
New Revision: c4a63b8ee11f7e5a849f50d00018654c43046e3c
URL: https://github.com/llvm/llvm-project/commit/c4a63b8ee11f7e5a849f50d00018654c43046e3c
DIFF: https://github.com/llvm/llvm-project/commit/c4a63b8ee11f7e5a849f50d00018654c43046e3c.diff
LOG: [mlir][openacc] Switch numGangs to a variadic operand
In the latest spec, the `num_gangs` clause accepts up to three
arguments. Update the dialect to swicth `numGangs` operands from
optional single operand to a variadic operand. The verifier limits
the number of operands to three as specified in the spec.
Reviewed By: razvanlupusoru
Differential Revision: https://reviews.llvm.org/D153796
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/invalid.mlir
mlir/test/Dialect/OpenACC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 32882bd9c9d20..076faa76fcd31 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -656,7 +656,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
UnitAttr:$asyncAttr,
Variadic<IntOrIndex>:$waitOperands,
UnitAttr:$waitAttr,
- Optional<IntOrIndex>:$numGangs,
+ Variadic<IntOrIndex>:$numGangs,
Optional<IntOrIndex>:$numWorkers,
Optional<IntOrIndex>:$vectorLength,
Optional<I1>:$ifCond,
@@ -802,7 +802,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
UnitAttr:$asyncAttr,
Variadic<IntOrIndex>:$waitOperands,
UnitAttr:$waitAttr,
- Optional<IntOrIndex>:$numGangs,
+ Variadic<IntOrIndex>:$numGangs,
Optional<IntOrIndex>:$numWorkers,
Optional<IntOrIndex>:$vectorLength,
Optional<I1>:$ifCond,
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 9b04d6cdf326e..77de45281170e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -573,7 +573,7 @@ unsigned ParallelOp::getNumDataOperands() {
Value ParallelOp::getDataOperand(unsigned i) {
unsigned numOptional = getAsync() ? 1 : 0;
- numOptional += getNumGangs() ? 1 : 0;
+ numOptional += getNumGangs().size();
numOptional += getNumWorkers() ? 1 : 0;
numOptional += getVectorLength() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
@@ -590,6 +590,8 @@ LogicalResult acc::ParallelOp::verify() {
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
return failure();
+ if (getNumGangs().size() > 3)
+ return emitOpError() << "num_gangs expects a maximum of 3 values";
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
@@ -631,12 +633,18 @@ unsigned KernelsOp::getNumDataOperands() {
Value KernelsOp::getDataOperand(unsigned i) {
unsigned numOptional = getAsync() ? 1 : 0;
+ numOptional += getWaitOperands().size();
+ numOptional += getNumGangs().size();
+ numOptional += getNumWorkers() ? 1 : 0;
+ numOptional += getVectorLength() ? 1 : 0;
numOptional += getIfCond() ? 1 : 0;
numOptional += getSelfCond() ? 1 : 0;
- return getOperand(getWaitOperands().size() + numOptional + i);
+ return getOperand(numOptional + i);
}
LogicalResult acc::KernelsOp::verify() {
+ if (getNumGangs().size() > 3)
+ return emitOpError() << "num_gangs expects a maximum of 3 values";
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index a3d938658e0e0..2a36ccabbdd9c 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -486,3 +486,10 @@ acc.loop gang() {
"test.openacc_dummy_op"() : () -> ()
acc.yield
}
+
+// -----
+
+%i64value = arith.constant 1 : i64
+// expected-error at +1 {{num_gangs expects a maximum of 3 values}}
+acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
+}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 92d2a9271ed0d..e07ab8c3d31f2 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -443,6 +443,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
}
acc.parallel num_gangs(%idxValue: index) {
}
+ acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
+ }
acc.parallel num_workers(%i64value: i64) {
}
acc.parallel num_workers(%i32value: i32) {
@@ -494,6 +496,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
// CHECK-NEXT: }
// CHECK: acc.parallel num_gangs([[IDXVALUE]] : index) {
// CHECK-NEXT: }
+// CHECK: acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
+// CHECK-NEXT: }
// CHECK: acc.parallel num_workers([[I64VALUE]] : i64) {
// CHECK-NEXT: }
// CHECK: acc.parallel num_workers([[I32VALUE]] : i32) {
More information about the Mlir-commits
mailing list