[llvm-branch-commits] [mlir] 1c3be7e - Add Op transform logic. Improve Op translate logic.
Wen-Heng Chung via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:25 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:20-05:00
New Revision: 1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd
URL: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd
DIFF: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd.diff
LOG: Add Op transform logic. Improve Op translate logic.
Revise tests.
Added:
Modified:
mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
mlir/test/Dialect/MIOpen/lowering.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index 2bd64efa77b6..cda706c4112c 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -343,12 +343,17 @@ struct GridwiseConvolutionImplicitGemm_v4r4_)";
output << kHeaderPreamblePart2;
output << kHeaderPreamblePart3;
output << '\n';
- output << R"(
- constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
- output << R"(
- constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
- output << R"(
- constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
+
+ // TBD: remove these interim checks.
+ if (tensorDescs.size() > 0)
+ output << R"(
+ constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
+ if (tensorDescs.size() > 1)
+ output << R"(
+ constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
+ if (tensorDescs.size() > 2)
+ output << R"(
+ constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
output << '\n';
}
@@ -358,7 +363,7 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t,
// decltype(wei_e_k_global_desc),
// decltype(in_e_b_global_desc),
// decltype(out_k_b_global_desc),
- for (int i = 0; i < 3; ++i) {
+ for (unsigned i = 0; i < args.size(); ++i) {
output << R"(
decltype()" << args[i] << "),";
}
@@ -396,7 +401,9 @@ void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attr
case 'H':
case 'W':
output << llvm::toUpper(strAttr.getValue()[0]);
- output << llvm::toUpper(strAttr.getValue()[1]);
+ // XXX: fix this.
+ if (strAttr.getValue().size() > 1)
+ output << llvm::toUpper(strAttr.getValue()[1]);
break;
default:
output << llvm::toUpper(strAttr.getValue()[0]);
diff --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
index 2a00ed675122..27311cb8cfb9 100644
--- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/MIOpenOps/Passes.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
@@ -37,6 +38,8 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/SmallVector.h"
+
using namespace mlir;
struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
@@ -44,15 +47,450 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
PatternMatchResult
matchAndRewrite(miopen::Conv2DOp op, PatternRewriter &rewriter) const override {
- rewriter.create<miopen::TransformOp>(op.getLoc(), op.filter().getType(), op.filter());
+ auto filterLayoutAttr = op.getAttrOfType<ArrayAttr>("filter_layout");
+ auto inputLayoutAttr = op.getAttrOfType<ArrayAttr>("input_layout");
+ auto outputLayoutAttr = op.getAttrOfType<ArrayAttr>("output_layout");
+
+ // TBD: handle dilations, strides, padding.
+
+ // Transform filter tensor.
+ auto filterType = op.filter().getType().dyn_cast<MemRefType>();
+ auto filterShape = filterType.getShape();
+ auto filterElementType = filterType.getElementType();
+
+ llvm::SmallVector<int64_t, 2> transformedFilterShape;
+ transformedFilterShape.set_size(filterShape.size() - 2);
+ // TBD: compute transformed filter shape dimensions.
+ std::fill(transformedFilterShape.begin(), transformedFilterShape.end(), -1);
+ auto transformedFilterMemRefType = MemRefType::get(transformedFilterShape, filterElementType);
+
+ llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs;
+
+ // TBD: set layout attribute.
+ // TBD: Merge part.
+ llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart1Specs;
+ transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext())));
+ transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+ transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+ }, op.getContext())));
+ transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("c", op.getContext()),
+ StringAttr::get("y", op.getContext()),
+ StringAttr::get("x", op.getContext())
+ }, op.getContext())));
+
+ // TBD: Passthrough part.
+ llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart2Specs;
+ transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+ transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ }, op.getContext())));
+ transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("k", op.getContext())
+ }, op.getContext())));
+
+ auto transformedFilterLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(transformedFilterLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(transformedFilterLayoutPart2Specs, op.getContext())
+ }, op.getContext()));
+ transformedFilterAttrs.push_back(transformedFilterLayoutAttr);
+
+ // set source_layout attribute.
+ auto filterSrcLayoutAttr = rewriter.getNamedAttr("source_layout", filterLayoutAttr);
+ transformedFilterAttrs.push_back(filterSrcLayoutAttr);
+ // set output_layout attribute.
+ auto filterOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+ ArrayAttr::get({
+ StringAttr::get("gemmK", op.getContext()),
+ StringAttr::get("gemmM", op.getContext())
+ }, op.getContext()));
+ transformedFilterAttrs.push_back(filterOutputLayoutAttr);
+ // set gridwise_gemm_argument_pos attribute.
+ auto filterGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position",
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0));
+ transformedFilterAttrs.push_back(filterGridwiseGemmArgPosAttr);
+ auto gemmA = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedFilterMemRefType, op.filter(), transformedFilterAttrs);
+
+
+ // Transform input tensor.
+ // Input tensor step 1: padded input.
+ auto inputType = op.input().getType().dyn_cast<MemRefType>();
+ auto inputShape = inputType.getShape();
+ auto inputElementType = inputType.getElementType();
+
+ // TBD: compute padded input shape dimensions.
+
+ llvm::SmallVector<NamedAttribute, 3> paddedInputAttrs;
+
+ // TBD: set layout attribute.
+ // TBD: part 1: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart1Specs;
+ paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext())));
+ paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ }, op.getContext())));
+ paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext())
+ }, op.getContext())));
+
+ // TBD: part 2: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart2Specs;
+ paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext())));
+ paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ }, op.getContext())));
+ paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ci", op.getContext())
+ }, op.getContext())));
+
+ // TBD: part 3: Pad.
+ llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart3Specs;
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+ }, op.getContext())));
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names",
+ ArrayAttr::get({
+ StringAttr::get("hipad", op.getContext()),
+ StringAttr::get("wipad", op.getContext()),
+ }, op.getContext())));
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Pad", op.getContext())));
+ // TBD: padding parmeters.
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+ }, op.getContext())));
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+ }, op.getContext())));
+ paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("hi", op.getContext()),
+ StringAttr::get("wi", op.getContext())
+ }, op.getContext())));
+
+ auto paddedInputLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(paddedInputLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(paddedInputLayoutPart2Specs, op.getContext()),
+ DictionaryAttr::get(paddedInputLayoutPart3Specs, op.getContext())
+ }, op.getContext()));
+ paddedInputAttrs.push_back(paddedInputLayoutAttr);
+
+ // set source_layout attribute.
+ auto inputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", inputLayoutAttr);
+ paddedInputAttrs.push_back(inputSrcLayoutAttr);
+ // TBD: set output_layout attribute.
+ auto paddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext()),
+ StringAttr::get("ci", op.getContext()),
+ StringAttr::get("hi", op.getContext()),
+ StringAttr::get("wi", op.getContext())
+ }, op.getContext()));
+ paddedInputAttrs.push_back(paddedInputOutputLayoutAttr);
+ auto paddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), inputType, op.input(), paddedInputAttrs);
+
+ // Input tensor step 2 : embedded input.
+ llvm::SmallVector<int64_t, 6> embeddedInputShape;
+ embeddedInputShape.set_size(inputShape.size() + 2);
+ // TBD: compute embedded input shape dimensions.
+ std::fill(embeddedInputShape.begin(), embeddedInputShape.end(), -1);
+ auto embeddedInputMemRefType = MemRefType::get(embeddedInputShape, inputElementType);
+
+ llvm::SmallVector<NamedAttribute, 3> embeddedInputAttrs;
+
+ // TBD: set layout attribute.
+ // TBD: part 1: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart1Specs;
+ embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext())));
+ embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ }, op.getContext())));
+ embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext())
+ }, op.getContext())));
+
+ // TBD: part 2: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart2Specs;
+ embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext())));
+ embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ }, op.getContext())));
+ embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ci", op.getContext())
+ }, op.getContext())));
+ // TBD: part 3: Embed.
+ llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart3Specs;
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+ }, op.getContext())));
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names",
+ ArrayAttr::get({
+ StringAttr::get("y", op.getContext()),
+ StringAttr::get("ho", op.getContext()),
+ }, op.getContext())));
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext())));
+ // TBD: padding parmeters.
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+ }, op.getContext())));
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2)
+ }, op.getContext())));
+ embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("hipad", op.getContext()),
+ }, op.getContext())));
+
+ // TBD: part 4: Embed.
+ llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart4Specs;
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 4),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 5)
+ }, op.getContext())));
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("names",
+ ArrayAttr::get({
+ StringAttr::get("x", op.getContext()),
+ StringAttr::get("wo", op.getContext()),
+ }, op.getContext())));
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext())));
+ // TBD: embed parmeters.
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("parameters",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+ }, op.getContext())));
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+ }, op.getContext())));
+ embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("wipad", op.getContext())
+ }, op.getContext())));
+
+ auto embeddedInputLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(embeddedInputLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(embeddedInputLayoutPart2Specs, op.getContext()),
+ DictionaryAttr::get(embeddedInputLayoutPart3Specs, op.getContext()),
+ DictionaryAttr::get(embeddedInputLayoutPart4Specs, op.getContext())
+ }, op.getContext()));
+ embeddedInputAttrs.push_back(embeddedInputLayoutAttr);
+
+
+ // TBD: set intermediate_layout attribute.
+ auto embeddedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext()),
+ StringAttr::get("ci", op.getContext()),
+ StringAttr::get("hipad", op.getContext()),
+ StringAttr::get("wipad", op.getContext())
+ }, op.getContext()));
+ embeddedInputAttrs.push_back(embeddedInputImmLayoutAttr);
+ // TBD: set output_layout attribute.
+ auto embeddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext()),
+ StringAttr::get("ci", op.getContext()),
+ StringAttr::get("y", op.getContext()),
+ StringAttr::get("ho", op.getContext()),
+ StringAttr::get("x", op.getContext()),
+ StringAttr::get("wo", op.getContext())
+ }, op.getContext()));
+ embeddedInputAttrs.push_back(embeddedInputOutputLayoutAttr);
+ auto embeddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), embeddedInputMemRefType, ArrayRef<Value>(paddedInput), embeddedInputAttrs);
+
+ // Input tensor step 3: transformed input.
+ llvm::SmallVector<int64_t, 2> transformedInputShape;
+ transformedInputShape.set_size(inputShape.size() - 2);
+ // TBD: compute transformed input shape dimensions.
+ std::fill(transformedInputShape.begin(), transformedInputShape.end(), -1);
+ auto transformedInputMemRefType = MemRefType::get(transformedInputShape, inputElementType);
+
+ llvm::SmallVector<NamedAttribute, 3> transformedInputAttrs;
+
+ // TBD: set layout attribute.
+ // TBD: Part 1: Merge.
+ llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart1Specs;
+ transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext())));
+ transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+ transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 4)
+ }, op.getContext())));
+ transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ci", op.getContext()),
+ StringAttr::get("y", op.getContext()),
+ StringAttr::get("x", op.getContext())
+ }, op.getContext())));
+
+ // TBD: Part 2: Merge.
+ llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart2Specs;
+ transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+ transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+ transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 5)
+ }, op.getContext())));
+ transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext()),
+ StringAttr::get("ho", op.getContext()),
+ StringAttr::get("wo", op.getContext())
+ }, op.getContext())));
+
+ auto transformedInputLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(transformedInputLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(transformedInputLayoutPart2Specs, op.getContext())
+ }, op.getContext()));
+ transformedInputAttrs.push_back(transformedInputLayoutAttr);
+
+ // TBD: set intermediate_layout attribute.
+ auto transformedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout",
+ ArrayAttr::get({
+ StringAttr::get("ni", op.getContext()),
+ StringAttr::get("ci", op.getContext()),
+ StringAttr::get("y", op.getContext()),
+ StringAttr::get("ho", op.getContext()),
+ StringAttr::get("x", op.getContext()),
+ StringAttr::get("wo", op.getContext())
+ }, op.getContext()));
+ transformedInputAttrs.push_back(transformedInputImmLayoutAttr);
+ // TBD: set output_layout attribute.
+ auto transformedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+ ArrayAttr::get({
+ StringAttr::get("gemmK", op.getContext()),
+ StringAttr::get("gemmN", op.getContext()),
+ }, op.getContext()));
+ transformedInputAttrs.push_back(transformedInputOutputLayoutAttr);
+
+ // set gridwise_gemm_argument_pos attribute.
+ auto inputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position",
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1));
+ transformedInputAttrs.push_back(inputGridwiseGemmArgPosAttr);
+ auto gemmB = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedInputMemRefType, ArrayRef<Value>(embeddedInput), transformedInputAttrs);
+
+
+ // Transform output tensor.
+ auto outputType = op.output().getType().dyn_cast<MemRefType>();
+ auto outputShape = outputType.getShape();
+ auto outputElementType = outputType.getElementType();
+
+ llvm::SmallVector<int64_t, 2> transformedOutputShape;
+ transformedOutputShape.set_size(outputShape.size() - 2);
+ // TBD: compute transformed output shape dimensions.
+ std::fill(transformedOutputShape.begin(), transformedOutputShape.end(), -1);
+ auto transformedOutputMemRefType = MemRefType::get(transformedOutputShape, outputElementType);
+
+ llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs;
+
+ // TBD: set layout attribute.
+ // TBD: Part 1: Passthrough.
+ llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+ }, op.getContext())));
+ transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({
+ StringAttr::get("ko", op.getContext())
+ }, op.getContext())));
+
+ // TBD: Part 2: Merge.
+ llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+ ArrayAttr::get({
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+ }, op.getContext())));
+ transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+ ArrayAttr::get({ StringAttr::get("no", op.getContext()),
+ StringAttr::get("ho", op.getContext()),
+ StringAttr::get("wo", op.getContext())
+ }, op.getContext())));
+
+ auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
+ ArrayAttr::get({
+ DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
+ DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
+ }, op.getContext()));
+ transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
- rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
- rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
- rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
+ // set source_layout attribute.
+ auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", outputLayoutAttr);
+ transformedOutputAttrs.push_back(outputSrcLayoutAttr);
+ // TBD: set output_layout attribute.
+ auto transformedOutputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+ ArrayAttr::get({
+ StringAttr::get("gemmM", op.getContext()),
+ StringAttr::get("gemmN", op.getContext()),
+ }, op.getContext()));
+ transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr);
- rewriter.create<miopen::TransformOp>(op.getLoc(), op.output().getType(), op.output());
+ // TBD: set gridwise_gemm_argument_pos attribute.
+ auto outputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position",
+ IntegerAttr::get(IntegerType::get(32, op.getContext()), 2));
+ transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr);
+ auto gemmC = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedOutputMemRefType, op.output(), transformedOutputAttrs);
- //rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), op.filter(), op.input(), op.output());
+ // Emit miopen.gridwise_gemm op.
+ rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), gemmA, gemmB, gemmC);
// Finally, erase the original Conv2D op.
op.erase();
diff --git a/mlir/test/Dialect/MIOpen/lowering.mlir b/mlir/test/Dialect/MIOpen/lowering.mlir
index e7734cef5a29..5907fbd41ebd 100644
--- a/mlir/test/Dialect/MIOpen/lowering.mlir
+++ b/mlir/test/Dialect/MIOpen/lowering.mlir
@@ -3,8 +3,8 @@
func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
miopen.conv2d(%filter, %input, %output) {
filter_layout = ["k", "c", "y", "x"],
- input_layout = ["n", "c", "hi", "wi"],
- output_layout = ["n", "k", "ho", "wo"],
+ input_layout = ["ni", "ci", "hi", "wi"],
+ output_layout = ["no", "ko", "ho", "wo"],
dilations = [1, 1],
strides = [1, 1],
padding = [0, 0]
@@ -18,4 +18,4 @@ func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>,
// CHECK-NEXT: miopen.transform
// CHECK-NEXT: miopen.transform
// CHECK-NEXT: miopen.transform
-// TBD-CHECK-NEXT: miopen.gridwise_gemm
+// CHECK-NEXT: miopen.gridwise_gemm
More information about the llvm-branch-commits
mailing list