[llvm-branch-commits] [mlir] 7054cfc - Fix translation logic.

Wen-Heng Chung via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Oct 22 13:20:21 PDT 2020


Author: Wen-Heng (Jack) Chung
Date: 2020-06-05T22:18:20-05:00
New Revision: 7054cfc71ac450ab5ac9ee505d1096bb1252f9c6

URL: https://github.com/llvm/llvm-project/commit/7054cfc71ac450ab5ac9ee505d1096bb1252f9c6
DIFF: https://github.com/llvm/llvm-project/commit/7054cfc71ac450ab5ac9ee505d1096bb1252f9c6.diff

LOG: Fix translation logic.

Added: 
    

Modified: 
    mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
index 4eb8d7de7181..2bd64efa77b6 100644
--- a/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
+++ b/mlir/lib/Dialect/MIOpenOps/CppOutput/ConvertToMIOpenCPP.cpp
@@ -183,20 +183,16 @@ static constexpr StringLiteral kCppEpiloguePart2 =R"(
  
 void EmitCppPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr) {
   output << kCppPreamblePart1;
-
 // Between Preamble Part 1 and Part 2:
 // #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
   output << R"(#include "gridwise_convolution_implicit_gemm_v4r4_)";
-  output << layoutStr << ".hpp";
-
+  output << layoutStr << R"(.hpp")";
   output << kCppPreamblePart2;
-
 // Between Preamble Part 2 and Par 3:
 //    __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(
    output << R"(
     __launch_bounds__(CK_PARAM_TUNABLE_BLOCK_SIZE, 2) void gridwise_convolution_implicit_gemm_v4r4_)";
    output << layoutStr;
-
   output << kCppPreamblePart3;
 }
 
@@ -210,9 +206,7 @@ void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm:
   output << R"(
     constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_)";
   output << layoutStr;
-
   output << kCppEpiloguePart1;
-
 // Between Part1 and Part2:
 //        decltype(in_nchw_desc),
 //        decltype(wei_kcyx_desc),
@@ -220,7 +214,6 @@ void EmitCppEpilogue(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm:
   for (auto desc : tensorDescs) {
     output << "        decltype(" << desc << "),\n";
   }
-
   output << kCppEpiloguePart2;
 }
 
@@ -344,28 +337,23 @@ static constexpr StringLiteral kHeaderEpiloguePart2 = R"(
 
 void EmitHeaderPreamble(llvm::raw_ostream &output, llvm::StringRef layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
   output << kHeaderPreamblePart1;
-
   output << R"(
 struct GridwiseConvolutionImplicitGemm_v4r4_)";
   output << layoutStr;
-
   output << kHeaderPreamblePart2;
-
   output << kHeaderPreamblePart3;
-
   output << '\n';
-
   output << R"(
         constexpr auto )" << tensorDescs[0] << " = InGlobalDesc{};";
   output << R"(
         constexpr auto )" << tensorDescs[1] << " = WeiGlobalDesc{};";
   output << R"(
         constexpr auto )" << tensorDescs[2] << " = OutGlobalDesc{};";
+  output << '\n';
 }
 
 void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t, std::string> &args) {
   output << kHeaderEpiloguePart1;
-
 // Between Part1 and Part2 emit:
 //                                                   decltype(wei_e_k_global_desc),
 //                                                   decltype(in_e_b_global_desc),
@@ -374,7 +362,6 @@ void EmitHeaderEpilogue(llvm::raw_ostream &output, llvm::SmallDenseMap<int64_t,
     output << R"(
                                                      decltype()" << args[i] << "),";
   }
-
   output << kHeaderEpiloguePart2;
 }
 
@@ -437,17 +424,19 @@ void EmitStrideVariables(llvm::raw_ostream &output, llvm::ArrayRef<mlir::Attribu
   }
 }
 
-void EmitInterleaveArrayAttrOfStringAttrWithSeparator(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr, const StringRef &separator) {
+template<typename T>
+void EmitInterleaveArrayAttrWithSeparator(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr, const StringRef &separator) {
   if (arrayAttr) {
     interleave(arrayAttr, os, [&](Attribute attr) {
-      if (auto strAttr = attr.dyn_cast<StringAttr>())
-        os << strAttr.getValue();
+      if (auto typedAttr = attr.dyn_cast<T>())
+        os << typedAttr.getValue();
     }, separator);
   }
 }
 
-void EmitInterleaveCommaArrayAttrOfStringAttr(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr) {
-  EmitInterleaveArrayAttrOfStringAttrWithSeparator(os, arrayAttr, ", ");
+template<typename T>
+void EmitInterleaveCommaArrayAttr(llvm::raw_ostream &os, mlir::ArrayAttr &arrayAttr) {
+  EmitInterleaveArrayAttrWithSeparator<T>(os, arrayAttr, ", ");
 }
 
 void ObtainModuleInfo(ModuleOp &m, std::string &layoutStr, llvm::SmallVector<std::string, 3> &tensorDescs) {
@@ -511,7 +500,8 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
     // Start emitting.
     EmitHeaderPreamble(output, layoutStr, tensorDescs);
 
-    f.walk([&output, &srcLayoutAttrCtr, &tensorDescs, &gridwiseGemmArguments](miopen::TransformOp op) {
+    // First iteration. Output source dimensions.
+    f.walk([&output, &srcLayoutAttrCtr, &tensorDescs](miopen::TransformOp op) {
       // get source_layout attribute.
       auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
       if (srcLayoutAttr) {
@@ -520,10 +510,17 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
         EmitLayoutString(output, srcLayout, "", "", ", ");
         output << '\n';
 
-        EmitHeaderDimensionLengths(output, srcLayout, tensorDescs[srcLayoutAttrCtr]);
+        EmitHeaderDimensionLengths(output, srcLayout, tensorDescs[srcLayoutAttrCtr++]);
       }
-      output << '\n';
+    });
+    output << '\n';
  
+    srcLayoutAttrCtr = 0;
+    // Second iteration. Output the rest.
+    f.walk([&output, &srcLayoutAttrCtr, &tensorDescs, &gridwiseGemmArguments](miopen::TransformOp op) {
+      // get source_layout attribute.
+      auto srcLayoutAttr = op.getAttrOfType<ArrayAttr>("source_layout");
+
       // get layout attribute.
       auto layoutAttr = op.getAttrOfType<ArrayAttr>("layout");
       std::string inputTensorName;
@@ -549,22 +546,20 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
         // get intermediate_layout attribute.
         if (immLayoutAttr) {
           ins << kVarName[srcLayoutAttrCtr - 1] << "_";
-          EmitInterleaveArrayAttrOfStringAttrWithSeparator(ins, immLayoutAttr, "_");
+          EmitInterleaveArrayAttrWithSeparator<StringAttr>(ins, immLayoutAttr, "_");
           ins << "_desc";
           ins.flush();
 
           outs << kVarName[srcLayoutAttrCtr - 1] << "_";
         }
       }
-      EmitInterleaveArrayAttrOfStringAttrWithSeparator(outs, outputLayoutAttr, "_");
+      EmitInterleaveArrayAttrWithSeparator<StringAttr>(outs, outputLayoutAttr, "_");
       outs << "_desc";
       outs.flush();
 
       // determine gridwise GEMM arguments.
       auto gridwiseGemmArgPosAttr = op.getAttrOfType<IntegerAttr>("gridwise_gemm_argument_position");
       if (gridwiseGemmArgPosAttr) {
-        llvm::errs() << "gridwise gemm argument pos: " << gridwiseGemmArgPosAttr.getValue() << "\n";
-        llvm::errs() << "tensor: " << outputTensorName << "\n";
         gridwiseGemmArguments[gridwiseGemmArgPosAttr.getInt()] = outputTensorName;
       }  
 
@@ -572,30 +567,48 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
       srcs << "            make_tuple(";
       dsts << "            make_tuple(";
 
+      // XXX see if we can get better than this.
+      int convDilationCtr = 0;
+
       for (auto layoutSpec = layoutAttr.begin(); layoutSpec != layoutAttr.end(); ) {
         if (auto layoutSpecDict = layoutSpec->dyn_cast<DictionaryAttr>()) {
           auto srcNames = layoutSpecDict.get("source_names").dyn_cast<ArrayAttr>();
           auto dstNames = layoutSpecDict.get("names").dyn_cast<ArrayAttr>();
+          auto srcDims = layoutSpecDict.get("source_dimensions").dyn_cast<ArrayAttr>();
+          auto dstDims = layoutSpecDict.get("dimensions").dyn_cast<ArrayAttr>();
 
           if (auto transform = layoutSpecDict.get("transformation").dyn_cast<StringAttr>()) {
-            if (transform.getValue() == "PassThrough" ||
-                transform.getValue() == "Merge") {
+            if (transform.getValue() == "PassThrough") {
               ops << transform.getValue() << "<";
-              EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames);
+              EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames);
               ops << ">{}";
+            } else if (transform.getValue() == "Merge") {
+              ops << transform.getValue() << "<"
+                  << "Sequence<";
+              EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames);
+              ops << ">" << ">{}";
             } else if (transform.getValue() == "Pad") {
               ops << transform.getValue() << "<"
                   << "Sequence<";
-              EmitInterleaveCommaArrayAttrOfStringAttr(ops, srcNames);
+              EmitInterleaveCommaArrayAttr<StringAttr>(ops, srcNames);
               ops << ">, InLeftPads, InRightPads" << ">{}";
             } else if (transform.getValue() == "Embed") {
               ops << transform.getValue() << "<"
                   << "Sequence<";
-              EmitInterleaveCommaArrayAttrOfStringAttr(ops, dstNames);
-              ops << ">, Sequence<ConvDilationTBD, ConvDilationTBD, 0>>{}";
+              EmitInterleaveCommaArrayAttr<StringAttr>(ops, dstNames);
+              if (convDilationCtr == 0) {
+                ops << ">, Sequence<ConvDilationH, ConvDilationH, 0>>{}";
+                convDilationCtr++;
+              } else {
+                ops << ">, Sequence<ConvDilationW, ConvDilationW, 0>>{}";
+              }
             }
-            srcs << "Sequence<" << layoutSpecDict.get("source_dimensions") << ">{}";
-            dsts << "Sequence<" << layoutSpecDict.get("dimensions") << ">{}";
+            srcs << "Sequence<";
+            EmitInterleaveCommaArrayAttr<IntegerAttr>(srcs, srcDims);
+            srcs << ">{}";
+            dsts << "Sequence<";
+            EmitInterleaveCommaArrayAttr<IntegerAttr>(dsts, dstDims);
+            dsts << ">{}";
           }
         }
 
@@ -616,7 +629,7 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
       output << "        constexpr auto " << outputTensorName << " = transform_tensor_descriptor(\n";
       output << "            " << inputTensorName << ",\n";
       output << operationSpec << srcDimSpec << dstDimSpec;
-      output << ");\n";
+      output << ");\n\n";
     });
 
     // TBD get tuning parameters.


        


More information about the llvm-branch-commits mailing list