[Mlir-commits] [mlir] c4a63b8 - [mlir][openacc] Switch numGangs to a variadic operand

Valentin Clement llvmlistbot at llvm.org
Tue Jun 27 11:08:50 PDT 2023


Author: Valentin Clement
Date: 2023-06-27T11:08:44-07:00
New Revision: c4a63b8ee11f7e5a849f50d00018654c43046e3c

URL: https://github.com/llvm/llvm-project/commit/c4a63b8ee11f7e5a849f50d00018654c43046e3c
DIFF: https://github.com/llvm/llvm-project/commit/c4a63b8ee11f7e5a849f50d00018654c43046e3c.diff

LOG: [mlir][openacc] Switch numGangs to a variadic operand

In the latest spec, the `num_gangs` clause accepts up to three
arguments. Update the dialect to swicth `numGangs` operands from
optional single operand to a variadic operand. The verifier limits
the number of operands to three as specified in the spec.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D153796

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/invalid.mlir
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 32882bd9c9d20..076faa76fcd31 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -656,7 +656,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
                        UnitAttr:$asyncAttr,
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$waitAttr,
-                       Optional<IntOrIndex>:$numGangs,
+                       Variadic<IntOrIndex>:$numGangs,
                        Optional<IntOrIndex>:$numWorkers,
                        Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,
@@ -802,7 +802,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
                        UnitAttr:$asyncAttr,
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$waitAttr,
-                       Optional<IntOrIndex>:$numGangs,
+                       Variadic<IntOrIndex>:$numGangs,
                        Optional<IntOrIndex>:$numWorkers,
                        Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 9b04d6cdf326e..77de45281170e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -573,7 +573,7 @@ unsigned ParallelOp::getNumDataOperands() {
 
 Value ParallelOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsync() ? 1 : 0;
-  numOptional += getNumGangs() ? 1 : 0;
+  numOptional += getNumGangs().size();
   numOptional += getNumWorkers() ? 1 : 0;
   numOptional += getVectorLength() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
@@ -590,6 +590,8 @@ LogicalResult acc::ParallelOp::verify() {
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
     return failure();
+  if (getNumGangs().size() > 3)
+    return emitOpError() << "num_gangs expects a maximum of 3 values";
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
@@ -631,12 +633,18 @@ unsigned KernelsOp::getNumDataOperands() {
 
 Value KernelsOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsync() ? 1 : 0;
+  numOptional += getWaitOperands().size();
+  numOptional += getNumGangs().size();
+  numOptional += getNumWorkers() ? 1 : 0;
+  numOptional += getVectorLength() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
-  return getOperand(getWaitOperands().size() + numOptional + i);
+  return getOperand(numOptional + i);
 }
 
 LogicalResult acc::KernelsOp::verify() {
+  if (getNumGangs().size() > 3)
+    return emitOpError() << "num_gangs expects a maximum of 3 values";
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 

diff  --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index a3d938658e0e0..2a36ccabbdd9c 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -486,3 +486,10 @@ acc.loop gang() {
   "test.openacc_dummy_op"() : () -> ()
   acc.yield
 }
+
+// -----
+
+%i64value = arith.constant 1 : i64
+// expected-error at +1 {{num_gangs expects a maximum of 3 values}}
+acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
+}

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 92d2a9271ed0d..e07ab8c3d31f2 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -443,6 +443,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
   }
   acc.parallel num_gangs(%idxValue: index) {
   }
+  acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
+  }
   acc.parallel num_workers(%i64value: i64) {
   }
   acc.parallel num_workers(%i32value: i32) {
@@ -494,6 +496,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
 // CHECK-NEXT: }
 // CHECK:      acc.parallel num_gangs([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
+// CHECK-NEXT: }
 // CHECK:      acc.parallel num_workers([[I64VALUE]] : i64) {
 // CHECK-NEXT: }
 // CHECK:      acc.parallel num_workers([[I32VALUE]] : i32) {


        


More information about the Mlir-commits mailing list