[llvm-branch-commits] [mlir] f0500d1 - Add Op traversing logic into MIOpen dialect -> C++ header translator.

Wen-Heng Chung via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:17 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:19-05:00
New Revision: f0500d18ee6f0072487526e53fc9a7bd1c1d7235

URL: https://github.com/llvm/llvm-project/commit/f0500d18ee6f0072487526e53fc9a7bd1c1d7235
DIFF: https://github.com/llvm/llvm-project/commit/f0500d18ee6f0072487526e53fc9a7bd1c1d7235.diff

LOG: Add Op traversing logic into MIOpen dialect -> C++ header translator.

Added: 
    mlir/test/Dialect/MIOpen/translate.mlir

Modified: 
    mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp

Removed: 
    mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index b071565c2b51..4eb8d7de7181 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -15,8 +15,10 @@
 #include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
+#include "mlir/Support/STLExtras.h"
 #include "mlir/Translation.h"
 
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/raw_ostream.h"
@@ -222,6 +224,160 @@ void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm:
   output << kCppEpiloguePart2;
 }
 
+static constexpr StringLiteral kHeaderPreamblePart1 = R"(
+#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_HPP
+#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_HPP
+
+#include "common_header.hpp"
+#include "tensor_descriptor.hpp"
+#include "tensor_descriptor_helper.hpp"
+#include "gridwise_gemm.hpp"
+
+namespace ck {
+
+// GemmM = K
+// GemmN = N * Ho * Wo
+// GemmK = C * Y * X
+template <index_t GridSize,
+          index_t BlockSize,
+          typename Float,
+          typename AccFloat,
+          typename InGlobalDesc,
+          typename WeiGlobalDesc,
+          typename OutGlobalDesc,
+          typename ConvStrides,
+          typename ConvDilations,
+          typename InLeftPads,
+          typename InRightPads,
+          index_t GemmMPerBlock,
+          index_t GemmNPerBlock,
+          index_t GemmKPerBlock,
+          index_t GemmMPerThreadSubC,
+          index_t GemmNPerThreadSubC,
+          index_t GemmMLevel0Cluster,
+          index_t GemmNLevel0Cluster,
+          index_t GemmMLevel1Cluster,
+          index_t GemmNLevel1Cluster,
+          index_t GemmKPerThreadLoop,
+          index_t GemmThreadGemmDataPerReadM,
+          index_t GemmThreadGemmDataPerReadN,
+          typename GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
+          typename GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
+          index_t GemmABlockCopySrcDataPerRead_GemmK,
+          index_t GemmABlockCopyDstDataPerWrite_GemmM,
+          typename GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
+          typename GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
+          index_t GemmBBlockCopySrcDataPerRead_GemmN,
+          index_t GemmBBlockCopyDstDataPerWrite_GemmN,
+          index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
+)";
+
+static constexpr StringLiteral kHeaderPreamblePart2 = R"(
+{
+    __device__ void Run(const Float* const __restrict__ p_in_global,
+                        const Float* const __restrict__ p_wei_global,
+                        Float* const __restrict__ p_out_global) const
+    {
+)";
+
+static constexpr StringLiteral kHeaderPreamblePart3 = R"(
+        constexpr auto I0 = Number<0>{};
+        constexpr auto I1 = Number<1>{};
+        constexpr auto I2 = Number<2>{};
+        constexpr auto I3 = Number<3>{};
+
+        constexpr index_t ConvStrideH = ConvStrides{}[0];
+        constexpr index_t ConvStrideW = ConvStrides{}[1];
+
+        constexpr index_t ConvDilationH = ConvDilations{}[0];
+        constexpr index_t ConvDilationW = ConvDilations{}[1];
+)";
+
+static constexpr StringLiteral kHeaderEpiloguePart1 = R"(
+        // GEMM
+        constexpr auto gridwise_gemm =
+            GridwiseGemmTransposedANormalBNormalC_v1<GridSize,
+                                                     BlockSize,
+                                                     Float,
+                                                     AccFloat,
+)";
+
+static constexpr StringLiteral kHeaderEpiloguePart2 = R"(
+                                                     InMemoryDataOperation::none,
+                                                     GemmMPerBlock,
+                                                     GemmNPerBlock,
+                                                     GemmKPerBlock,
+                                                     GemmMPerThreadSubC,
+                                                     GemmNPerThreadSubC,
+                                                     GemmMLevel0Cluster,
+                                                     GemmNLevel0Cluster,
+                                                     GemmMLevel1Cluster,
+                                                     GemmNLevel1Cluster,
+                                                     GemmKPerThreadLoop,
+                                                     GemmThreadGemmDataPerReadM,
+                                                     GemmThreadGemmDataPerReadN,
+                                                     GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
+                                                     GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
+                                                     Sequence<1, 0>,
+                                                     Sequence<1, 0>,
+                                                     0,
+                                                     GemmABlockCopySrcDataPerRead_GemmK,
+                                                     GemmABlockCopyDstDataPerWrite_GemmM,
+                                                     GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
+                                                     GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
+                                                     Sequence<0, 1>,
+                                                     Sequence<0, 1>,
+                                                     1,
+                                                     GemmBBlockCopySrcDataPerRead_GemmN,
+                                                     GemmBBlockCopyDstDataPerWrite_GemmN,
+                                                     Sequence<0, 1, 2, 3>,
+                                                     3,
+                                                     GemmCThreadCopyDstDataPerWrite_GemmN1>{};
+
+        gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global);
+    }
+};
+
+} // namespace ck
+#endif
+)";
+
+void EmitHeaderPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
+  output << kHeaderPreamblePart1;
+
+  output << R"(
+struct GridwiseConvolutionImplicitGemm_v4r4_)";
+  output << layoutStr;
+
+  output << kHeaderPreamblePart2;
+
+  output << kHeaderPreamblePart3;
+
+  output << '\n';
+
+  output << R"(
+        constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
+  output << R"(
+        constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
+  output << R"(
+        constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
+}
+
+void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t, std::string> &args) {
+  output << kHeaderEpiloguePart1;
+
+// Between Part1 and Part2 emit:
+//                                                   decltype(wei_e_k_global_desc),
+//                                                   decltype(in_e_b_global_desc),
+//                                                   decltype(out_k_b_global_desc),
+  for (int i = 0; i < 3; ++i) {
+    output << R"(
+                                                     decltype()" << args[i] << "),";
+  }
+
+  output << kHeaderEpiloguePart2;
+}
+
 void EmitLayoutString(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr, llvm::StringRef prefix, llvm::StringRef suffix, llvm::StringRef delimiter = "") {
   for (int i = 0; i < kConv2DTensorDimension; ++i) {
     auto attr = layoutArrayAttr[i];
@@ -234,11 +390,20 @@ void EmitLayoutString(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute>
   }
 }
 
+void EmitHeaderDimensionLengths(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr, llvm::StringRef tensorDesc) {
+  for (int i = 0; i < kConv2DTensorDimension; ++i) {
+    auto attr = layoutArrayAttr[i];
+    if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+      output << "        constexpr index_t " << strAttr.getValue() << " = " << tensorDesc << ".GetLengths()[" << i << "];\n";
+    }
+  }
+}
+
 void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) {
   for (int i = 0; i < kConv2DTensorDimension; ++i) {
     auto attr = layoutArrayAttr[i];
     if (auto strAttr = attr.dyn_cast<StringAttr>()) {
-      output << "    const index_t " << strAttr.getValue() << " = CK_PARAM_PROBLEM_";
+      output << "    constexpr index_t " << strAttr.getValue() << " = CK_PARAM_PROBLEM_";
 
       switch (llvm::toUpper(strAttr.getValue()[0])) {
           case 'H':
@@ -258,7 +423,7 @@ void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribu
   for (int i = 0; i < kConv2DTensorDimension; ++i) {
     auto attr = layoutArrayAttr[i];
     if (auto strAttr = attr.dyn_cast<StringAttr>()) {
-      output << "    const index_t stride_" << strAttr.getValue() << " = ";
+      output << "    constexpr index_t stride_" << strAttr.getValue() << " = ";
 
       if (i == 0) {
         output << "1;\n";
@@ -272,6 +437,19 @@ void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribu
   }
 }
 
+void EmitInterleaveArrayAttrOfStringAttrWithSeparator(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr, const StringRef &separator) {
+  if (arrayAttr) {
+    interleave(arrayAttr, os, [&](Attribute attr) {
+      if (auto strAttr = attr.dyn_cast<StringAttr>())
+        os << strAttr.getValue();
+    }, separator);
+  }
+}
+
+void EmitInterleaveCommaArrayAttrOfStringAttr(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr) {
+  EmitInterleaveArrayAttrOfStringAttrWithSeparator(os, arrayAttr, ", ");
+}
+
 void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
   // (TBD verifiying logic) The Module could contain multiple FuncOp, and inside each FuncOp there
   // should be exactly:
@@ -315,7 +493,7 @@ void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std
 
 }
 
-std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
+std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m) {
   std::string resultStr;
   llvm::raw_string_ostream output(resultStr);
 
@@ -323,6 +501,7 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
   for (auto f : m.getOps<FuncOp>()) {
     std::string layoutStr;
     llvm::SmallVector<std::string, 3> tensorDescs;
+    llvm::SmallDenseMap<int64_t, std::string> gridwiseGemmArguments;
 
     // Obtain critical information from ModuleOp.
     ObtainModuleInfo(m, layoutStr, tensorDescs);
@@ -330,17 +509,154 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
     int srcLayoutAttrCtr = 0;
 
     // Start emitting.
+    EmitHeaderPreamble(output, layoutStr, tensorDescs);
+
+    f.walk([&output, &srcLayoutAttrCtr, &tensorDescs, &gridwiseGemmArguments](miopen::TransformOp op) {
+      // get source_layout attribute.
+      auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+      if (srcLayoutAttr) {
+        auto srcLayout = srcLayoutAttr.getValue();
+        output << "\n        // ";
+        EmitLayoutString(output, srcLayout, "", "", ", ");
+        output << '\n';
+
+        EmitHeaderDimensionLengths(output, srcLayout, tensorDescs[srcLayoutAttrCtr]);
+      }
+      output << '\n';
+ 
+      // get layout attribute.
+      auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout");
+      std::string inputTensorName;
+      std::string outputTensorName;
+      std::string operationSpec;
+      std::string srcDimSpec;
+      std::string dstDimSpec;
+      llvm::raw_string_ostream ins(inputTensorName);
+      llvm::raw_string_ostream outs(outputTensorName);
+      llvm::raw_string_ostream ops(operationSpec);
+      llvm::raw_string_ostream srcs(srcDimSpec);
+      llvm::raw_string_ostream dsts(dstDimSpec);
+
+      // determine input and output tensor name.
+      auto immLayoutAttr = op.getAttrOfType<ArrayAttr>("intermediate_layout");
+      auto outputLayoutAttr = op.getAttrOfType<ArrayAttr>("output_layout");
+      if (srcLayoutAttr) {
+        inputTensorName = tensorDescs[srcLayoutAttrCtr];
+        outs << kVarName[srcLayoutAttrCtr] << "_";
+
+        srcLayoutAttrCtr++;
+      } else {
+        // get intermediate_layout attribute.
+        if (immLayoutAttr) {
+          ins << kVarName[srcLayoutAttrCtr - 1] << "_";
+          EmitInterleaveArrayAttrOfStringAttrWithSeparator(ins, immLayoutAttr, "_");
+          ins << "_desc";
+          ins.flush();
+
+          outs << kVarName[srcLayoutAttrCtr - 1] << "_";
+        }
+      }
+      EmitInterleaveArrayAttrOfStringAttrWithSeparator(outs, outputLayoutAttr, "_");
+      outs << "_desc";
+      outs.flush();
+
+      // determine gridwise GEMM arguments.
+      auto gridwiseGemmArgPosAttr = op.getAttrOfType<IntegerAttr>("gridwise_gemm_argument_position");
+      if (gridwiseGemmArgPosAttr) {
+        llvm::errs() << "gridwise gemm argument pos: " << gridwiseGemmArgPosAttr.getValue() << "\n";
+        llvm::errs() << "tensor: " << outputTensorName << "\n";
+        gridwiseGemmArguments[gridwiseGemmArgPosAttr.getInt()] = outputTensorName;
+      }  
+
+      ops << "            make_tuple(";
+      srcs << "            make_tuple(";
+      dsts << "            make_tuple(";
+
+      for (auto layoutSpec = layoutAttr.begin(); layoutSpec != layoutAttr.end(); ) {
+        if (auto layoutSpecDict = layoutSpec->dyn_cast<DictionaryAttr>()) {
+          auto srcNames = layoutSpecDict.get("source_names").dyn_cast<ArrayAttr>();
+          auto dstNames = layoutSpecDict.get("names").dyn_cast<ArrayAttr>();
+
+          if (auto transform = layoutSpecDict.get("transformation").dyn_cast<StringAttr>()) {
+            if (transform.getValue() == "PassThrough" ||
+                transform.getValue() == "Merge") {
+              ops << transform.getValue() << "<";
+              EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames);
+              ops << ">{}";
+            } else if (transform.getValue() == "Pad") {
+              ops << transform.getValue() << "<"
+                  << "Sequence<";
+              EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames);
+              ops << ">, InLeftPads, InRightPads" << ">{}";
+            } else if (transform.getValue() == "Embed") {
+              ops << transform.getValue() << "<"
+                  << "Sequence<";
+              EmitInterleaveCommaArrayAttrOfStringAttr(ops, dstNames);
+              ops << ">, Sequence<ConvDilationTBD, ConvDilationTBD, 0>>{}";
+            }
+            srcs << "Sequence<" << layoutSpecDict.get("source_dimensions") << ">{}";
+            dsts << "Sequence<" << layoutSpecDict.get("dimensions") << ">{}";
+          }
+        }
+
+        ++layoutSpec;
+        if (layoutSpec != layoutAttr.end()) {
+          ops << ", ";
+          srcs << ", ";
+          dsts << ", ";
+        }
+      }
+      ops << "),\n";
+      ops.flush();
+      srcs << "),\n";
+      srcs.flush();
+      dsts << ")";
+      dsts.flush();
+
+      output << "        constexpr auto " << outputTensorName << " = transform_tensor_descriptor(\n";
+      output << "            " << inputTensorName << ",\n";
+      output << operationSpec << srcDimSpec << dstDimSpec;
+      output << ");\n";
+    });
+
+    // TBD get tuning parameters.
+    //f.walk([&output](miopen::GridwiseGemmOp op) {
+    //  // get op name.
+    //  //output << "op name: " << op.getOperationName() << "\n";
+    //  //op.dump();
+    //});
+
+    EmitHeaderEpilogue(output, gridwiseGemmArguments);
+  }
+
+  output.flush();
+  return std::make_unique<llvm::StringRef>(resultStr);
+}
 
+std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
+  std::string resultStr;
+  llvm::raw_string_ostream output(resultStr);
+
+  // Enumerate FuncOp instances inside the ModuleOp.
+  for (auto f : m.getOps<FuncOp>()) {
+    std::string layoutStr;
+    llvm::SmallVector<std::string, 3> tensorDescs;
+
+    // Obtain critical information from ModuleOp.
+    ObtainModuleInfo(m, layoutStr, tensorDescs);
+
+    int srcLayoutAttrCtr = 0;
+
+    // Start emitting.
     EmitCppPreamble(output, layoutStr);
 
     f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) {
-
       // get source_layout attribute.
       auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
       if (srcLayoutAttr) {
         auto srcLayout = srcLayoutAttr.getValue();
         output << "    // ";
-        EmitLayoutString(output, srcLayout, "", "", ", ");
+        EmitLayoutString(output, srcLayout, "", "", ",");
         output << '\n';
 
         EmitDimensionVariables(output, srcLayout);
@@ -354,20 +670,6 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
         EmitLayoutString(output, srcLayout, "stride_", "", ", ");
         output << ">{});\n\n";
       }
-
-      //// get layout attribute.
-      // TBD not used in emitting C++ source wrapper.
-      // would be used in emitting C++ header.
-      //auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout");
-      //for (auto layoutSpec : layoutAttr) {
-      //  if (auto layoutSpecDict = layoutSpec.dyn_cast<DictionaryAttr>()) {
-      //    //output << "dimensions: " << layoutSpecDict.get("dimensions") << "\n";
-      //    //output << "names: " << layoutSpecDict.get("names") << "\n";
-      //    //output << "source_dimensions: " << layoutSpecDict.get("source_dimensions") << "\n";
-      //    //output << "source_names: " << layoutSpecDict.get("source_names") << "\n";
-      //    //output << "transformation: " << layoutSpecDict.get("transformation") << "\n";
-      //  }
-      //}
     });
 
     EmitCppInterlude(output);
@@ -396,13 +698,13 @@ static TranslateFromMLIRRegistration
       return success();
     });
 
-//static TranslateFromMLIRRegistration
-//    toHeader("mlir-to-miopen-h", [](ModuleOp module, llvm::raw_ostream &output) {
-//      auto sourceCode = mlir::translateModuleToMIOpenHeader(module);
-//      if (!sourceCode)
-//        return failure();
-//
-//      output << *sourceCode;
-//      return success();
-//    });
+static TranslateFromMLIRRegistration
+    toHeader("mlir-to-miopen-hpp", [](ModuleOp module, llvm::raw_ostream &output) {
+      auto sourceCode = mlir::translateModuleToMIOpenHeader(module);
+      if (!sourceCode)
+        return failure();
+
+      output << *sourceCode;
+      return success();
+    });
 

diff  --git a/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
deleted file mode 100644
index ca9751191859..000000000000
--- a/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
+++ /dev/null
@@ -1,157 +0,0 @@
-// RUN: mlir-translate -mlir-to-miopen-cpp %s | FileCheck %s
-
-// CHECK:  __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_kcyx_niciwihi_nokohowo
-func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
-  // filter tensor
-  %filter_gemmK_gemmM = miopen.transform(%filter) {
-    layout = [
-      {
-        dimensions = [0],
-        names = ["gemmK"],
-        transformation = "merge",
-        source_dimensions = [1, 2, 3],
-        source_names = ["c", "y", "x"]
-      },
-      {
-        dimensions = [1],
-        names = ["gemmM"],
-        transformation = "passthrough",
-        source_dimensions = [0],
-        source_names = ["n"]
-      }
-    ],
-    source_layout = ["k", "c", "y", "x"]
-  } : memref<?x?x?x?xf32> to memref<?x?xf32>
-
-  // input tensor
-  %input_n_c_hipad_wipad = miopen.transform(%input) {
-    layout = [
-      {
-        dimensions = [0],
-        names = ["n"],
-        transformation = "passthorugh",
-        source_dimensions = [0],
-        source_names = ["ni"]
-      },
-      {
-        dimensions = [1],
-        names = ["c"],
-        transformation = "passthorugh",
-        source_dimensions = [1],
-        source_names = ["ci"]
-      },
-      {
-        dimensions = [2],
-        names = ["hipad"],
-        transformation = "pad",
-        parameters = [0, 0],
-        source_dimensions = [2],
-        source_names = ["hi"]
-      },
-      {
-        dimensions = [3],
-        names = ["wipad"],
-        transformation = "pad",
-        parameters = [0, 0],
-        source_dimensions = [3],
-        source_names = ["wi"]
-      }
-    ],
-    source_layout = ["ni", "ci", "wi", "hi"]
-  } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
-  
-  %input_n_c_y_ho_x_wo = miopen.transform(%input_n_c_hipad_wipad) {
-    layout = [
-      {
-        dimensions = [0],
-        names = ["n"],
-        transformation = "passthrough",
-        source_dimensions = [0],
-        source_names = ["n"]
-      },
-      {
-        dimensions = [1],
-        names = ["c"],
-        transformation = "passthrough",
-        source_dimensions = [1],
-        source_names = ["c"]
-      },
-      {
-        dimensions = [2, 3],
-        names = ["y", "ho"],
-        transformation = "embed",
-        parameters = [2, [1, 1, 0]],
-        source_dimensions = [2],
-        source_names = ["hipad"]
-      },
-      {
-        dimensions = [4, 5],
-        names = ["x", "wo"],
-        transformation = "embed",
-        parameters = [2, [1, 1, 0]],
-        source_dimensions = [2],
-        source_names = ["wipad"]
-      }
-    ],
-    intermediate_layout = ["n", "c", "hipad", "wipad"]
-  } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?x?xf32>
-  
-  %input_gemmK_gemmN = miopen.transform(%input_n_c_y_ho_x_wo) {
-    layout = [
-      {
-        dimensions = [0],
-        names = ["gemmK"],
-        transformation = "merge",
-        source_dimensions = [1, 2, 4],
-        source_names = ["c", "y", "x"]
-      },
-      {
-        dimensions = [1],
-        names = ["gemmN"],
-        transformation = "merge",
-        source_dimensions = [0, 3, 5],
-        source_names = ["n", "hipad", "wipad"]
-      }
-    ],
-    intermediate_layout = ["n", "c", "y", "hipad", "x", "wipad"]
-  } : memref<?x?x?x?x?x?x?xf32> to memref<?x?xf32>
-  
-  // output tensor
-  %output_gemmM_gemmN = miopen.transform(%output) {
-    layout = [
-      {
-        dimensions = [0],
-        names = ["gemmM"],
-        transformation = "passthrough",
-        source_dimensions = [1],
-        source_names = ["ko"]
-      },
-      {
-        dimensions = [1],
-        names = ["gemmN"],
-        transformation = "merge",
-        source_dimensions = [0, 2, 3],
-        source_names = ["no", "ho", "wo"]
-      }
-    ],
-    source_layout = ["no", "ko", "ho", "wo"]
-  } : memref<?x?x?x?xf32> to memref<?x?xf32>
-  
-  // apply gridwise GEMM
-  miopen.gridwise_gemm(%filter_gemmK_gemmM, %input_gemmK_gemmN, %output_gemmM_gemmN) {
-    parameters = [
-      // tuning parameters
-    ]
-  } : memref<?x?xf32>,
-      memref<?x?xf32>,
-      memref<?x?xf32>
-
-  return
-}
-// CHECK:    constexpr auto weight_k_c_y_x_desc = make_native_tensor_descriptor(Sequence<k, c, y, x>{}, Sequence<stride_k, stride_c, stride_y, stride_x>{});
-// CHECK:     constexpr auto input_ni_ci_wi_hi_desc = make_native_tensor_descriptor(Sequence<ni, ci, wi, hi>{}, Sequence<stride_ni, stride_ci, stride_wi, stride_hi>{});
-// CHECK:     constexpr auto output_no_ko_ho_wo_desc = make_native_tensor_descriptor(Sequence<no, ko, ho, wo>{}, Sequence<stride_no, stride_ko, stride_ho, stride_wo>{});
-// CHECK:         constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_kcyx_niciwihi_nokohowo
-// CHECK:        decltype(weight_k_c_y_x_desc),
-// CHECK:        decltype(input_ni_ci_wi_hi_desc),
-// CHECK:        decltype(output_no_ko_ho_wo_desc),

diff  --git a/mlir/test/Dialect/MIOpen/translate.mlir b/mlir/test/Dialect/MIOpen/translate.mlir
new file mode 100644
index 000000000000..1139100262e2
--- /dev/null
+++ b/mlir/test/Dialect/MIOpen/translate.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-translate -mlir-to-miopen-cpp %s | FileCheck -check-prefix=MIOPEN-CPP %s
+// RUN: mlir-translate -mlir-to-miopen-hpp %s | FileCheck -check-prefix=MIOPEN-HPP %s
+
+// MIOPEN-CPP:  __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_kcyx_nicihiwi_nokohowo
+// MIOPEN-HPP: struct GridwiseConvolutionImplicitGemm_v4r4_kcyx_nicihiwi_nokohowo
+func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
+  // filter tensor
+  %filter_gemmK_gemmM = miopen.transform(%filter) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["gemmK"],
+        transformation = "Merge",
+        source_dimensions = [1, 2, 3],
+        source_names = ["c", "y", "x"]
+      },
+      {
+        dimensions = [1],
+        names = ["gemmM"],
+        transformation = "PassThrough",
+        source_dimensions = [0],
+        source_names = ["k"]
+      }
+    ],
+    source_layout = ["k", "c", "y", "x"],
+    output_layout = ["gemmK", "gemmM"],
+    gridwise_gemm_argument_position = 0
+  } : memref<?x?x?x?xf32> to memref<?x?xf32>
+
+  // input tensor
+  %input_ni_ci_hipad_wipad = miopen.transform(%input) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["ni"],
+        transformation = "PassThrough",
+        source_dimensions = [0],
+        source_names = ["ni"]
+      },
+      {
+        dimensions = [1],
+        names = ["ci"],
+        transformation = "PassThrough",
+        source_dimensions = [1],
+        source_names = ["ci"]
+      },
+      {
+        dimensions = [2, 3],
+        names = ["hipad", "wipad"],
+        transformation = "Pad",
+        parameters = [0, 0],
+        source_dimensions = [2, 3],
+        source_names = ["hi", "wi"]
+      }
+    ],
+    source_layout = ["ni", "ci", "hi", "wi"],
+    output_layout = ["ni", "ci", "hipad", "wipad"]
+  } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
+  
+  %input_ni_ci_y_ho_x_wo = miopen.transform(%input_ni_ci_hipad_wipad) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["ni"],
+        transformation = "PassThrough",
+        source_dimensions = [0],
+        source_names = ["ni"]
+      },
+      {
+        dimensions = [1],
+        names = ["ci"],
+        transformation = "PassThrough",
+        source_dimensions = [1],
+        source_names = ["ci"]
+      },
+      {
+        dimensions = [2, 3],
+        names = ["y", "ho"],
+        transformation = "Embed",
+        parameters = [2, [1, 1, 0]],
+        source_dimensions = [2],
+        source_names = ["hipad"]
+      },
+      {
+        dimensions = [4, 5],
+        names = ["x", "wo"],
+        transformation = "Embed",
+        parameters = [2, [1, 1, 0]],
+        source_dimensions = [3],
+        source_names = ["wipad"]
+      }
+    ],
+    intermediate_layout = ["ni", "ci", "hipad", "wipad"],
+    output_layout = ["ni", "ci", "y", "ho", "x", "wo"]
+  } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?x?xf32>
+  
+  %input_gemmK_gemmN = miopen.transform(%input_ni_ci_y_ho_x_wo) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["gemmK"],
+        transformation = "Merge",
+        source_dimensions = [1, 2, 4],
+        source_names = ["ci", "y", "x"]
+      },
+      {
+        dimensions = [1],
+        names = ["gemmN"],
+        transformation = "Merge",
+        source_dimensions = [0, 3, 5],
+        source_names = ["ni", "ho", "wo"]
+      }
+    ],
+    intermediate_layout = ["ni", "ci", "y", "ho", "x", "wo"],
+    output_layout = ["gemmK", "gemmN"],
+    gridwise_gemm_argument_position = 1
+  } : memref<?x?x?x?x?x?x?xf32> to memref<?x?xf32>
+  
+  // output tensor
+  %output_gemmM_gemmN = miopen.transform(%output) {
+    layout = [
+      {
+        dimensions = [0],
+        names = ["gemmM"],
+        transformation = "PassThrough",
+        source_dimensions = [1],
+        source_names = ["ko"]
+      },
+      {
+        dimensions = [1],
+        names = ["gemmN"],
+        transformation = "Merge",
+        source_dimensions = [0, 2, 3],
+        source_names = ["no", "ho", "wo"]
+      }
+    ],
+    source_layout = ["no", "ko", "ho", "wo"],
+    output_layout = ["gemmM", "gemmN"],
+    gridwise_gemm_argument_position = 2
+  } : memref<?x?x?x?xf32> to memref<?x?xf32>
+  
+  // apply gridwise GEMM
+  miopen.gridwise_gemm(%filter_gemmK_gemmM, %input_gemmK_gemmN, %output_gemmM_gemmN) {
+    parameters = [
+      // tuning parameters
+    ]
+  } : memref<?x?xf32>,
+      memref<?x?xf32>,
+      memref<?x?xf32>
+
+  return
+}
+// MIOPEN-CPP:    constexpr auto weight_k_c_y_x_desc = make_native_tensor_descriptor(Sequence<k, c, y, x>{}, Sequence<stride_k, stride_c, stride_y, stride_x>{});
+// MIOPEN-CPP:     constexpr auto input_ni_ci_hi_wi_desc = make_native_tensor_descriptor(Sequence<ni, ci, hi, wi>{}, Sequence<stride_ni, stride_ci, stride_hi, stride_wi>{});
+// MIOPEN-CPP:     constexpr auto output_no_ko_ho_wo_desc = make_native_tensor_descriptor(Sequence<no, ko, ho, wo>{}, Sequence<stride_no, stride_ko, stride_ho, stride_wo>{});
+// MIOPEN-CPP:         constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_kcyx_nicihiwi_nokohowo
+// MIOPEN-CPP:        decltype(weight_k_c_y_x_desc),
+// MIOPEN-CPP:        decltype(input_ni_ci_hi_wi_desc),
+// MIOPEN-CPP:        decltype(output_no_ko_ho_wo_desc),


        


More information about the llvm-branch-commits mailing list