[llvm-branch-commits] [mlir] 05d19a7 - Add Op traversing logic into MIOpen dialect -> C++ translator.
Wen-Heng Chung via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:14 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:19-05:00
New Revision: 05d19a7eb4ea27d4d0e7989b145c2f7458fff54a
URL: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a
DIFF: https://github.com/llvm/llvm-project/commit/05d19a7eb4ea27d4d0e7989b145c2f7458fff54a.diff
LOG: Add Op traversing logic into MIOpen dialect -> C++ translator.
Added:
mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
Modified:
mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
Removed:
mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
index d3e9b8ee09a2..09d2d1166caf 100644
--- a/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
+++ b/mlir/include/mlir/Dialect/MIOpenOps/MIOpenCPP.h
@@ -33,7 +33,17 @@ class ModuleOp;
/// Convert the given MLIR module into MIOpen C++ . In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCPP(ModuleOp m);
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenCpp(ModuleOp m);
+
+/// Convert the given MLIR module into MIOpen C++ Header. In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenHeader(ModuleOp m);
+
+/// Convert the given MLIR module into MIOpen C++ Solver. In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::StringRef> translateModuleToMIOpenSolver(ModuleOp m);
} // namespace mlir
diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
index 855985b4b945..3d37305c60e7 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/CMakeLists.txt
@@ -5,6 +5,7 @@ add_llvm_library(MLIRMIOpenCpp
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MIOpenOps
)
target_link_libraries(MLIRMIOpenCpp
+ LLVMSupport
MLIRIR
MLIRMIOpenOps
MLIRStandardOps
diff --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index 5fe33d695cb3..b071565c2b51 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -13,34 +13,396 @@
#include "mlir/Dialect/MIOpenOps/MIOpenCPP.h"
#include "mlir/Dialect/MIOpenOps/MIOpenOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
-
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
#include "mlir/Translation.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
-std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCPP(ModuleOp m) {
- // Check constraints:
+namespace {
+
+static constexpr StringLiteral kVarName[3] = {"weight", "input", "output"};
+
+static constexpr int kConv2DTensorDimension = 4;
+
+static constexpr StringLiteral kCppPreamblePart1 = R"(
+#include "common_header.hpp"
+)";
+
+static constexpr StringLiteral kCppPreamblePart2 = R"(
+#include "float_types.h"
+
+extern "C" __global__
+)";
+
+static constexpr StringLiteral kCppPreamblePart3 = R"(
+ (const FLOAT* const __restrict__ p_in_global,
+ const FLOAT* const __restrict__ p_wei_global,
+ FLOAT* const __restrict__ p_out_global)
+{
+ using namespace ck;
+
+ constexpr index_t ConvStrideH = CK_PARAM_PROBLEM_CONV_STRIDE_H;
+ constexpr index_t ConvStrideW = CK_PARAM_PROBLEM_CONV_STRIDE_W;
+
+ constexpr index_t ConvDilationH = CK_PARAM_PROBLEM_CONV_DILATION_H;
+ constexpr index_t ConvDilationW = CK_PARAM_PROBLEM_CONV_DILATION_W;
+
+ constexpr index_t InLeftPadH = CK_PARAM_PROBLEM_IN_LEFT_PAD_H;
+ constexpr index_t InLeftPadW = CK_PARAM_PROBLEM_IN_LEFT_PAD_W;
+
+ constexpr index_t InRightPadH = CK_PARAM_PROBLEM_IN_RIGHT_PAD_H;
+ constexpr index_t InRightPadW = CK_PARAM_PROBLEM_IN_RIGHT_PAD_W;
+
+ constexpr index_t BlockSize = CK_PARAM_TUNABLE_BLOCK_SIZE;
+ constexpr index_t GridSize = CK_PARAM_DEPENDENT_GRID_SIZE;
+
+ constexpr index_t GemmMPerBlock = CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK;
+ constexpr index_t GemmNPerBlock = CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK;
+ constexpr index_t GemmKPerBlock = CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK;
+
+)";
+
+static constexpr StringLiteral kCppInterlude = R"(
+ using ConvStrides = Sequence<ConvStrideH, ConvStrideW>;
+ using ConvDilations = Sequence<ConvDilationH, ConvDilationW>;
+
+ using InLeftPads = Sequence<InLeftPadH, InLeftPadW>;
+ using InRightPads = Sequence<InRightPadH, InRightPadW>;
+
+ // read and calculate tuning parameter
+ constexpr index_t GemmMPerThreadSubC = CK_PARAM_TUNABLE_GEMM_M_PER_THREAD_SUB_C;
+ constexpr index_t GemmNPerThreadSubC = CK_PARAM_TUNABLE_GEMM_N_PER_THREAD_SUB_C;
+ constexpr index_t GemmMLevel0Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER;
+ constexpr index_t GemmNLevel0Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER;
+ constexpr index_t GemmMLevel1Cluster = CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER;
+ constexpr index_t GemmNLevel1Cluster = CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER;
+ constexpr index_t GemmKPerThreadLoop = 1;
+
+ constexpr index_t GemmThreadGemmDataPerReadM = GemmMPerThreadSubC;
+ constexpr index_t GemmThreadGemmDataPerReadN = GemmNPerThreadSubC;
+
+ // A matrix
+ constexpr index_t GemmABlockCopyClusterLengths_GemmK =
+ CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
+
+ constexpr index_t GemmABlockCopyClusterLengths_GemmM =
+ CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M;
+
+ constexpr index_t GemmABlockCopyThreadSliceLengths_GemmK =
+ GemmKPerBlock / GemmABlockCopyClusterLengths_GemmK;
+
+ constexpr index_t GemmABlockCopyThreadSliceLengths_GemmM =
+ GemmMPerBlock / GemmABlockCopyClusterLengths_GemmM;
+
+ using GemmABlockCopyThreadSliceLengths_GemmK_GemmM =
+ Sequence<GemmABlockCopyThreadSliceLengths_GemmK, GemmABlockCopyThreadSliceLengths_GemmM>;
+
+ using GemmABlockCopyThreadClusterLengths_GemmK_GemmM =
+ Sequence<GemmABlockCopyClusterLengths_GemmK, GemmABlockCopyClusterLengths_GemmM>;
+
+ constexpr index_t GemmABlockCopySrcDataPerRead_GemmK =
+ CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K;
+
+ constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM =
+ CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M;
+
+ // B matrix
+ constexpr index_t GemmBBlockCopyClusterLengths_GemmK =
+ CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K;
+
+ constexpr index_t GemmBBlockCopyClusterLengths_GemmN =
+ CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N;
+
+ constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmK =
+ GemmKPerBlock / GemmBBlockCopyClusterLengths_GemmK;
+
+ constexpr index_t GemmBBlockCopyThreadSliceLengths_GemmN =
+ GemmNPerBlock / GemmBBlockCopyClusterLengths_GemmN;
+ using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN =
+ Sequence<GemmBBlockCopyThreadSliceLengths_GemmK, GemmBBlockCopyThreadSliceLengths_GemmN>;
+
+ using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN =
+ Sequence<GemmBBlockCopyClusterLengths_GemmK, GemmBBlockCopyClusterLengths_GemmN>;
+
+ constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
+ CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
+
+ constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN =
+ CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N;
+
+ // C matrix
+ constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 =
+ CK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1;
+)";
+
+static constexpr StringLiteral kCppEpiloguePart1 = R"(
+ <GridSize,
+ BlockSize,
+ FLOAT,
+ FLOAT_ACCUM,
+)";
+
+static constexpr StringLiteral kCppEpiloguePart2 =R"(
+ ConvStrides,
+ ConvDilations,
+ InLeftPads,
+ InRightPads,
+ GemmMPerBlock,
+ GemmNPerBlock,
+ GemmKPerBlock,
+ GemmMPerThreadSubC,
+ GemmNPerThreadSubC,
+ GemmMLevel0Cluster,
+ GemmNLevel0Cluster,
+ GemmMLevel1Cluster,
+ GemmNLevel1Cluster,
+ GemmKPerThreadLoop,
+ GemmThreadGemmDataPerReadM,
+ GemmThreadGemmDataPerReadN,
+ GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
+ GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
+ GemmABlockCopySrcDataPerRead_GemmK,
+ GemmABlockCopyDstDataPerWrite_GemmM,
+ GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
+ GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
+ GemmBBlockCopySrcDataPerRead_GemmN,
+ GemmBBlockCopyDstDataPerWrite_GemmN,
+ GemmCThreadCopyDstDataPerWrite_GemmN1>{};
+
+ gridwise_conv.Run(p_in_global, p_wei_global, p_out_global);
+}
+)";
+
+void EmitCppPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr) {
+ output << kCppPreamblePart1;
+
+// Between Preamble Part 1 and Part 2:
+// #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
+ output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)";
+ output << layoutStr << ".hpp";
+
+ output << kCppPreamblePart2;
+
+// Between Preamble Part 2 and Par 3:
+// __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(
+ output << R"(
+ __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_)";
+ output << layoutStr;
+
+ output << kCppPreamblePart3;
+}
+
+void EmitCppInterlude(llvm::raw_ostream &output) {
+ output << kCppInterlude;
+}
+
+void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> tensorDescs) {
+// Before Part1:
+// constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
+ output << R"(
+ constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_)";
+ output << layoutStr;
+
+ output << kCppEpiloguePart1;
+
+// Between Part1 and Part2:
+// decltype(in_nchw_desc),
+// decltype(wei_kcyx_desc),
+// decltype(out_nkhw_desc),
+ for (auto desc : tensorDescs) {
+ output << " decltype(" << desc << "),\n";
+ }
+
+ output << kCppEpiloguePart2;
+}
+
+void EmitLayoutString(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr, llvm::StringRef prefix, llvm::StringRef suffix, llvm::StringRef delimiter = "") {
+ for (int i = 0; i < kConv2DTensorDimension; ++i) {
+ auto attr = layoutArrayAttr[i];
+ if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+ output << prefix << strAttr.getValue() << suffix;
+ }
+ if (i < kConv2DTensorDimension - 1) {
+ output << delimiter;
+ }
+ }
+}
+
+void EmitDimensionVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) {
+ for (int i = 0; i < kConv2DTensorDimension; ++i) {
+ auto attr = layoutArrayAttr[i];
+ if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+ output << " const index_t " << strAttr.getValue() << " = CK_PARAM_PROBLEM_";
+
+ switch (llvm::toUpper(strAttr.getValue()[0])) {
+ case 'H':
+ case 'W':
+ output << llvm::toUpper(strAttr.getValue()[0]);
+ output << llvm::toUpper(strAttr.getValue()[1]);
+ break;
+ default:
+ output << llvm::toUpper(strAttr.getValue()[0]);
+ }
+ output << ";\n";
+ }
+ }
+}
+
+void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribute> &layoutArrayAttr) {
+ for (int i = 0; i < kConv2DTensorDimension; ++i) {
+ auto attr = layoutArrayAttr[i];
+ if (auto strAttr = attr.dyn_cast<StringAttr>()) {
+ output << " const index_t stride_" << strAttr.getValue() << " = ";
+
+ if (i == 0) {
+ output << "1;\n";
+ } else {
+ auto prevAttr = layoutArrayAttr[i - 1];
+ if (auto strPrevAttr = prevAttr.dyn_cast<StringAttr>()) {
+ output << strPrevAttr.getValue() << " * stride_" << strPrevAttr.getValue() << ";\n";
+ }
+ }
+ }
+ }
+}
+
+void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
+ // (TBD verifiying logic) The Module could contain multiple FuncOp, and inside each FuncOp there
+ // should be exactly:
+ // - 3 input arguments
+ // - 1 result.
//
- // The Module should only contain 1 function.
- // The Function should only contain exactly:
// - 0 conv2d op.
// - 5 transform ops (1 for filter, 3 for input, 1 for output).
// - 1 gridwise gemm op.
- m.dump();
- return std::make_unique<llvm::StringRef>("Hello World");
+ // Enumerate FuncOp instances inside the ModuleOp.
+ for (auto f : m.getOps<FuncOp>()) {
+ int srcLayoutAttrCtr = 0;
+ llvm::raw_string_ostream los(layoutStr);
+
+ // First iteration. Construct tensor descriptor names.
+ f.walk([&srcLayoutAttrCtr, &tensorDescs, &los](miopen::TransformOp op) {
+ // get source_layout attribute.
+ auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+ if (srcLayoutAttr) {
+ auto srcLayout = srcLayoutAttr.getValue();
+
+ // Prepare tensor descriptor variable name.
+ std::string desc{};
+ llvm::raw_string_ostream os(desc);
+ os << kVarName[srcLayoutAttrCtr++] << "_";
+ EmitLayoutString(os, srcLayout, "", "", "_");
+ os << "_desc";
+ os.flush();
+ tensorDescs.push_back(desc);
+
+ // Prepare layout string.
+ if (srcLayoutAttrCtr != 1)
+ los << "_";
+ EmitLayoutString(los, srcLayout, "", "");
+ }
+ });
+ los.flush();
+ }
+}
+
+}
+
+std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCpp(ModuleOp m) {
+ std::string resultStr;
+ llvm::raw_string_ostream output(resultStr);
+
+ // Enumerate FuncOp instances inside the ModuleOp.
+ for (auto f : m.getOps<FuncOp>()) {
+ std::string layoutStr;
+ llvm::SmallVector<std::string, 3> tensorDescs;
+
+ // Obtain critical information from ModuleOp.
+ ObtainModuleInfo(m, layoutStr, tensorDescs);
+
+ int srcLayoutAttrCtr = 0;
+
+ // Start emitting.
+
+ EmitCppPreamble(output, layoutStr);
+
+ f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) {
+
+ // get source_layout attribute.
+ auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+ if (srcLayoutAttr) {
+ auto srcLayout = srcLayoutAttr.getValue();
+ output << " // ";
+ EmitLayoutString(output, srcLayout, "", "", ", ");
+ output << '\n';
+
+ EmitDimensionVariables(output, srcLayout);
+ output << '\n';
+ EmitStrideVariables(output, srcLayout);
+
+ output << " constexpr auto " << tensorDescs[srcLayoutAttrCtr++];
+ output << " = make_native_tensor_descriptor(Sequence<";
+ EmitLayoutString(output, srcLayout, "", "", ", ");
+ output << ">{}, Sequence<";
+ EmitLayoutString(output, srcLayout, "stride_", "", ", ");
+ output << ">{});\n\n";
+ }
+
+ //// get layout attribute.
+ // TBD not used in emitting C++ source wrapper.
+ // would be used in emitting C++ header.
+ //auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout");
+ //for (auto layoutSpec : layoutAttr) {
+ // if (auto layoutSpecDict = layoutSpec.dyn_cast<DictionaryAttr>()) {
+ // //output << "dimensions: " << layoutSpecDict.get("dimensions") << "\n";
+ // //output << "names: " << layoutSpecDict.get("names") << "\n";
+ // //output << "source_dimensions: " << layoutSpecDict.get("source_dimensions") << "\n";
+ // //output << "source_names: " << layoutSpecDict.get("source_names") << "\n";
+ // //output << "transformation: " << layoutSpecDict.get("transformation") << "\n";
+ // }
+ //}
+ });
+
+ EmitCppInterlude(output);
+
+ // TBD get tuning parameters.
+ //f.walk([&output](miopen::GridwiseGemmOp op) {
+ // // get op name.
+ // //output << "op name: " << op.getOperationName() << "\n";
+ // //op.dump();
+ //});
+
+ EmitCppEpilogue(output, layoutStr, tensorDescs);
+ }
+
+ output.flush();
+ return std::make_unique<llvm::StringRef>(resultStr);
}
static TranslateFromMLIRRegistration
- toCPP("mlir-to-miopencpp", [](ModuleOp module, llvm::raw_ostream &output) {
- auto sourceCode = mlir::translateModuleToMIOpenCPP(module);
+ toCpp("mlir-to-miopen-cpp", [](ModuleOp module, llvm::raw_ostream &output) {
+ auto sourceCode = mlir::translateModuleToMIOpenCpp(module);
if (!sourceCode)
return failure();
output << *sourceCode;
return success();
});
+
+//static TranslateFromMLIRRegistration
+// toHeader("mlir-to-miopen-h", [](ModuleOp module, llvm::raw_ostream &output) {
+// auto sourceCode = mlir::translateModuleToMIOpenHeader(module);
+// if (!sourceCode)
+// return failure();
+//
+// output << *sourceCode;
+// return success();
+// });
+
diff --git a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
similarity index 70%
rename from mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir
rename to mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
index 4b4bc0031717..ca9751191859 100644
--- a/mlir/test/Dialect/MIOpen/CppOutput/miopencpp.mlir
+++ b/mlir/test/Dialect/MIOpen/CppOutput/transformed.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-translate -mlir-to-miopencpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-miopen-cpp %s | FileCheck %s
-// CHECK: Hello World
+// CHECK: __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_kcyx_niciwihi_nokohowo
func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?x?x?x?xf32>, %output : memref<?x?x?x?xf32>) {
// filter tensor
%filter_gemmK_gemmM = miopen.transform(%filter) {
@@ -19,7 +19,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
source_dimensions = [0],
source_names = ["n"]
}
- ]
+ ],
+ source_layout = ["k", "c", "y", "x"]
} : memref<?x?x?x?xf32> to memref<?x?xf32>
// input tensor
@@ -30,14 +31,14 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
names = ["n"],
transformation = "passthorugh",
source_dimensions = [0],
- source_names = ["n"]
+ source_names = ["ni"]
},
{
dimensions = [1],
names = ["c"],
transformation = "passthorugh",
source_dimensions = [1],
- source_names = ["c"]
+ source_names = ["ci"]
},
{
dimensions = [2],
@@ -55,7 +56,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
source_dimensions = [3],
source_names = ["wi"]
}
- ]
+ ],
+ source_layout = ["ni", "ci", "wi", "hi"]
} : memref<?x?x?x?xf32> to memref<?x?x?x?xf32>
%input_n_c_y_ho_x_wo = miopen.transform(%input_n_c_hipad_wipad) {
@@ -90,7 +92,8 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
source_dimensions = [2],
source_names = ["wipad"]
}
- ]
+ ],
+ intermediate_layout = ["n", "c", "hipad", "wipad"]
} : memref<?x?x?x?xf32> to memref<?x?x?x?x?x?x?xf32>
%input_gemmK_gemmN = miopen.transform(%input_n_c_y_ho_x_wo) {
@@ -107,9 +110,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
names = ["gemmN"],
transformation = "merge",
source_dimensions = [0, 3, 5],
- source_names = ["n", "ho", "wo"]
+ source_names = ["n", "hipad", "wipad"]
}
- ]
+ ],
+ intermediate_layout = ["n", "c", "y", "hipad", "x", "wipad"]
} : memref<?x?x?x?x?x?x?xf32> to memref<?x?xf32>
// output tensor
@@ -120,16 +124,17 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
names = ["gemmM"],
transformation = "passthrough",
source_dimensions = [1],
- source_names = ["k"]
+ source_names = ["ko"]
},
{
dimensions = [1],
names = ["gemmN"],
transformation = "merge",
source_dimensions = [0, 2, 3],
- source_names = ["n", "ho", "wo"]
+ source_names = ["no", "ho", "wo"]
}
- ]
+ ],
+ source_layout = ["no", "ko", "ho", "wo"]
} : memref<?x?x?x?xf32> to memref<?x?xf32>
// apply gridwise GEMM
@@ -143,3 +148,10 @@ func @miopen_transformed_conv2d(%filter : memref<?x?x?x?xf32>, %input : memref<?
return
}
+// CHECK: constexpr auto weight_k_c_y_x_desc = make_native_tensor_descriptor(Sequence<k, c, y, x>{}, Sequence<stride_k, stride_c, stride_y, stride_x>{});
+// CHECK: constexpr auto input_ni_ci_wi_hi_desc = make_native_tensor_descriptor(Sequence<ni, ci, wi, hi>{}, Sequence<stride_ni, stride_ci, stride_wi, stride_hi>{});
+// CHECK: constexpr auto output_no_ko_ho_wo_desc = make_native_tensor_descriptor(Sequence<no, ko, ho, wo>{}, Sequence<stride_no, stride_ko, stride_ho, stride_wo>{});
+// CHECK: constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_kcyx_niciwihi_nokohowo
+// CHECK: decltype(weight_k_c_y_x_desc),
+// CHECK: decltype(input_ni_ci_wi_hi_desc),
+// CHECK: decltype(output_no_ko_ho_wo_desc),
More information about the llvm-branch-commits
mailing list