[llvm-branch-commits] [mlir] 1c3be7e - Add Op transform logic. Improve Op translate logic.

Wen-Heng Chung via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:25 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:20-05:00
New Revision: 1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd

URL: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd
DIFF: https://github.com/llvm/llvm-project/commit/1c3be7ec0838048d0a5a8f2ebf3dfa5e831370cd.diff

LOG: Add Op transform logic. Improve Op translate logic.

Revise tests.

Added: 
    

Modified: 
    mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
    mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
    mlir/test/Dialect/MIOpen/lowering.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index 2bd64efa77b6..cda706c4112c 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -343,12 +343,17 @@ struct GridwiseConvolutionImplicitGemm_v4r4_)";
   output << kHeaderPreamblePart2;
   output << kHeaderPreamblePart3;
   output << '\n';
-  output << R"(
-        constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
-  output << R"(
-        constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
-  output << R"(
-        constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
+
+  // TBD: remove these interim checks.
+  if (tensorDescs.size() > 0)
+    output << R"(
+          constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
+  if (tensorDescs.size() > 1)
+    output << R"(
+          constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
+  if (tensorDescs.size() > 2)
+    output << R"(
+          constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
   output << '\n';
 }
 
@@ -358,7 +363,7 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t,
 //                                                   decltype(wei_e_k_global_desc),
 //                                                   decltype(in_e_b_global_desc),
 //                                                   decltype(out_k_b_global_desc),
-  for (int i = 0; i < 3; ++i) {
+  for (unsigned i = 0; i < args.size(); ++i) {
     output << R"(
                                                      decltype()" << args[i] << "),";
   }
@@ -396,7 +401,9 @@ void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attr
           case 'H':
           case 'W':
             output << llvm::toUpper(strAttr.getValue()[0]);
-            output << llvm::toUpper(strAttr.getValue()[1]);
+            // XXX: fix this. 
+            if (strAttr.getValue().size() > 1)
+              output << llvm::toUpper(strAttr.getValue()[1]);
             break;
           default:
             output << llvm::toUpper(strAttr.getValue()[0]);

diff  --git a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
index 2a00ed675122..27311cb8cfb9 100644
--- a/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/LowerMIOpenOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/MIOpenOps/Passes.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Operation.h"
@@ -37,6 +38,8 @@
 #include "mlir/Transforms/Passes.h"
 #include "mlir/Support/LogicalResult.h"
 
+#include "llvm/ADT/SmallVector.h"
+
 using namespace mlir;
 
 struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
@@ -44,15 +47,450 @@ struct Conv2DOpRewritePattern : public OpRewritePattern<miopen::Conv2DOp> {
 
   PatternMatchResult
   matchAndRewrite(miopen::Conv2DOp op, PatternRewriter &rewriter) const override {
-    rewriter.create<miopen::TransformOp>(op.getLoc(), op.filter().getType(), op.filter());
+    auto filterLayoutAttr = op.getAttrOfType<ArrayAttr>("filter_layout");
+    auto inputLayoutAttr = op.getAttrOfType<ArrayAttr>("input_layout");
+    auto outputLayoutAttr = op.getAttrOfType<ArrayAttr>("output_layout");
+
+    // TBD: handle dilations, strides, padding.
+
+    // Transform filter tensor.
+    auto filterType = op.filter().getType().dyn_cast<MemRefType>();
+    auto filterShape = filterType.getShape();
+    auto filterElementType = filterType.getElementType();
+    
+    llvm::SmallVector<int64_t, 2> transformedFilterShape;
+    transformedFilterShape.set_size(filterShape.size() - 2);
+    // TBD: compute transformed filter shape dimensions.
+    std::fill(transformedFilterShape.begin(), transformedFilterShape.end(), -1);
+    auto transformedFilterMemRefType = MemRefType::get(transformedFilterShape, filterElementType);
+
+    llvm::SmallVector<NamedAttribute, 3> transformedFilterAttrs;
+
+    // TBD: set layout attribute.
+    // TBD: Merge part.
+    llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart1Specs;
+    transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+    transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext())));
+    transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+    transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+                                                }, op.getContext())));
+    transformedFilterLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("c", op.getContext()),
+                                                    StringAttr::get("y", op.getContext()),
+                                                    StringAttr::get("x", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: Passthrough part.
+    llvm::SmallVector<NamedAttribute, 5> transformedFilterLayoutPart2Specs;
+    transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+    transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+    transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                }, op.getContext())));
+    transformedFilterLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("k", op.getContext())
+                                                }, op.getContext())));
+
+    auto transformedFilterLayoutAttr = rewriter.getNamedAttr("layout",
+                                                             ArrayAttr::get({
+                                                                 DictionaryAttr::get(transformedFilterLayoutPart1Specs, op.getContext()),
+                                                                 DictionaryAttr::get(transformedFilterLayoutPart2Specs, op.getContext())
+                                                             }, op.getContext()));
+    transformedFilterAttrs.push_back(transformedFilterLayoutAttr);
+
+    // set source_layout attribute.
+    auto filterSrcLayoutAttr = rewriter.getNamedAttr("source_layout", filterLayoutAttr);
+    transformedFilterAttrs.push_back(filterSrcLayoutAttr);
+    // set output_layout attribute.
+    auto filterOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("gemmK", op.getContext()),
+                                                            StringAttr::get("gemmM", op.getContext())
+                                                        }, op.getContext()));
+    transformedFilterAttrs.push_back(filterOutputLayoutAttr);
+    // set gridwise_gemm_argument_pos attribute.
+    auto filterGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", 
+                                                              IntegerAttr::get(IntegerType::get(32, op.getContext()), 0));
+    transformedFilterAttrs.push_back(filterGridwiseGemmArgPosAttr);
+    auto gemmA = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedFilterMemRefType, op.filter(), transformedFilterAttrs);
+
+
+    // Transform input tensor.
+    // Input tensor step 1: padded input.
+    auto inputType = op.input().getType().dyn_cast<MemRefType>();
+    auto inputShape = inputType.getShape();
+    auto inputElementType = inputType.getElementType();
+
+    // TBD: compute padded input shape dimensions.
+
+    llvm::SmallVector<NamedAttribute, 3> paddedInputAttrs;
+
+    // TBD: set layout attribute.
+    // TBD: part 1: Passthrough.
+    llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart1Specs;
+    paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+    paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext())));
+    paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                }, op.getContext())));
+    paddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ni", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: part 2: Passthrough.
+    llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart2Specs;
+    paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+    paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext())));
+    paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                }, op.getContext())));
+    paddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ci", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: part 3: Pad.
+    llvm::SmallVector<NamedAttribute, 5> paddedInputLayoutPart3Specs;
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+                                                }, op.getContext())));
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("hipad", op.getContext()),
+                                                    StringAttr::get("wipad", op.getContext()),
+                                                }, op.getContext())));
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Pad", op.getContext())));
+    // TBD: padding parmeters.
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+                                                }, op.getContext())));
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+                                                }, op.getContext())));
+    paddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("hi", op.getContext()),
+                                                    StringAttr::get("wi", op.getContext())
+                                                }, op.getContext())));
+
+    auto paddedInputLayoutAttr = rewriter.getNamedAttr("layout",
+                                                             ArrayAttr::get({
+                                                                 DictionaryAttr::get(paddedInputLayoutPart1Specs, op.getContext()),
+                                                                 DictionaryAttr::get(paddedInputLayoutPart2Specs, op.getContext()),
+                                                                 DictionaryAttr::get(paddedInputLayoutPart3Specs, op.getContext())
+                                                             }, op.getContext()));
+    paddedInputAttrs.push_back(paddedInputLayoutAttr);
+
+    // set source_layout attribute.
+    auto inputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", inputLayoutAttr);
+    paddedInputAttrs.push_back(inputSrcLayoutAttr);
+    // TBD: set output_layout attribute.
+    auto paddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("ni", op.getContext()),
+                                                            StringAttr::get("ci", op.getContext()),
+                                                            StringAttr::get("hi", op.getContext()),
+                                                            StringAttr::get("wi", op.getContext())
+                                                        }, op.getContext()));
+    paddedInputAttrs.push_back(paddedInputOutputLayoutAttr);
+    auto paddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), inputType, op.input(), paddedInputAttrs);
+
+    // Input tensor step 2 : embedded input.
+    llvm::SmallVector<int64_t, 6> embeddedInputShape;
+    embeddedInputShape.set_size(inputShape.size() + 2);
+    // TBD: compute embedded input shape dimensions.
+    std::fill(embeddedInputShape.begin(), embeddedInputShape.end(), -1);
+    auto embeddedInputMemRefType = MemRefType::get(embeddedInputShape, inputElementType);
+
+    llvm::SmallVector<NamedAttribute, 3> embeddedInputAttrs;
+
+    // TBD: set layout attribute.
+    // TBD: part 1: Passthrough.
+    llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart1Specs;
+    embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+    embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ni", op.getContext())}, op.getContext())));
+    embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                }, op.getContext())));
+    embeddedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ni", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: part 2: Passthrough.
+    llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart2Specs;
+    embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+    embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("ci", op.getContext())}, op.getContext())));
+    embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                }, op.getContext())));
+    embeddedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ci", op.getContext())
+                                                }, op.getContext())));
+    // TBD: part 3: Embed.
+    llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart3Specs;
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("y", op.getContext()),
+                                                    StringAttr::get("ho", op.getContext()),
+                                                }, op.getContext())));
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext())));
+    // TBD: padding parmeters.
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("parameters",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart3Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("hipad", op.getContext()),
+                                                }, op.getContext())));
+
+    // TBD: part 4: Embed.
+    llvm::SmallVector<NamedAttribute, 5> embeddedInputLayoutPart4Specs;
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 4),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 5)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("x", op.getContext()),
+                                                    StringAttr::get("wo", op.getContext()),
+                                                }, op.getContext())));
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Embed", op.getContext())));
+    // TBD: embed parmeters.
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("parameters",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3)
+                                                }, op.getContext())));
+    embeddedInputLayoutPart4Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("wipad", op.getContext())
+                                                }, op.getContext())));
+
+    auto embeddedInputLayoutAttr = rewriter.getNamedAttr("layout",
+                                                             ArrayAttr::get({
+                                                                 DictionaryAttr::get(embeddedInputLayoutPart1Specs, op.getContext()),
+                                                                 DictionaryAttr::get(embeddedInputLayoutPart2Specs, op.getContext()),
+                                                                 DictionaryAttr::get(embeddedInputLayoutPart3Specs, op.getContext()),
+                                                                 DictionaryAttr::get(embeddedInputLayoutPart4Specs, op.getContext())
+                                                             }, op.getContext()));
+    embeddedInputAttrs.push_back(embeddedInputLayoutAttr);
+
+
+    // TBD: set intermediate_layout attribute.
+    auto embeddedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("ni", op.getContext()),
+                                                            StringAttr::get("ci", op.getContext()),
+                                                            StringAttr::get("hipad", op.getContext()),
+                                                            StringAttr::get("wipad", op.getContext())
+                                                        }, op.getContext()));
+    embeddedInputAttrs.push_back(embeddedInputImmLayoutAttr);
+    // TBD: set output_layout attribute.
+    auto embeddedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("ni", op.getContext()),
+                                                            StringAttr::get("ci", op.getContext()),
+                                                            StringAttr::get("y", op.getContext()),
+                                                            StringAttr::get("ho", op.getContext()),
+                                                            StringAttr::get("x", op.getContext()),
+                                                            StringAttr::get("wo", op.getContext())
+                                                        }, op.getContext()));
+    embeddedInputAttrs.push_back(embeddedInputOutputLayoutAttr);
+    auto embeddedInput = rewriter.create<miopen::TransformOp>(op.getLoc(), embeddedInputMemRefType, ArrayRef<Value>(paddedInput), embeddedInputAttrs);
+
+    // Input tensor step 3: transformed input.
+    llvm::SmallVector<int64_t, 2> transformedInputShape;
+    transformedInputShape.set_size(inputShape.size() - 2);
+    // TBD: compute transformed input shape dimensions.
+    std::fill(transformedInputShape.begin(), transformedInputShape.end(), -1);
+    auto transformedInputMemRefType = MemRefType::get(transformedInputShape, inputElementType);
+
+    llvm::SmallVector<NamedAttribute, 3> transformedInputAttrs;
+
+    // TBD: set layout attribute.
+    // TBD: Part 1: Merge.
+    llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart1Specs;
+    transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+    transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmK", op.getContext())}, op.getContext())));
+    transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+    transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 4)
+                                                }, op.getContext())));
+    transformedInputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ci", op.getContext()),
+                                                    StringAttr::get("y", op.getContext()),
+                                                    StringAttr::get("x", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: Part 2: Merge.
+    llvm::SmallVector<NamedAttribute, 5> transformedInputLayoutPart2Specs;
+    transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+    transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+    transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+    transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 5)
+                                                }, op.getContext())));
+    transformedInputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ni", op.getContext()),
+                                                    StringAttr::get("ho", op.getContext()),
+                                                    StringAttr::get("wo", op.getContext())
+                                                }, op.getContext())));
+
+    auto transformedInputLayoutAttr = rewriter.getNamedAttr("layout",
+                                                             ArrayAttr::get({
+                                                                 DictionaryAttr::get(transformedInputLayoutPart1Specs, op.getContext()),
+                                                                 DictionaryAttr::get(transformedInputLayoutPart2Specs, op.getContext())
+                                                             }, op.getContext()));
+    transformedInputAttrs.push_back(transformedInputLayoutAttr);
+
+    // TBD: set intermediate_layout attribute.
+    auto transformedInputImmLayoutAttr = rewriter.getNamedAttr("intermediate_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("ni", op.getContext()),
+                                                            StringAttr::get("ci", op.getContext()),
+                                                            StringAttr::get("y", op.getContext()),
+                                                            StringAttr::get("ho", op.getContext()),
+                                                            StringAttr::get("x", op.getContext()),
+                                                            StringAttr::get("wo", op.getContext())
+                                                        }, op.getContext()));
+    transformedInputAttrs.push_back(transformedInputImmLayoutAttr);
+    // TBD: set output_layout attribute.
+    auto transformedInputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("gemmK", op.getContext()),
+                                                            StringAttr::get("gemmN", op.getContext()),
+                                                        }, op.getContext()));
+    transformedInputAttrs.push_back(transformedInputOutputLayoutAttr);
+
+    // set gridwise_gemm_argument_pos attribute.
+    auto inputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", 
+                                                             IntegerAttr::get(IntegerType::get(32, op.getContext()), 1));
+    transformedInputAttrs.push_back(inputGridwiseGemmArgPosAttr);
+    auto gemmB = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedInputMemRefType, ArrayRef<Value>(embeddedInput), transformedInputAttrs);
+
+
+    // Transform output tensor.
+    auto outputType = op.output().getType().dyn_cast<MemRefType>();
+    auto outputShape = outputType.getShape();
+    auto outputElementType = outputType.getElementType();
+
+    llvm::SmallVector<int64_t, 2> transformedOutputShape;
+    transformedOutputShape.set_size(outputShape.size() - 2);
+    // TBD: compute transformed output shape dimensions.
+    std::fill(transformedOutputShape.begin(), transformedOutputShape.end(), -1);
+    auto transformedOutputMemRefType = MemRefType::get(transformedOutputShape, outputElementType);
+
+    llvm::SmallVector<NamedAttribute, 3> transformedOutputAttrs;
+
+    // TBD: set layout attribute.
+    // TBD: Part 1: Passthrough.
+    llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart1Specs;
+    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 0)}, op.getContext())));
+    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmM", op.getContext())}, op.getContext())));
+    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("PassThrough", op.getContext())));
+    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 1),
+                                                }, op.getContext())));
+    transformedOutputLayoutPart1Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({
+                                                    StringAttr::get("ko", op.getContext())
+                                                }, op.getContext())));
+
+    // TBD: Part 2: Merge.
+    llvm::SmallVector<NamedAttribute, 5> transformedOutputLayoutPart2Specs;
+    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("dimensions", ArrayAttr::get({IntegerAttr::get(IntegerType::get(32, op.getContext()), 1)}, op.getContext())));
+    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("names", ArrayAttr::get({StringAttr::get("gemmN", op.getContext())}, op.getContext())));
+    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("transformation", StringAttr::get("Merge", op.getContext())));
+    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_dimensions",
+                                                ArrayAttr::get({
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 0),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 2),
+                                                    IntegerAttr::get(IntegerType::get(32, op.getContext()), 3),
+                                                }, op.getContext())));
+    transformedOutputLayoutPart2Specs.push_back(rewriter.getNamedAttr("source_names",
+                                                ArrayAttr::get({                                                                                             StringAttr::get("no", op.getContext()),
+                                                    StringAttr::get("ho", op.getContext()),
+                                                    StringAttr::get("wo", op.getContext())
+                                                }, op.getContext())));
+
+    auto transformedOutputLayoutAttr = rewriter.getNamedAttr("layout",
+                                                             ArrayAttr::get({
+                                                                 DictionaryAttr::get(transformedOutputLayoutPart1Specs, op.getContext()),
+                                                                 DictionaryAttr::get(transformedOutputLayoutPart2Specs, op.getContext())
+                                                             }, op.getContext()));
+    transformedOutputAttrs.push_back(transformedOutputLayoutAttr);
 
-    rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
-    rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
-    rewriter.create<miopen::TransformOp>(op.getLoc(), op.input().getType(), op.input());
+    // set source_layout attribute.
+    auto outputSrcLayoutAttr = rewriter.getNamedAttr("source_layout", outputLayoutAttr);
+    transformedOutputAttrs.push_back(outputSrcLayoutAttr);
+    // TBD: set output_layout attribute.
+    auto transformedOutputOutputLayoutAttr = rewriter.getNamedAttr("output_layout",
+                                                        ArrayAttr::get({
+                                                            StringAttr::get("gemmM", op.getContext()),
+                                                            StringAttr::get("gemmN", op.getContext()),
+                                                        }, op.getContext()));
+    transformedOutputAttrs.push_back(transformedOutputOutputLayoutAttr);
 
-    rewriter.create<miopen::TransformOp>(op.getLoc(), op.output().getType(), op.output());
+    // TBD: set gridwise_gemm_argument_pos attribute.
+    auto outputGridwiseGemmArgPosAttr = rewriter.getNamedAttr("gridwise_gemm_argument_position", 
+                                                             IntegerAttr::get(IntegerType::get(32, op.getContext()), 2));
+    transformedOutputAttrs.push_back(outputGridwiseGemmArgPosAttr);
+    auto gemmC = rewriter.create<miopen::TransformOp>(op.getLoc(), transformedOutputMemRefType, op.output(), transformedOutputAttrs);
 
-    //rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), op.filter(), op.input(), op.output());
+    // Emit miopen.gridwise_gemm op.
+    rewriter.create<miopen::GridwiseGemmOp>(op.getLoc(), gemmA, gemmB, gemmC);
 
     // Finally, erase the original Conv2D op.
     op.erase();

diff  --git a/mlir/test/Dialect/MIOpen/lowering.mlir b/mlir/test/Dialect/MIOpen/lowering.mlir
index e7734cef5a29..5907fbd41ebd 100644
--- a/mlir/test/Dialect/MIOpen/lowering.mlir
+++ b/mlir/test/Dialect/MIOpen/lowering.mlir
@@ -3,8 +3,8 @@
 func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
   miopen.conv2d(%filter, %input, %output) {
     filter_layout = ["k", "c", "y", "x"],
-    input_layout = ["n", "c", "hi", "wi"],
-    output_layout = ["n", "k", "ho", "wo"],
+    input_layout = ["ni", "ci", "hi", "wi"],
+    output_layout = ["no", "ko", "ho", "wo"],
     dilations = [1, 1],
     strides = [1, 1],
     padding = [0, 0]
@@ -18,4 +18,4 @@ func @miopen_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>,
 //  CHECK-NEXT: miopen.transform
 //  CHECK-NEXT: miopen.transform
 //  CHECK-NEXT: miopen.transform
-//  TBD-CHECK-NEXT: miopen.gridwise_gemm
+//  CHECK-NEXT: miopen.gridwise_gemm


        


More information about the llvm-branch-commits mailing list