[llvm-branch-commits] [mlir] ff39c4c - Add parse / print logic to MIOpen ops.
Wen-Heng Chung via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:10 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:19-05:00
New Revision: ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc
URL: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc
DIFF: https://github.com/llvm/llvm-project/commit/ff39c4c709ac1603d21f7baab75dbfbb13ae6fbc.diff
LOG: Add parse / print logic to MIOpen ops.
Revise test cases along the way.
Added:
Modified:
mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
mlir/test/Dialect/MIOpen/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
index 1304f16f3b30..8ffd66647f3f 100644
--- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
+++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenOps.td
@@ -35,10 +35,36 @@ class MIOpen_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
-def MIOpen_Conv2DOp : MIOpen_Op<"conv2d">;
+def MIOpen_Conv2DOp :
+ MIOpen_Op<"conv2d">,
+ Arguments<(ins MemRefRankOf<[F32], [4]>,
+ MemRefRankOf<[F32], [4]>,
+ MemRefRankOf<[F32], [4]>)> {
+ let summary = "2D convolution";
+ let description = [{
+ The `miopen.conv2d` op computes 2D convolution.
+ }];
+}
-def MIOpen_TransformOp : MIOpen_Op<"transform">;
+def MIOpen_TransformOp :
+ MIOpen_Op<"transform">,
+ Arguments<(ins AnyMemRef)>,
+ Results<(outs AnyMemRef)> {
+ let summary = "Tensor transformation";
+ let description = [{
+ The `miopen.transform` op transforms tensor coordinates.
+ }];
+}
-def MIOpen_GridwiseGemmOp : MIOpen_Op<"gridwise_gemm">;
+def MIOpen_GridwiseGemmOp :
+ MIOpen_Op<"gridwise_gemm">,
+ Arguments<(ins MemRefRankOf<[F32], [2]>,
+ MemRefRankOf<[F32], [2]>,
+ MemRefRankOf<[F32], [2]>)> {
+ let summary = "Gridwise GEMM";
+ let description = [{
+ The `miopen.gridwise_gemm` op computes gridwise GEMM.
+ }];
+}
#endif // MIOPEN_OPS
diff --git a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
index 9408e17b2831..b41423435e33 100644
--- a/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/MIOpenOps.cpp
@@ -40,8 +40,6 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/MIOpenOps/MIOpenOps.cpp.inc"
>();
-
- //addInterfaces<LoopSideEffectsInterface>();
}
//===----------------------------------------------------------------------===//
@@ -49,11 +47,19 @@ MIOpenOpsDialect::MIOpenOpsDialect(MLIRContext *context)
//===----------------------------------------------------------------------===//
static ParseResult parseConv2DOp(OpAsmParser &parser, OperationState &result) {
- return success();
+ SmallVector<OpAsmParser::OperandType, 3> ops;
+ SmallVector<Type, 3> types;
+ return failure(
+ parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonTypeList(types) ||
+ parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}
static void print(OpAsmPrinter &p, Conv2DOp op) {
- p << Conv2DOp::getOperationName();
+ p << op.getOperationName() << "(" << op.getOperands() << ")";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperandTypes();
}
static LogicalResult verify(Conv2DOp op) {
@@ -65,11 +71,24 @@ static LogicalResult verify(Conv2DOp op) {
//===----------------------------------------------------------------------===//
static ParseResult parseTransformOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType src;
+ Type srcType, dstType;
+ return failure(
+ parser.parseLParen() ||
+ parser.parseOperand(src) ||
+ parser.parseRParen() ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonType(srcType) ||
+ parser.resolveOperand(src, srcType, result.operands) ||
+ parser.parseKeywordType("to", dstType) ||
+ parser.addTypeToList(dstType, result.types));
return success();
}
static void print(OpAsmPrinter &p, TransformOp op) {
- p << TransformOp::getOperationName();
+ p << op.getOperationName() << "(" << op.getOperand() << ")";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperand()->getType() << " to " << op.getType();
}
static LogicalResult verify(TransformOp op) {
@@ -81,11 +100,19 @@ static LogicalResult verify(TransformOp op) {
//===----------------------------------------------------------------------===//
static ParseResult parseGridwiseGemmOp(OpAsmParser &parser, OperationState &result) {
- return success();
+ SmallVector<OpAsmParser::OperandType, 3> ops;
+ SmallVector<Type, 3> types;
+ return failure(
+ parser.parseOperandList(ops, OpAsmParser::Delimiter::Paren) ||
+ parser.parseOptionalAttrDict(result.attributes) ||
+ parser.parseColonTypeList(types) ||
+ parser.resolveOperands(ops, types, parser.getNameLoc(), result.operands));
}
static void print(OpAsmPrinter &p, GridwiseGemmOp op) {
- p << GridwiseGemmOp::getOperationName();
+ p << op.getOperationName() << "(" << op.getOperands() << ")";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getOperandTypes();
}
static LogicalResult verify(GridwiseGemmOp op) {
diff --git a/mlir/test/Dialect/MIOpen/ops.mlir b/mlir/test/Dialect/MIOpen/ops.mlir
index a37e54110186..9b3b2e3db27e 100644
--- a/mlir/test/Dialect/MIOpen/ops.mlir
+++ b/mlir/test/Dialect/MIOpen/ops.mlir
@@ -3,22 +3,132 @@
// Run: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
- miopen.conv2d
+ miopen.conv2d(%filter, %input, %output) {
+ filter_layout = ["k", "c", "y", "x"],
+ input_layout = ["n", "c", "hi", "wi"],
+ output_layout = ["n", "k", "ho", "wo"],
+ dilations = [1, 1],
+ strides = [1, 1],
+ padding = [0, 0]
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
// CHECK-LABEL: func @miopen_conv2d
// CHECK-NEXT: miopen.conv2d
-func @miopen_transform(%memref : memref<?x?x?x?xf32>) {
- miopen.transform
+// test 1-1 dimension mappings.
+func @miopen_transform_1_to_1(%memref: memref<?x?x?x?xf32>) {
+ %transformed_memref = miopen.transform(%memref) {
+ layout = [
+ {
+ dimensions = [0],
+ names = ["n"],
+ transformation = "passthorugh",
+ source_dimensions = [0],
+ source_names = ["n"]
+ },
+ {
+ dimensions = [1],
+ names = ["c"],
+ transformation = "passthorugh",
+ source_dimensions = [1],
+ source_names = ["c"]
+ },
+ {
+ dimensions = [2],
+ names = ["hipad"],
+ transformation = "pad",
+ parameters = [0, 0],
+ source_dimensions = [2],
+ source_names = ["hi"]
+ },
+ {
+ dimensions = [3],
+ names = ["wipad"],
+ transformation = "pad",
+ parameters = [0, 0],
+ source_dimensions = [3],
+ source_names = ["wi"]
+ }
+ ]
+ } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @miopen_transform_1_to_1
+// CHECK-NEXT: miopen.transform
+
+// test multiple source dimensions map to 1 target dimension.
+func @miopen_transform_n_to_1(%memref : memref<?x?x?x?xf32>) {
+ %transformed_memref = miopen.transform(%memref) {
+ layout = [
+ {
+ dimensions = [0],
+ names = ["gemmK"],
+ transformation = "merge",
+ source_dimensions = [1, 2, 3],
+ source_names = ["c", "y", "x"]
+ },
+ {
+ dimensions = [1],
+ names = ["gemmM"],
+ transformation = "passthrough",
+ source_dimensions = [0],
+ source_names = ["n"]
+ }
+ ]
+ } : memref<?x?x?x?xf32> to memref<?x?xf32>
+ return
+}
+// CHECK-LABEL: func @miopen_transform_n_to_1
+// CHECK-NEXT: miopen.transform
+
+// test 1 source dimension map to multiple target dimensions.
+func @miopen_transform_1_to_n(%memref : memref<?x?x?x?xf32>) {
+ %transformed_memref = miopen.transform(%memref) {
+ layout = [
+ {
+ dimensions = [0],
+ names = ["n"],
+ transformation = "passthrough",
+ source_dimensions = [0],
+ source_names = ["n"]
+ },
+ {
+ dimensions = [1],
+ names = ["c"],
+ transformation = "passthrough",
+ source_dimensions = [1],
+ source_names = ["c"]
+ },
+ {
+ dimensions = [2, 3],
+ names = ["y", "ho"],
+ transformation = "embed",
+ parameters = [1, 1, 0],
+ source_dimensions = [2],
+ source_names = ["hipad"]
+ },
+ {
+ dimensions = [4, 5],
+ names = ["x", "wo"],
+ transformation = "embed",
+ parameters = [1, 1, 0],
+ source_dimensions = [3],
+ source_names = ["wipad"]
+ }
+ ]
+ } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?xf32>
return
}
-// CHECK-LABEL: func @miopen_transform
+// CHECK-LABEL: func @miopen_transform_1_to_n
// CHECK-NEXT: miopen.transform
func @miopen_gridwise_gemm(%A : memref<?x?xf32>, %B : memref<?x?xf32>, %C : memref<?x?xf32>) {
- miopen.gridwise_gemm
+ miopen.gridwise_gemm(%A, %B, %C) {
+ parameters = [
+ ]
+ } : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return
}
More information about the llvm-branch-commits
mailing list