[llvm-branch-commits] [mlir] ff39c4c - Add parse / print logic to MIOpen ops.

Wen-Heng Chung via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:10 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:19-05:00
New Revision: ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc

URL: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc
DIFF: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc.diff

LOG: Add parse / print logic to MIOpen ops.

Revise test cases along the way.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
    mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
    mlir/test/Dialect/MIOpen/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
index 1304f16f3b30..8ffd66647f3f 100644
--- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
+++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
@@ -35,10 +35,36 @@ class MIOpen_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def MIOpen_Conv2DOp : MIOpen_Op<"conv2d">;
+def MIOpen_Conv2DOp :
+    MIOpen_Op<"conv2d">,
+    Arguments<(ins MemRefRankOf<[F32], [4]>,
+                   MemRefRankOf<[F32], [4]>,
+                   MemRefRankOf<[F32], [4]>)> {
+  let summary = "2D convolution";
+  let description = [{
+    The `miopen.conv2d` op computes 2D convolution.
+  }];
+}
 
-def MIOpen_TransformOp : MIOpen_Op<"transform">;
+def MIOpen_TransformOp :
+    MIOpen_Op<"transform">,
+    Arguments<(ins AnyMemRef)>,
+    Results<(outs AnyMemRef)> {
+  let summary = "Tensor transformation";
+  let description = [{
+    The `miopen.transform` op transforms tensor coordinates.
+  }];
+}
 
-def MIOpen_GridwiseGemmOp : MIOpen_Op<"gridwise_gemm">;
+def MIOpen_GridwiseGemmOp :
+    MIOpen_Op<"gridwise_gemm">,
+    Arguments<(ins MemRefRankOf<[F32], [2]>,
+                   MemRefRankOf<[F32], [2]>,
+                   MemRefRankOf<[F32], [2]>)> {
+  let summary = "Gridwise GEMM";
+  let description = [{
+    The `miopen.gridwise_gemm` op computes gridwise GEMM.
+  }];
+}
 
 #endif // MIOPEN_OPS

diff  --git a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
index 9408e17b2831..b41423435e33 100644
--- a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
@@ -40,8 +40,6 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context)
 #define GET_OP_LIST
 #include "mlir/Dialect/MIOpenOps/MIOpenOps.cpp.inc"
       >();
-
-  //addInterfaces<LoopSideEffectsInterface>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -49,11 +47,19 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context)
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseConv2DOp(OpAsmParser &parser, OperationState &result) {
-  return success();
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<Type, 3> types;
+  return failure(
+      parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonTypeList(types) ||
+      parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
 }
 
 static void print(OpAsmPrinter &p, Conv2DOp op) {
-  p << Conv2DOp::getOperationName();
+  p << op.getOperationName() << "(" << op.getOperands() << ")";
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.getOperandTypes();
 }
 
 static LogicalResult verify(Conv2DOp op) {
@@ -65,11 +71,24 @@ static LogicalResult verify(Conv2DOp op) {
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseTransformOp(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::OperandType src;
+  Type srcType, dstType;
+  return failure(
+      parser.parseLParen() ||
+      parser.parseOperand(src) ||
+      parser.parseRParen() ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonType(srcType) ||
+      parser.resolveOperand(src, srcType, result.operands) ||
+      parser.parseKeywordType("to", dstType) ||
+      parser.addTypeToList(dstType, result.types));
   return success();
 }
 
 static void print(OpAsmPrinter &p, TransformOp op) {
-  p << TransformOp::getOperationName();
+  p << op.getOperationName() << "(" << op.getOperand() << ")";
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.getOperand()->getType() << " to " << op.getType();
 }
 
 static LogicalResult verify(TransformOp op) {
@@ -81,11 +100,19 @@ static LogicalResult verify(TransformOp op) {
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseGridwiseGemmOp(OpAsmParser &parser, OperationState &result) {
-  return success();
+  SmallVector<OpAsmParser::OperandType, 3> ops;
+  SmallVector<Type, 3> types;
+  return failure(
+      parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.parseColonTypeList(types) ||
+      parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
 }
 
 static void print(OpAsmPrinter &p, GridwiseGemmOp op) {
-  p << GridwiseGemmOp::getOperationName();
+  p << op.getOperationName() << "(" << op.getOperands() << ")";
+  p.printOptionalAttrDict(op.getAttrs());
+  p << " : " << op.getOperandTypes();
 }
 
 static LogicalResult verify(GridwiseGemmOp op) {

diff  --git a/mlir/test/Dialect/MIOpen/ops.mlir b/mlir/test/Dialect/MIOpen/ops.mlir
index a37e54110186..9b3b2e3db27e 100644
--- a/mlir/test/Dialect/MIOpen/ops.mlir
+++ b/mlir/test/Dialect/MIOpen/ops.mlir
@@ -3,22 +3,132 @@
 // Run: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
 
 func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
-  miopen.conv2d
+  miopen.conv2d(%filter, %input, %output) {
+    filter_layout = ["k", "c", "y", "x"],
+    input_layout = ["n", "c", "hi", "wi"],
+    output_layout = ["n", "k", "ho", "wo"],
+    dilations = [1, 1],
+    strides = [1, 1],
+    padding = [0, 0]
+  } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
   return
 }
 // CHECK-LABEL: func @miopen_conv2d
 //  CHECK-NEXT: miopen.conv2d
 
-func @miopen_transform(%memref : memref<?x?x?x?xf32>) {
-  miopen.transform
+// test 1-1 dimension mappings.
+func @miopen_transform_1_to_1(%memref: memref<?x?x?x?xf32>) {
+  %transformed_memref = miopen.transform(%memref) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["n"],
+        transformation = "passthorugh",
+        source_dimensions = [0],
+        source_names = ["n"]
+      },
+      {
+        dimensions = [1],
+        names = ["c"],
+        transformation = "passthorugh",
+        source_dimensions = [1],
+        source_names = ["c"]
+      },
+      {
+        dimensions = [2],
+        names = ["hipad"],
+        transformation = "pad",
+        parameters = [0, 0],
+        source_dimensions = [2],
+        source_names = ["hi"]
+      },
+      {
+        dimensions = [3],
+        names = ["wipad"],
+        transformation = "pad",
+        parameters = [0, 0],
+        source_dimensions = [3],
+        source_names = ["wi"]
+      }
+    ]
+  } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @miopen_transform_1_to_1
+//  CHECK-NEXT: miopen.transform
+
+// test multiple source dimensions map to 1 target dimension.
+func @miopen_transform_n_to_1(%memref : memref<?x?x?x?xf32>) {
+  %transformed_memref = miopen.transform(%memref) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["gemmK"],
+        transformation = "merge",
+        source_dimensions = [1, 2, 3],
+        source_names = ["c", "y", "x"]
+      },
+      {
+        dimensions = [1],
+        names = ["gemmM"],
+        transformation = "passthrough",
+        source_dimensions = [0],
+        source_names = ["n"]
+      }
+    ]
+  } : memref<?x?x?x?xf32> to memref<?x?xf32>
+  return
+}
+// CHECK-LABEL: func @miopen_transform_n_to_1
+//  CHECK-NEXT: miopen.transform
+
+// test 1 source dimension map to multiple target dimensions.
+func @miopen_transform_1_to_n(%memref : memref<?x?x?x?xf32>) {
+  %transformed_memref = miopen.transform(%memref) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["n"],
+        transformation = "passthrough",
+        source_dimensions = [0],
+        source_names = ["n"]
+      },
+      {
+        dimensions = [1],
+        names = ["c"],
+        transformation = "passthrough",
+        source_dimensions = [1],
+        source_names = ["c"]
+      },
+      {
+        dimensions = [2, 3],
+        names = ["y", "ho"],
+        transformation = "embed",
+        parameters = [1, 1, 0],
+        source_dimensions = [2],
+        source_names = ["hipad"]
+      },
+      {
+        dimensions = [4, 5],
+        names = ["x", "wo"],
+        transformation = "embed",
+        parameters = [1, 1, 0],
+        source_dimensions = [3],
+        source_names = ["wipad"]
+      }
+    ]
+  } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?xf32>
   return
 }
 
-// CHECK-LABEL: func @miopen_transform
+// CHECK-LABEL: func @miopen_transform_1_to_n
 //  CHECK-NEXT: miopen.transform
 
 func @miopen_gridwise_gemm(%A : memref<?x?xf32>, %B : memref<?x?xf32>, %C : memref<?x?xf32>) {
-  miopen.gridwise_gemm
+  miopen.gridwise_gemm(%A, %B, %C) {
+    parameters = [
+    ]
+  } : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
   return
 }
 


        


More information about the llvm-branch-commits mailing list