[llvm-branch-commits] [mlir] a1e3fec - Generalized op transformation logic for output tensor.
Wen-Heng Chung via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:33 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:20-05:00
New Revision: a1e3fec79420164b7cd398872d525f03c4436e96
URL: https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96
DIFF: https://github.com/llvm/llvm-project/commit/a1e3fec79420164b7cd398872d525f03c4436e96.diff
LOG: Generalized op transformation logic for output tensor.
Add more op lowering test cases.
Added:
mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir
Modified:
mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
index 46083be58a35..f1d8e914c3ec 100644
--- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
@@ -66,7 +66,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs;
- // TBD: set layout attribute.
+ // set layout attribute.
// Weight tensor transformation:
// - Part 1: Merge non-K dimensions to dimension 0, name it as gemmK.
// - Part 2: PassThrough K dimension to dimension 1, name it as gemmM.
@@ -414,7 +414,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
StringAttr::get("wo", op.getContext())
}, op.getContext()));
transformedInputAttrs.push_back(transformedInputImmLayoutAttr);
- // TBD: set output_layout attribute.
+ // set output_layout attribute.
auto transformedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
ArrayAttr::get({
StringAttr::get("gemmK", op.getContext()),
@@ -442,49 +442,59 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs;
- // TBD: set layout attribute.
- // TBD: Part 1: Passthrough.
- llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
- transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
- transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
- transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
- transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
- ArrayAttr::get({
- IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
- }, op.getContext())));
- transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
- ArrayAttr::get({
- StringAttr::get("ko", op.getContext())
- }, op.getContext())));
+ // set layout attribute.
+ // Weight tensor transformation:
+ // - Part 1: PassThrough K dimension to dimension 0, name it as gemmM.
+ // - Part 2: Merge non-K dimensions to dimension 1, name it as gemmN.
+ {
+ llvm::SmallVector<IntegerAttr, 3> nonKDims;
+ IntegerAttr kDim;
+ llvm::SmallVector<StringAttr, 3> nonKDimNames;
+ StringAttr kDimName;
+ for (unsigned i = 0; i < outputLayoutAttr.size(); ++i) {
+ if (auto strAttr = outputLayoutAttr.getValue()[i].dyn_cast<StringAttr>()) {
+ if (strAttr.getValue() == "ko") {
+ kDim = IntegerAttr::get(IntegerType::get(32, op.getContext()), i);
+ kDimName = StringAttr::get(strAttr.getValue(), op.getContext());
+ } else {
+ nonKDims.push_back(IntegerAttr::get(IntegerType::get(32, op.getContext()), i));
+ nonKDimNames.push_back(StringAttr::get(strAttr.getValue(), op.getContext()));
+ }
+ }
+ }
+
+ // Part 1: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({kDim}, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({kDimName}, op.getContext())));
- // TBD: Part 2: Merge.
- llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
- transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
- transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
- transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
- transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
- ArrayAttr::get({
- IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
- IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
- IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
- }, op.getContext())));
- transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
- ArrayAttr::get({ StringAttr::get("no", op.getContext()),
- StringAttr::get("ho", op.getContext()),
- StringAttr::get("wo", op.getContext())
- }, op.getContext())));
+ // Part 2: Merge.
+ llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get(ArrayRef<Attribute>(nonKDims.begin(), nonKDims.end()), op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get(ArrayRef<Attribute>(nonKDimNames.begin(), nonKDimNames.end()), op.getContext())));
- auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
- ArrayAttr::get({
- DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
- DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
- }, op.getContext()));
- transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
+ auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
+ }, op.getContext()));
+ transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
+ }
// set source_layout attribute.
auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", outputLayoutAttr);
transformedOutputAttrs.push_back(outputSrcLayoutAttr);
- // TBD: set output_layout attribute.
+ // set output_layout attribute.
auto transformedOutputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
ArrayAttr::get({
StringAttr::get("gemmM", op.getContext()),
@@ -492,7 +502,7 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
}, op.getContext()));
transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr);
- // TBD: set gridwise_gemm_argument_pos attribute.
+ // set gridwise_gemm_argument_pos attribute.
auto outputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position",
IntegerAttr::get(IntegerType::get(32, op.getContext()), 2));
transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr);
diff --git a/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
new file mode 100644
index 000000000000..4f6222237d59
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_ckyx_cnhw_knhw.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_ckyx_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+ miopen.conv2d(%filter, %input, %output) {
+ filter_layout = ["c", "k", "y", "x"],
+ input_layout = ["ci", "ni", "hi", "wi"],
+ output_layout = ["ko", "no", "ho", "wo"],
+ dilations = [1, 1],
+ strides = [1, 1],
+ padding = [0, 0]
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @miopen_conv2d
+// CHECK-NOT: miopen.conv2d
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.gridwise_gemm
diff --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
new file mode 100644
index 000000000000..d5d0d9836bbd
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_chwn_khwn.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_cyxk_chwn_khwn(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+ miopen.conv2d(%filter, %input, %output) {
+ filter_layout = ["c", "y", "x", "k"],
+ input_layout = ["ci", "hi", "wi", "ni"],
+ output_layout = ["ko", "ho", "wo", "no"],
+ dilations = [1, 1],
+ strides = [1, 1],
+ padding = [0, 0]
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @miopen_conv2d
+// CHECK-NOT: miopen.conv2d
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.gridwise_gemm
diff --git a/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir
new file mode 100644
index 000000000000..cec6463d783c
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/lowering_cyxk_cnhw_knhw.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -miopen-lowering %s | FileCheck %s
+
+func @miopen_conv2d_cyxk_cnhw_knhw(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+ miopen.conv2d(%filter, %input, %output) {
+ filter_layout = ["c", "y", "x", "k"],
+ input_layout = ["ci", "ni", "hi", "wi"],
+ output_layout = ["ko", "no", "ho", "wo"],
+ dilations = [1, 1],
+ strides = [1, 1],
+ padding = [0, 0]
+ } : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
+ return
+}
+// CHECK-LABEL: func @miopen_conv2d
+// CHECK-NOT: miopen.conv2d
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.transform
+// CHECK-NEXT: miopen.gridwise_gemm
More information about the llvm-branch-commits
mailing list