[Mlir-commits] [mlir] 2bbbcae - [mlir][openacc] Add missing attributes and operands for acc.loop

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 31 16:50:13 PDT 2020


Author: Valentin Clement
Date: 2020-08-31T19:50:05-04:00
New Revision: 2bbbcae782adbea20ae50f9f5471056a91498ffc

URL: https://github.com/llvm/llvm-project/commit/2bbbcae782adbea20ae50f9f5471056a91498ffc
DIFF: https://github.com/llvm/llvm-project/commit/2bbbcae782adbea20ae50f9f5471056a91498ffc.diff

LOG: [mlir][openacc] Add missing attributes and operands for acc.loop

This patch add the missing operands to the acc.loop operation. Only the device_type
information is not part of the operation for now.

Reviewed By: rriddle, kiranchandramohan

Differential Revision: https://reviews.llvm.org/D86753

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACC.h
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 399db74fae2c..8f5e1daf9aeb 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -31,12 +31,11 @@ namespace acc {
 /// 2.9.2. gang
 /// 2.9.3. worker
 /// 2.9.4. vector
-/// 2.9.5. seq
 ///
 /// Value can be combined bitwise to reflect the mapping applied to the
 /// construct. e.g. `acc.loop gang vector`, the `gang` and `vector` could be
 /// combined and the final mapping value would be 5 (4 | 1).
-enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4, SEQ = 8 };
+enum OpenACCExecMapping { NONE = 0, VECTOR = 1, WORKER = 2, GANG = 4 };
 
 } // end namespace acc
 } // end namespace mlir

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 8dd0f849ddd2..30d6f435b75f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -224,6 +224,14 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
 
   let arguments = (ins OptionalAttr<I64Attr>:$collapse,
+                       Optional<AnyInteger>:$gangNum,
+                       Optional<AnyInteger>:$gangStatic,
+                       Optional<AnyInteger>:$workerNum,
+                       Optional<AnyInteger>:$vectorLength,
+                       UnitAttr:$loopSeq,
+                       UnitAttr:$loopIndependent,
+                       UnitAttr:$loopAuto,
+                       Variadic<AnyInteger>:$tileOperands,
                        Variadic<AnyType>:$privateOperands,
                        OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
                        Variadic<AnyType>:$reductionOperands);
@@ -234,11 +242,16 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
   let extraClassDeclaration = [{
     static StringRef getCollapseAttrName() { return "collapse"; }
-    static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
-    static StringRef getGangAttrName() { return "gang"; }
     static StringRef getSeqAttrName() { return "seq"; }
-    static StringRef getVectorAttrName() { return "vector"; }
-    static StringRef getWorkerAttrName() { return "worker"; }
+    static StringRef getIndependentAttrName() { return "independent"; }
+    static StringRef getAutoAttrName() { return "auto"; }
+    static StringRef getExecutionMappingAttrName() { return "exec_mapping"; }
+    static StringRef getGangKeyword() { return "gang"; }
+    static StringRef getGangNumKeyword() { return "num"; }
+    static StringRef getGangStaticKeyword() { return "static"; }
+    static StringRef getVectorKeyword() { return "vector"; }
+    static StringRef getWorkerKeyword() { return "worker"; }
+    static StringRef getTileKeyword() { return "tile"; }
     static StringRef getPrivateKeyword() { return "private"; }
     static StringRef getReductionKeyword() { return "reduction"; }
   }];

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6ac411d456af..b5dfa2c13358 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -476,32 +476,81 @@ static void print(OpAsmPrinter &printer, DataOp &op) {
 //===----------------------------------------------------------------------===//
 
 /// Parse acc.loop operation
-/// operation := `acc.loop` `gang`? `vector`? `worker`? `seq`?
+/// operation := `acc.loop` `gang`? `vector`? `worker`?
 ///                         `private` `(` value-list `)`?
 ///                         `reduction` `(` value-list `)`?
 ///                         region attr-dict?
 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
-
   Builder &builder = parser.getBuilder();
   unsigned executionMapping = 0;
   SmallVector<Type, 8> operandTypes;
   SmallVector<OpAsmParser::OperandType, 8> privateOperands, reductionOperands;
+  SmallVector<OpAsmParser::OperandType, 8> tileOperands;
+  bool hasWorkerNum = false, hasVectorLength = false, hasGangNum = false;
+  bool hasGangStatic = false;
+  OpAsmParser::OperandType workerNum, vectorLength, gangNum, gangStatic;
+  Type intType = builder.getI64Type();
 
   // gang?
-  if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangAttrName())))
+  if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangKeyword())))
     executionMapping |= OpenACCExecMapping::GANG;
 
-  // vector?
-  if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorAttrName())))
-    executionMapping |= OpenACCExecMapping::VECTOR;
+  // optional gang operand
+  if (succeeded(parser.parseOptionalLParen())) {
+    if (succeeded(parser.parseOptionalKeyword(LoopOp::getGangNumKeyword()))) {
+      hasGangNum = true;
+      parser.parseColon();
+      if (parser.parseOperand(gangNum) ||
+          parser.resolveOperand(gangNum, intType, result.operands)) {
+        return failure();
+      }
+    }
+    parser.parseOptionalComma();
+    if (succeeded(
+            parser.parseOptionalKeyword(LoopOp::getGangStaticKeyword()))) {
+      hasGangStatic = true;
+      parser.parseColon();
+      if (parser.parseOperand(gangStatic) ||
+          parser.resolveOperand(gangStatic, intType, result.operands)) {
+        return failure();
+      }
+    }
+    if (failed(parser.parseRParen()))
+      return failure();
+  }
 
   // worker?
-  if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerAttrName())))
+  if (succeeded(parser.parseOptionalKeyword(LoopOp::getWorkerKeyword())))
     executionMapping |= OpenACCExecMapping::WORKER;
 
-  // seq?
-  if (succeeded(parser.parseOptionalKeyword(LoopOp::getSeqAttrName())))
-    executionMapping |= OpenACCExecMapping::SEQ;
+  // optional worker operand
+  if (succeeded(parser.parseOptionalLParen())) {
+    hasWorkerNum = true;
+    if (parser.parseOperand(workerNum) ||
+        parser.resolveOperand(workerNum, intType, result.operands) ||
+        parser.parseRParen()) {
+      return failure();
+    }
+  }
+
+  // vector?
+  if (succeeded(parser.parseOptionalKeyword(LoopOp::getVectorKeyword())))
+    executionMapping |= OpenACCExecMapping::VECTOR;
+
+  // optional vector operand
+  if (succeeded(parser.parseOptionalLParen())) {
+    hasVectorLength = true;
+    if (parser.parseOperand(vectorLength) ||
+        parser.resolveOperand(vectorLength, intType, result.operands) ||
+        parser.parseRParen()) {
+      return failure();
+    }
+  }
+
+  // tile()?
+  if (failed(parseOperandList(parser, LoopOp::getTileKeyword(), tileOperands,
+                              operandTypes, result)))
+    return failure();
 
   // private()?
   if (failed(parseOperandList(parser, LoopOp::getPrivateKeyword(),
@@ -526,7 +575,12 @@ static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &result) {
 
   result.addAttribute(LoopOp::getOperandSegmentSizeAttr(),
                       builder.getI32VectorAttr(
-                          {static_cast<int32_t>(privateOperands.size()),
+                          {static_cast<int32_t>(hasGangNum ? 1 : 0),
+                           static_cast<int32_t>(hasGangStatic ? 1 : 0),
+                           static_cast<int32_t>(hasWorkerNum ? 1 : 0),
+                           static_cast<int32_t>(hasVectorLength ? 1 : 0),
+                           static_cast<int32_t>(tileOperands.size()),
+                           static_cast<int32_t>(privateOperands.size()),
                            static_cast<int32_t>(reductionOperands.size())}));
 
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -544,17 +598,44 @@ static void print(OpAsmPrinter &printer, LoopOp &op) {
           ? op.getAttrOfType<IntegerAttr>(LoopOp::getExecutionMappingAttrName())
                 .getInt()
           : 0;
-  if ((execMapping & OpenACCExecMapping::GANG) == OpenACCExecMapping::GANG)
-    printer << " " << LoopOp::getGangAttrName();
 
-  if ((execMapping & OpenACCExecMapping::WORKER) == OpenACCExecMapping::WORKER)
-    printer << " " << LoopOp::getWorkerAttrName();
+  if (execMapping & OpenACCExecMapping::GANG) {
+    printer << " " << LoopOp::getGangKeyword();
+    Value gangNum = op.gangNum();
+    Value gangStatic = op.gangStatic();
+
+    // Print optional gang operands
+    if (gangNum || gangStatic) {
+      printer << "(";
+      if (gangNum) {
+        printer << LoopOp::getGangNumKeyword() << ": " << gangNum;
+        if (gangStatic)
+          printer << ", ";
+      }
+      if (gangStatic)
+        printer << LoopOp::getGangStaticKeyword() << ": " << gangStatic;
+      printer << ")";
+    }
+  }
 
-  if ((execMapping & OpenACCExecMapping::VECTOR) == OpenACCExecMapping::VECTOR)
-    printer << " " << LoopOp::getVectorAttrName();
+  if (execMapping & OpenACCExecMapping::WORKER) {
+    printer << " " << LoopOp::getWorkerKeyword();
+
+    // Print optional worker operand if present
+    if (Value workerNum = op.workerNum())
+      printer << "(" << workerNum << ")";
+  }
+
+  if (execMapping & OpenACCExecMapping::VECTOR) {
+    printer << " " << LoopOp::getVectorKeyword();
+
+    // Print optional vector operand if present
+    if (Value vectorLength = op.vectorLength())
+      printer << "(" << vectorLength << ")";
+  }
 
-  if ((execMapping & OpenACCExecMapping::SEQ) == OpenACCExecMapping::SEQ)
-    printer << " " << LoopOp::getSeqAttrName();
+  // tile()?
+  printOperandList(op.tileOperands(), LoopOp::getTileKeyword(), printer);
 
   // private()?
   printOperandList(op.privateOperands(), LoopOp::getPrivateKeyword(), printer);

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 89431d8926e2..6cdba227d5da 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -62,7 +62,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
   %c1 = constant 1 : index
 
   acc.parallel {
-    acc.loop seq {
+    acc.loop {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
           scf.for %arg5 = %c0 to %c10 step %c1 {
@@ -76,7 +76,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
         }
       }
       acc.yield
-    }
+    } attributes {seq}
     acc.yield
   }
 
@@ -88,7 +88,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:   %{{.*}} = constant 10 : index
 //  CHECK-NEXT:   %{{.*}} = constant 1 : index
 //  CHECK-NEXT:   acc.parallel {
-//  CHECK-NEXT:     acc.loop seq {
+//  CHECK-NEXT:     acc.loop {
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:         scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:           scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -102,7 +102,7 @@ func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:         }
 //  CHECK-NEXT:       }
 //  CHECK-NEXT:       acc.yield
-//  CHECK-NEXT:     }
+//  CHECK-NEXT:     } attributes {seq}
 //  CHECK-NEXT:     acc.yield
 //  CHECK-NEXT:   }
 //  CHECK-NEXT:   return %{{.*}} : memref<10x10xf32>
@@ -128,7 +128,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
             acc.yield
           }
 
-          acc.loop seq {
+          acc.loop {
             // for i = 0 to 10 step 1
             //   d[x] += c[i]
             scf.for %i = %lb to %c10 step %st {
@@ -138,7 +138,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
               store %z, %d[%x] : memref<10xf32>
             }
             acc.yield
-          }
+          } attributes {seq}
         }
         acc.yield
       }
@@ -167,7 +167,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:             }
 // CHECK-NEXT:             acc.yield
 // CHECK-NEXT:           }
-// CHECK-NEXT:           acc.loop seq {
+// CHECK-NEXT:           acc.loop {
 // CHECK-NEXT:             scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:               %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
 // CHECK-NEXT:               %{{.*}} = load %{{.*}}[%{{.*}}] : memref<10xf32>
@@ -175,7 +175,7 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:               store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
 // CHECK-NEXT:             }
 // CHECK-NEXT:             acc.yield
-// CHECK-NEXT:           }
+// CHECK-NEXT:           } attributes {seq}
 // CHECK-NEXT:         }
 // CHECK-NEXT:         acc.yield
 // CHECK-NEXT:       }
@@ -184,4 +184,51 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:     acc.terminator
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return %{{.*}} : memref<10xf32>
-// CHECK-NEXT: }
\ No newline at end of file
+// CHECK-NEXT: }
+
+func @testop() -> () {
+  %workerNum = constant 1 : i64
+  %vectorLength = constant 128 : i64
+  %gangNum = constant 8 : i64
+  %gangStatic = constant 2 : i64
+  %tileSize = constant 2 : i64
+  acc.loop gang worker vector {
+  }
+  acc.loop gang(num: %gangNum) {
+  }
+  acc.loop gang(static: %gangStatic) {
+  }
+  acc.loop worker(%workerNum) {
+  }
+  acc.loop vector(%vectorLength) {
+  }
+  acc.loop gang(num: %gangNum) worker vector {
+  }
+  acc.loop gang(num: %gangNum, static: %gangStatic) worker(%workerNum) vector(%vectorLength) {
+  }
+  acc.loop tile(%tileSize : i64, %tileSize : i64) {
+  }
+  return
+}
+
+// CHECK:      [[WORKERNUM:%.*]] = constant 1 : i64
+// CHECK-NEXT: [[VECTORLENGTH:%.*]] = constant 128 : i64
+// CHECK-NEXT: [[GANGNUM:%.*]] = constant 8 : i64
+// CHECK-NEXT: [[GANGSTATIC:%.*]] = constant 2 : i64
+// CHECK-NEXT: [[TILESIZE:%.*]] = constant 2 : i64
+// CHECK-NEXT: acc.loop gang worker vector {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(static: [[GANGSTATIC]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop worker([[WORKERNUM]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop vector([[VECTORLENGTH]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]]) worker vector {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop gang(num: [[GANGNUM]], static: [[GANGSTATIC]]) worker([[WORKERNUM]]) vector([[VECTORLENGTH]]) {
+// CHECK-NEXT: }
+// CHECK-NEXT: acc.loop tile([[TILESIZE]]: i64, [[TILESIZE]]: i64) {
+// CHECK-NEXT: }


        


More information about the Mlir-commits mailing list