[Mlir-commits] [mlir] 22dde1f - [mlir][openacc] Support Index and AnyInteger in loop op

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Sep 18 08:37:56 PDT 2020


Author: Valentin Clement
Date: 2020-09-18T11:37:49-04:00
New Revision: 22dde1f92f68b4249dbae30c119972a17753236a

URL: https://github.com/llvm/llvm-project/commit/22dde1f92f68b4249dbae30c119972a17753236a
DIFF: https://github.com/llvm/llvm-project/commit/22dde1f92f68b4249dbae30c119972a17753236a.diff

LOG: [mlir][openacc] Support Index and AnyInteger in loop op

Following patch D87712, this patch switch AnyInteger for operands gangNum, gangStatic,
workerNum, vectoreLength and tileOperands to Index and AnyInteger.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87848

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 3fa26f932bd9..bd685f90ad4a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -244,14 +244,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
 
   let arguments = (ins OptionalAttr<I64Attr>:$collapse,
-                       Optional<AnyInteger>:$gangNum,
-                       Optional<AnyInteger>:$gangStatic,
-                       Optional<AnyInteger>:$workerNum,
-                       Optional<AnyInteger>:$vectorLength,
+                       Optional<IntOrIndex>:$gangNum,
+                       Optional<IntOrIndex>:$gangStatic,
+                       Optional<IntOrIndex>:$workerNum,
+                       Optional<IntOrIndex>:$vectorLength,
                        UnitAttr:$seq,
                        UnitAttr:$independent,
                        UnitAttr:$auto_,
-                       Variadic<AnyInteger>:$tileOperands,
+                       Variadic<IntOrIndex>:$tileOperands,
                        Variadic<AnyType>:$privateOperands,
                        OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
                        Variadic<AnyType>:$reductionOperands,

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 3cae3c8feb8f..efd7f866c491 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -545,10 +545,14 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
 //===----------------------------------------------------------------------===//
 
 /// Parse acc.loop operation
-/// operation := `acc.loop` `gang`? `vector`? `worker`?
-///                         `private` `(` value-list `)`?
-///                         `reduction` `(` value-list `)`?
-///                         region attr-dict?
+/// operation := `acc.loop`
+///              (`gang` ( `(` (`num=` value)? (`,` `static=` value `)`)? )? )?
+///              (`vector` ( `(` value `)` )? )? (`worker` (`(` value `)`)? )?
+///              (`vector_length` `(` value `)`)?
+///              (`tile` `(` value-list `)`)?
+///              (`private` `(` value-list `)`)?
+///              (`reduction` `(` value-list `)`)?
+///              region attr-dict?
 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
   Builder &builder = parser.getBuilder();
   unsigned executionMapping = OpenACCExecMapping::NONE;
@@ -558,7 +562,7 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
   bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
   bool hasGangStatic = false;
   OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
-  Type intType = builder.getI64Type();
+  Type gangNumType, gangStaticType, workerType, vectorLengthType;
 
   // gang?
   if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
@@ -568,9 +572,9 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
   if (succeeded(parser.parseOptionalLParen())) {
     if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
       hasGangNum = true;
-      parser.parseColon();
-      if (parser.parseOperand(gangNum) ||
-          parser.resolveOperand(gangNum, intType, result.operands)) {
+      parser.parseEqual();
+      if (parser.parseOperand(gangNum) || parser.parseColonType(gangNumType) ||
+          parser.resolveOperand(gangNum, gangNumType, result.operands)) {
         return failure();
       }
     }
@@ -578,9 +582,10 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
     if (succeeded(
             parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
       hasGangStatic = true;
-      parser.parseColon();
+      parser.parseEqual();
       if (parser.parseOperand(gangStatic) ||
-          parser.resolveOperand(gangStatic, intType, result.operands)) {
+          parser.parseColonType(gangStaticType) ||
+          parser.resolveOperand(gangStatic, gangStaticType, result.operands)) {
         return failure();
       }
     }
@@ -595,8 +600,8 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
   // optional worker operand
   if (succeeded(parser.parseOptionalLParen())) {
     hasWorkerNum = true;
-    if (parser.parseOperand(workerNum) ||
-        parser.resolveOperand(workerNum, intType, result.operands) ||
+    if (parser.parseOperand(workerNum) || parser.parseColonType(workerType) ||
+        parser.resolveOperand(workerNum, workerType, result.operands) ||
         parser.parseRParen()) {
       return failure();
     }
@@ -610,7 +615,9 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
   if (succeeded(parser.parseOptionalLParen())) {
     hasVectorLength = true;
     if (parser.parseOperand(vectorLength) ||
-        parser.resolveOperand(vectorLength, intType, result.operands) ||
+        parser.parseColonType(vectorLengthType) ||
+        parser.resolveOperand(vectorLength, vectorLengthType,
+                              result.operands) ||
         parser.parseRParen()) {
       return failure();
     }
@@ -671,12 +678,14 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
     if (gangNum || gangStatic) {
       printer << "(";
       if (gangNum) {
-        printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
+        printer << LoopOp::getGangNumKeyword() << "=" << gangNum << ": "
+                << gangNum.getType();
         if (gangStatic)
           printer << ", ";
       }
       if (gangStatic)
-        printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
+        printer << LoopOp::getGangStaticKeyword() << "=" << gangStatic << ": "
+                << gangStatic.getType();
       printer << ")";
     }
   }
@@ -686,7 +695,7 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
 
     // Print optional worker operand if present
     if (Value workerNum = op.workerNum())
-      printer << "(" << workerNum << ")";
+      printer << "(" << workerNum << ": " << workerNum.getType() << ")";
   }
 
   if (execMapping & OpenACCExecMapping::VECTOR) {
@@ -694,7 +703,7 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
 
     // Print optional vector operand if present
     if (Value vectorLength = op.vectorLength())
-      printer << "(" << vectorLength << ")";
+      printer << "(" << vectorLength << ": " << vectorLength.getType() << ")";
   }
 
   // tile()?

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 196949839db4..07ec198b4736 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1,8 +1,8 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | FileCheck %s
 // Verify the printed output can be parsed.
-// RUN: mlir-opt -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect  | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect  | FileCheck %s
 // Verify the generic form can be parsed.
-// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -mlir-print-op-generic %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
 
 func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
   %c0 = constant 0 : index
@@ -58,6 +58,8 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:   return %{{.*}} : memref<10x10xf32>
 //  CHECK-NEXT: }
 
+// -----
+
 func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf32>) -> memref<10x10xf32> {
   %c0 = constant 0 : index
   %c10 = constant 10 : index
@@ -110,6 +112,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:   return %{{.*}} : memref<10x10xf32>
 //  CHECK-NEXT: }
 
+// -----
 
 func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
   %lb = constant 0 : index
@@ -192,85 +195,133 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:   return %{{.*}} : memref<10xf32>
 // CHECK-NEXT: }
 
-func @testop(%a: memref<10xf32>) -> () {
-  %workerNum = constant 1 : i64
-  %vectorLength = constant 128 : i64
-  %gangNum = constant 8 : i64
-  %gangStatic = constant 2 : i64
-  %tileSize = constant 2 : i64
+// -----
+
+func @testloopop() -> () {
+  %i64Value = constant 1 : i64
+  %i32Value = constant 128 : i32
+  %idxValue = constant 8 : index
+
   acc.loop gang worker vector {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop gang(num: %gangNum) {
+  acc.loop gang(num=%i64Value: i64) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop gang(static=%i64Value: i64) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop worker(%i64Value: i64) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop worker(%i32Value: i32) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop worker(%idxValue: index) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop vector(%i64Value: i64) {
+    "some.op"() : () -> ()
+    acc.yield
+  }
+  acc.loop vector(%i32Value: i32) {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop gang(static: %gangStatic) {
+  acc.loop vector(%idxValue: index) {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop worker(%workerNum) {
+  acc.loop gang(num=%i64Value: i64) worker vector {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop vector(%vectorLength) {
+  acc.loop gang(num=%i64Value: i64, static=%i64Value: i64) worker(%i64Value: i64) vector(%i64Value: i64) {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop gang(num: %gangNum) worker vector {
+  acc.loop gang(num=%i32Value: i32, static=%idxValue: index) {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
+  acc.loop tile(%i64Value: i64, %i64Value: i64) {
     "some.op"() : () -> ()
     acc.yield
   }
-  acc.loop tile(%tileSize : i64, %tileSize : i64) {
+  acc.loop tile(%i32Value: i32, %i32Value: i32) {
     "some.op"() : () -> ()
     acc.yield
   }
   return
 }
 
-// CHECK:      [[WORKERNUM:%.*]] = constant 1 : i64
-// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
-// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
-// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
-// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
-// CHECK-NEXT: acc.loop gang worker vector {
+// CHECK:      [[I64VALUE:%.*]] = constant 1 : i64
+// CHECK-NEXT: [[I32VALUE:%.*]] = constant 128 : i32
+// CHECK-NEXT: [[IDXVALUE:%.*]] = constant 8 : index
+// CHECK:      acc.loop gang worker vector {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop gang(num=[[I64VALUE]]: i64) {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop gang(static=[[I64VALUE]]: i64) {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop worker([[I64VALUE]]: i64) {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop worker([[I32VALUE]]: i32) {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop worker([[IDXVALUE]]: index) {
+// CHECK-NEXT:   "some.op"() : () -> ()
+// CHECK-NEXT:   acc.yield
+// CHECK-NEXT: }
+// CHECK:      acc.loop vector([[I64VALUE]]: i64) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
+// CHECK:      acc.loop vector([[I32VALUE]]: i32) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
+// CHECK:      acc.loop vector([[IDXVALUE]]: index) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
+// CHECK:      acc.loop gang(num=[[I64VALUE]]: i64) worker vector {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
+// CHECK:      acc.loop gang(num=[[I64VALUE]]: i64, static=[[I64VALUE]]: i64) worker([[I64VALUE]]: i64) vector([[I64VALUE]]: i64) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
+// CHECK:      acc.loop gang(num=[[I32VALUE]]: i32, static=[[IDXVALUE]]: index) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
+// CHECK:      acc.loop tile([[I64VALUE]]: i64, [[I64VALUE]]: i64) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
+// CHECK:      acc.loop tile([[I32VALUE]]: i32, [[I32VALUE]]: i32) {
 // CHECK-NEXT:   "some.op"() : () -> ()
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
 
+// -----
 
 func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
   %i64value = constant 1 : i64


        


More information about the Mlir-commits mailing list