[llvm-branch-commits] [mlir] 05d19a7 - Add Op traversing logic into MIOpen dialect -> C++ translator.

Wen-Heng Chung via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:14 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:19-05:00
New Revision: 05d19a7eb4ea27d4d0e7989b145c2f7458fff54a

URL: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a
DIFF: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a.diff

LOG: Add Op traversing logic into MIOpen dialect -> C++ translator.

Added: 
    mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir

Modified: 
    mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
    mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
    mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp

Removed: 
    mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
index d3e9b8ee09a2..09d2d1166caf 100644
--- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
+++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
@@ -33,7 +33,17 @@ class ModuleOp;
 /// Convert the given MLIR module into MIOpen C++ . In case of error, report it
 /// to the error handler registered with the MLIR context, if any (obtained from
 /// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCPP(ModuleOp m);
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCpp(ModuleOp m);
+
+/// Convert the given MLIR module into MIOpen C++ Header. In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenHeader(ModuleOp m);
+
+/// Convert the given MLIR module into MIOpen C++ Solver. In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenSolver(ModuleOp m);
 
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
index 855985b4b945..3d37305c60e7 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
@@ -5,6 +5,7 @@ add_llvm_library(MLIRMIOpenCpp
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MIOpenOps
   )
 target_link_libraries(MLIRMIOpenCpp
+  LLVMSupport
   MLIRIR
   MLIRMIOpenOps
   MLIRStandardOps

diff  --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index 5fe33d695cb3..b071565c2b51 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -13,34 +13,396 @@
 #include "mlir/Dialect/MIOpenOps/MIOpenCPP.h"
 #include "mlir/Dialect/MIOpenOps/MIOpenOps.h"
 #include "mlir/Dialect/StandardOps/Ops.h"
-
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
 #include "mlir/Translation.h"
 
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Support/ToolOutputFile.h"
 
 using namespace mlir;
 
-std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCPP(ModuleOp m) {
-  // Check constraints:
+namespace {
+
+static constexpr StringLiteral kVarName[3] = {"weight", "input", "output"};
+
+static constexpr int kConv2DTensorDimension = 4;
+
+static constexpr StringLiteral kCppPreamblePart1 = R"(
+#include "common_header.hpp"
+)";
+
+static constexpr StringLiteral kCppPreamblePart2 = R"(
+#include "float_types.h"
+
+extern "C" __global__
+)";
+
+static constexpr StringLiteral kCppPreamblePart3 = R"(
+        (const FLOAT* const __restrict__ p_in_global,
+        const FLOAT* const __restrict__ p_wei_global,
+        FLOAT* const __restrict__ p_out_global)
+{
+    using namespace ck;
+
+    constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
+    constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
+
+    constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
+    constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
+
+    constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
+    constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
+
+    constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
+    constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
+
+    constexpr index_t BlockSize = CK_PARAM_TUNABLE_BLOCK_SIZE;
+    constexpr index_t GridSize  = CK_PARAM_DEPENDENT_GRID_SIZE;
+
+    constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
+    constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
+    constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
+
+)";
+
+static constexpr StringLiteral kCppInterlude = R"(
+    using ConvStrides   = Sequence<ConvStrideH, ConvStrideW>;
+    using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
+
+    using InLeftPads  = Sequence<InLeftPadH, InLeftPadW>;
+    using InRightPads = Sequence<InRightPadH, InRightPadW>;
+
+    // read and calculate tuning parameter
+    constexpr index_t GemmMPerThreadSubC = CK_PARAM_TUNABLE_GEMM_M_PER_THREAD_SUB_C;
+    constexpr index_t GemmNPerThreadSubC = CK_PARAM_TUNABLE_GEMM_N_PER_THREAD_SUB_C;
+    constexpr index_t GemmMLevel0Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER;
+    constexpr index_t GemmNLevel0Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER;
+    constexpr index_t GemmMLevel1Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER;
+    constexpr index_t GemmNLevel1Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER;
+    constexpr index_t GemmKPerThreadLoop = 1;
+
+    constexpr index_t GemmThreadGemmDataPerReadM = GemmMPerThreadSubC;
+    constexpr index_t GemmThreadGemmDataPerReadN = GemmNPerThreadSubC;
+
+    // A matrix
+    constexpr index_t GemmABlockCopyClusterLengths_GemmK =
+        CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
+
+    constexpr index_t GemmABlockCopyClusterLengths_GemmM =
+        CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
+
+    constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
+        GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
+
+    constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM =
+        GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM;
+
+    using GemmABlockCopyThreadSliceLengths_GemmK_GemmM =
+        Sequence<GemmABlockCopyThreadSliceLengths_GemmK, GemmABlockCopyThreadSliceLengths_GemmM>;
+
+    using GemmABlockCopyThreadClusterLengths_GemmK_GemmM =
+        Sequence<GemmABlockCopyClusterLengths_GemmK, GemmABlockCopyClusterLengths_GemmM>;
+
+    constexpr index_t GemmABlockCopySrcDataPerRead_GemmK =
+        CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K;
+
+    constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM =
+        CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M;
+
+    // B matrix
+    constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
+        CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
+
+    constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
+        CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
+
+    constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
+        GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
+
+    constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN =
+        GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN;
+    using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN =
+        Sequence<GemmBBlockCopyThreadSliceLengths_GemmK, GemmBBlockCopyThreadSliceLengths_GemmN>;
+
+    using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN =
+        Sequence<GemmBBlockCopyClusterLengths_GemmK, GemmBBlockCopyClusterLengths_GemmN>;
+
+    constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
+        CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
+
+    constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN =
+        CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N;
+
+    // C matrix
+    constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 =
+        CK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1;
+)";
+
+static constexpr StringLiteral kCppEpiloguePart1 = R"(
+        <GridSize,
+        BlockSize,
+        FLOAT,
+        FLOAT_ACCUM,
+)";
+
+static constexpr StringLiteral kCppEpiloguePart2 =R"(
+        ConvStrides,
+        ConvDilations,
+        InLeftPads,
+        InRightPads,
+        GemmMPerBlock,
+        GemmNPerBlock,
+        GemmKPerBlock,
+        GemmMPerThreadSubC,
+        GemmNPerThreadSubC,
+        GemmMLevel0Cluster,
+        GemmNLevel0Cluster,
+        GemmMLevel1Cluster,
+        GemmNLevel1Cluster,
+        GemmKPerThreadLoop,
+        GemmThreadGemmDataPerReadM,
+        GemmThreadGemmDataPerReadN,
+        GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
+        GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
+        GemmABlockCopySrcDataPerRead_GemmK,
+        GemmABlockCopyDstDataPerWrite_GemmM,
+        GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
+        GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
+        GemmBBlockCopySrcDataPerRead_GemmN,
+        GemmBBlockCopyDstDataPerWrite_GemmN,
+        GemmCThreadCopyDstDataPerWrite_GemmN1>{};
+
+    gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
+}
+)";
+ 
+void EmitCppPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr) {
+  output << kCppPreamblePart1;
+
+// Between Preamble Part 1 and Part 2:
+// #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
+  output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)";
+  output << layoutStr << ".hpp";
+
+  output << kCppPreamblePart2;
+
+// Between Preamble Part 2 and Par 3:
+//    __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(
+   output << R"(
+    __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_)";
+   output << layoutStr;
+
+  output << kCppPreamblePart3;
+}
+
+void EmitCppInterlude(llvm::raw_ostream &output) {
+  output << kCppInterlude;
+}
+
+void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> tensorDescs) {
+// Before Part1:
+//    constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
+  output << R"(
+    constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_)";
+  output << layoutStr;
+
+  output << kCppEpiloguePart1;
+
+// Between Part1 and Part2:
+//        decltype(in_nchw_desc),
+//        decltype(wei_kcyx_desc),
+//        decltype(out_nkhw_desc),
+  for (auto desc : tensorDescs) {
+    output << "        decltype(" << desc << "),\n";
+  }
+
+  output << kCppEpiloguePart2;
+}
+
+void EmitLayoutString(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr, llvm::StringRef prefix, llvm::StringRef suffix, llvm::StringRef delimiter = "") {
+  for (int i = 0; i < kConv2DTensorDimension; ++i) {
+    auto attr = layoutArrayAttr[i];
+    if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+      output << prefix << strAttr.getValue() << suffix;
+    }
+    if (i < kConv2DTensorDimension - 1) {
+      output << delimiter;
+    }
+  }
+}
+
+void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) {
+  for (int i = 0; i < kConv2DTensorDimension; ++i) {
+    auto attr = layoutArrayAttr[i];
+    if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+      output << "    const index_t " << strAttr.getValue() << " = CK_PARAM_PROBLEM_";
+
+      switch (llvm::toUpper(strAttr.getValue()[0])) {
+          case 'H':
+          case 'W':
+            output << llvm::toUpper(strAttr.getValue()[0]);
+            output << llvm::toUpper(strAttr.getValue()[1]);
+            break;
+          default:
+            output << llvm::toUpper(strAttr.getValue()[0]);
+      }
+      output << ";\n";
+    }
+  }
+}
+
+void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) {
+  for (int i = 0; i < kConv2DTensorDimension; ++i) {
+    auto attr = layoutArrayAttr[i];
+    if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+      output << "    const index_t stride_" << strAttr.getValue() << " = ";
+
+      if (i == 0) {
+        output << "1;\n";
+      } else {
+        auto prevAttr = layoutArrayAttr[i - 1];
+        if (auto strPrevAttr = prevAttr.dyn_cast<StringAttr>()) {
+          output << strPrevAttr.getValue() << " * stride_" << strPrevAttr.getValue() << ";\n";
+        }
+      }
+    }
+  }
+}
+
+void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
+  // (TBD verifiying logic) The Module could contain multiple FuncOp, and inside each FuncOp there
+  // should be exactly:
+  // - 3 input arguments
+  // - 1 result.
   //
-  // The Module should only contain 1 function.
-  // The Function should only contain exactly:
   // - 0 conv2d op.
   // - 5 transform ops (1 for filter, 3 for input, 1 for output).
   // - 1 gridwise gemm op.
-  m.dump();
 
-  return std::make_unique<llvm::StringRef>("Hello World");
+  // Enumerate FuncOp instances inside the ModuleOp.
+  for (auto f : m.getOps<FuncOp>()) {
+    int srcLayoutAttrCtr = 0;
+    llvm::raw_string_ostream los(layoutStr);
+
+    // First iteration. Construct tensor descriptor names.
+    f.walk([&srcLayoutAttrCtr, &tensorDescs, &los](miopen::TransformOp op) {
+      // get source_layout attribute.
+      auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+      if (srcLayoutAttr) {
+        auto srcLayout = srcLayoutAttr.getValue();
+
+        // Prepare tensor descriptor variable name.
+        std::string desc{};
+        llvm::raw_string_ostream os(desc);
+        os << kVarName[srcLayoutAttrCtr++] << "_";
+        EmitLayoutString(os, srcLayout, "", "", "_");
+        os << "_desc";
+        os.flush();
+        tensorDescs.push_back(desc);
+
+        // Prepare layout string.
+        if (srcLayoutAttrCtr != 1)
+          los << "_";
+        EmitLayoutString(los, srcLayout, "", "");
+      }
+    });
+    los.flush();
+  }
+}
+
+}
+
+std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
+  std::string resultStr;
+  llvm::raw_string_ostream output(resultStr);
+
+  // Enumerate FuncOp instances inside the ModuleOp.
+  for (auto f : m.getOps<FuncOp>()) {
+    std::string layoutStr;
+    llvm::SmallVector<std::string, 3> tensorDescs;
+
+    // Obtain critical information from ModuleOp.
+    ObtainModuleInfo(m, layoutStr, tensorDescs);
+
+    int srcLayoutAttrCtr = 0;
+
+    // Start emitting.
+
+    EmitCppPreamble(output, layoutStr);
+
+    f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) {
+
+      // get source_layout attribute.
+      auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+      if (srcLayoutAttr) {
+        auto srcLayout = srcLayoutAttr.getValue();
+        output << "    // ";
+        EmitLayoutString(output, srcLayout, "", "", ", ");
+        output << '\n';
+
+        EmitDimensionVariables(output, srcLayout);
+        output << '\n';
+        EmitStrideVariables(output, srcLayout);
+
+        output << "    constexpr auto " << tensorDescs[srcLayoutAttrCtr++];
+        output << " = make_native_tensor_descriptor(Sequence<";
+        EmitLayoutString(output, srcLayout, "", "", ", ");
+        output << ">{}, Sequence<";
+        EmitLayoutString(output, srcLayout, "stride_", "", ", ");
+        output << ">{});\n\n";
+      }
+
+      //// get layout attribute.
+      // TBD not used in emitting C++ source wrapper.
+      // would be used in emitting C++ header.
+      //auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout");
+      //for (auto layoutSpec : layoutAttr) {
+      //  if (auto layoutSpecDict = layoutSpec.dyn_cast<DictionaryAttr>()) {
+      //    //output << "dimensions: " << layoutSpecDict.get("dimensions") << "\n";
+      //    //output << "names: " << layoutSpecDict.get("names") << "\n";
+      //    //output << "source_dimensions: " << layoutSpecDict.get("source_dimensions") << "\n";
+      //    //output << "source_names: " << layoutSpecDict.get("source_names") << "\n";
+      //    //output << "transformation: " << layoutSpecDict.get("transformation") << "\n";
+      //  }
+      //}
+    });
+
+    EmitCppInterlude(output);
+
+    // TBD get tuning parameters.
+    //f.walk([&output](miopen::GridwiseGemmOp op) {
+    //  // get op name.
+    //  //output << "op name: " << op.getOperationName() << "\n";
+    //  //op.dump();
+    //});
+
+    EmitCppEpilogue(output, layoutStr, tensorDescs);
+  }
+
+  output.flush();
+  return std::make_unique<llvm::StringRef>(resultStr);
 }
 
 static TranslateFromMLIRRegistration
-    toCPP("mlir-to-miopencpp", [](ModuleOp module, llvm::raw_ostream &output) {
-      auto sourceCode = mlir::translateModuleToMIOpenCPP(module);
+    toCpp("mlir-to-miopen-cpp", [](ModuleOp module, llvm::raw_ostream &output) {
+      auto sourceCode = mlir::translateModuleToMIOpenCpp(module);
       if (!sourceCode)
         return failure();
 
       output << *sourceCode;
       return success();
     });
+
+//static TranslateFromMLIRRegistration
+//    toHeader("mlir-to-miopen-h", [](ModuleOp module, llvm::raw_ostream &output) {
+//      auto sourceCode = mlir::translateModuleToMIOpenHeader(module);
+//      if (!sourceCode)
+//        return failure();
+//
+//      output << *sourceCode;
+//      return success();
+//    });
+

diff  --git a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
similarity index 70%
rename from mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir
rename to mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
index 4b4bc0031717..ca9751191859 100644
--- a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir
+++ b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-translate -mlir-to-miopencpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-miopen-cpp %s | FileCheck %s
 
-// CHECK: Hello World
+// CHECK:  __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_kcyx_niciwihi_nokohowo
 func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
   // filter tensor
   %filter_gemmK_gemmM = miopen.transform(%filter) {
@@ -19,7 +19,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         source_dimensions = [0],
         source_names = ["n"]
       }
-    ]
+    ],
+    source_layout = ["k", "c", "y", "x"]
   } : memref<?x?x?x?xf32> to memref<?x?xf32>
 
   // input tensor
@@ -30,14 +31,14 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         names = ["n"],
         transformation = "passthorugh",
         source_dimensions = [0],
-        source_names = ["n"]
+        source_names = ["ni"]
       },
       {
         dimensions = [1],
         names = ["c"],
         transformation = "passthorugh",
         source_dimensions = [1],
-        source_names = ["c"]
+        source_names = ["ci"]
       },
       {
         dimensions = [2],
@@ -55,7 +56,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         source_dimensions = [3],
         source_names = ["wi"]
       }
-    ]
+    ],
+    source_layout = ["ni", "ci", "wi", "hi"]
   } : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
   
   %input_n_c_y_ho_x_wo = miopen.transform(%input_n_c_hipad_wipad) {
@@ -90,7 +92,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         source_dimensions = [2],
         source_names = ["wipad"]
       }
-    ]
+    ],
+    intermediate_layout = ["n", "c", "hipad", "wipad"]
   } : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?x?xf32>
   
   %input_gemmK_gemmN = miopen.transform(%input_n_c_y_ho_x_wo) {
@@ -107,9 +110,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         names = ["gemmN"],
         transformation = "merge",
         source_dimensions = [0, 3, 5],
-        source_names = ["n", "ho", "wo"]
+        source_names = ["n", "hipad", "wipad"]
       }
-    ]
+    ],
+    intermediate_layout = ["n", "c", "y", "hipad", "x", "wipad"]
   } : memref<?x?x?x?x?x?x?xf32> to memref<?x?xf32>
   
   // output tensor
@@ -120,16 +124,17 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
         names = ["gemmM"],
         transformation = "passthrough",
         source_dimensions = [1],
-        source_names = ["k"]
+        source_names = ["ko"]
       },
       {
         dimensions = [1],
         names = ["gemmN"],
         transformation = "merge",
         source_dimensions = [0, 2, 3],
-        source_names = ["n", "ho", "wo"]
+        source_names = ["no", "ho", "wo"]
       }
-    ]
+    ],
+    source_layout = ["no", "ko", "ho", "wo"]
   } : memref<?x?x?x?xf32> to memref<?x?xf32>
   
   // apply gridwise GEMM
@@ -143,3 +148,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
 
   return
 }
+// CHECK:    constexpr auto weight_k_c_y_x_desc = make_native_tensor_descriptor(Sequence<k, c, y, x>{}, Sequence<stride_k, stride_c, stride_y, stride_x>{});
+// CHECK:     constexpr auto input_ni_ci_wi_hi_desc = make_native_tensor_descriptor(Sequence<ni, ci, wi, hi>{}, Sequence<stride_ni, stride_ci, stride_wi, stride_hi>{});
+// CHECK:     constexpr auto output_no_ko_ho_wo_desc = make_native_tensor_descriptor(Sequence<no, ko, ho, wo>{}, Sequence<stride_no, stride_ko, stride_ho, stride_wo>{});
+// CHECK:         constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_kcyx_niciwihi_nokohowo
+// CHECK:        decltype(weight_k_c_y_x_desc),
+// CHECK:        decltype(input_ni_ci_wi_hi_desc),
+// CHECK:        decltype(output_no_ko_ho_wo_desc),


        


More information about the llvm-branch-commits mailing list