[llvm-branch-commits] [mlir] ae895ac - [mlir][spirv] De-template deserialization

Lei Zhang via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 23 11:50:28 PST 2020


Author: Lei Zhang
Date: 2020-12-23T14:45:46-05:00
New Revision: ae895ac4b9fa86b9471617357f66c0cd6cdb70b8

URL: https://github.com/llvm/llvm-project/commit/ae895ac4b9fa86b9471617357f66c0cd6cdb70b8
DIFF: https://github.com/llvm/llvm-project/commit/ae895ac4b9fa86b9471617357f66c0cd6cdb70b8.diff

LOG: [mlir][spirv] De-template deserialization

Previously for each op we generate a separate deserialization
method for it. Those deserialization methods duplicate the logic
of parsing operands/results/attributes and such.

This commit creates a generic method and let suitable op-specific
deserialization method to call into it.

wc -l SPIRVSerialization.inc: before 13290; after: 8304 (So -4986)

Reviewed By: hanchung, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D93504

Added: 
    

Modified: 
    mlir/lib/Target/SPIRV/Deserialization.cpp
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp
index 94ea19f94123..30f46f6fc605 100644
--- a/mlir/lib/Target/SPIRV/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization.cpp
@@ -34,6 +34,10 @@ using namespace mlir;
 
 #define DEBUG_TYPE "spirv-deserialization"
 
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
 /// Decodes a string literal in `words` starting at `wordIndex`. Update the
 /// latter to point to the position in words after the string literal.
 static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
@@ -55,6 +59,10 @@ static inline bool isFnEntryBlock(Block *block) {
 }
 
 namespace {
+//===----------------------------------------------------------------------===//
+// Utility Definitions
+//===----------------------------------------------------------------------===//
+
 /// A struct for containing a header block's merge and continue targets.
 ///
 /// This struct is used to track original structured control flow info from
@@ -124,6 +132,10 @@ struct DeferredStructTypeInfo {
   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
 };
 
+//===----------------------------------------------------------------------===//
+// Deserializer Declaration
+//===----------------------------------------------------------------------===//
+
 /// A SPIR-V module serializer.
 ///
 /// A SPIR-V binary module is a single linear stream of instructions; each
@@ -423,6 +435,14 @@ class Deserializer {
                                    ArrayRef<uint32_t> operands,
                                    bool deferInstructions = true);
 
+  /// Processes a SPIR-V instruction from the given `operands`. It should
+  /// deserialize into an op with the given `opName` and `numOperands`.
+  /// This method is a generic one for dispatching any SPIR-V ops without
+  /// variadic operands and attributes in TableGen definitions.
+  LogicalResult processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
+                                            StringRef opName, bool hasResult,
+                                            unsigned numOperands);
+
   /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
   /// insertion point.
   LogicalResult processUndef(ArrayRef<uint32_t> operands);
@@ -580,6 +600,10 @@ class Deserializer {
 };
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Deserializer Method Definitions
+//===----------------------------------------------------------------------===//
+
 Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
       module(createModuleOp()), opBuilder(module->body()) {}
@@ -2497,6 +2521,87 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   return dispatchToAutogenDeserialization(opcode, operands);
 }
 
+LogicalResult
+Deserializer::processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
+                                          StringRef opName, bool hasResult,
+                                          unsigned numOperands) {
+  SmallVector<Type, 1> resultTypes;
+  uint32_t valueID = 0;
+
+  size_t wordIndex= 0;
+  if (hasResult) {
+    if (wordIndex >= words.size())
+      return emitError(unknownLoc,
+                       "expected result type <id> while deserializing for ")
+             << opName;
+
+    // Decode the type <id>
+    auto type = getType(words[wordIndex]);
+    if (!type)
+      return emitError(unknownLoc, "unknown type result <id>: ")
+             << words[wordIndex];
+    resultTypes.push_back(type);
+    ++wordIndex;
+
+    // Decode the result <id>
+    if (wordIndex >= words.size())
+      return emitError(unknownLoc,
+                       "expected result <id> while deserializing for ")
+             << opName;
+    valueID = words[wordIndex];
+    ++wordIndex;
+  }
+
+  SmallVector<Value, 4> operands;
+  SmallVector<NamedAttribute, 4> attributes;
+
+  // Decode operands
+  size_t operandIndex = 0;
+  for (; operandIndex < numOperands && wordIndex < words.size();
+       ++operandIndex, ++wordIndex) {
+    auto arg = getValue(words[wordIndex]);
+    if (!arg)
+      return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
+    operands.push_back(arg);
+  }
+  if (operandIndex != numOperands) {
+    return emitError(
+               unknownLoc,
+               "found less operands than expected when deserializing for ")
+           << opName << "; only " << operandIndex << " of " << numOperands
+           << " processed";
+  }
+  if (wordIndex != words.size()) {
+    return emitError(
+               unknownLoc,
+               "found more operands than expected when deserializing for ")
+           << opName << "; only " << wordIndex << " of " << words.size()
+           << " processed";
+  }
+
+  // Attach attributes from decorations
+  if (decorations.count(valueID)) {
+    auto attrs = decorations[valueID].getAttrs();
+    attributes.append(attrs.begin(), attrs.end());
+  }
+
+  // Create the op and update bookkeeping maps
+  Location loc = createFileLineColLoc(opBuilder);
+  OperationState opState(loc, opName);
+  opState.addOperands(operands);
+  if (hasResult)
+    opState.addTypes(resultTypes);
+  opState.addAttributes(attributes);
+  Operation *op = opBuilder.createOperation(opState);
+  if (hasResult)
+    valueMap[valueID] = op->getResult(0);
+
+  if (op->hasTrait<OpTrait::IsTerminator>())
+    clearDebugLine();
+
+  return success();
+}
+
 LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc, "OpUndef instruction must have two operands");
@@ -2779,6 +2884,7 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
 // various Deserializer::processOp<...>() specializations.
 #define GET_DESERIALIZATION_FNS
 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
+
 } // namespace
 
 namespace mlir {

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index cd8be8a984f5..ab520114e7f5 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -914,16 +914,28 @@ static void emitDeserializationFunction(const Record *attrClass,
                                         const Record *record,
                                         const Operator &op, raw_ostream &os) {
   // If the record has 'autogenSerialization' set to 0, nothing to do
-  if (!record->getValueAsBit("autogenSerialization")) {
+  if (!record->getValueAsBit("autogenSerialization"))
     return;
-  }
+
   StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
       wordIndex("wordIndex"), opVar("op"), operands("operands"),
       attributes("attributes");
+
+  // Method declaration
   os << formatv("template <> "
                 "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
                 "uint32_t> {1}) {{\n",
                 op.getQualCppClassName(), words);
+
+  // Special case for ops without attributes in TableGen definitions
+  if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
+    os << formatv("  return processOpWithoutGrammarAttr("
+                  "{0}, \"{1}\", {2}, {3});\n}\n\n",
+                  words, op.getOperationName(),
+                  op.getNumResults() ? "true" : "false", op.getNumOperands());
+    return;
+  }
+
   os << formatv("  SmallVector<Type, 1> {0};\n", resultTypes);
   os << formatv("  size_t {0} = 0; (void){0};\n", wordIndex);
   os << formatv("  uint32_t {0} = 0; (void){0};\n", valueID);
@@ -938,6 +950,9 @@ static void emitDeserializationFunction(const Record *attrClass,
   emitOperandDeserialization(op, record->getLoc(), "  ", words, wordIndex,
                              operands, attributes, os);
 
+  // Decorations
+  emitDecorationDeserialization(op, "  ", valueID, attributes, os);
+
   os << formatv("  Location loc = createFileLineColLoc(opBuilder);\n");
   os << formatv("  auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
                 "(void){1};\n",
@@ -953,9 +968,6 @@ static void emitDeserializationFunction(const Record *attrClass,
   // next end of block.
   os << formatv("  if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
   os << formatv("    clearDebugLine();\n");
-
-  // Decorations
-  emitDecorationDeserialization(op, "  ", valueID, attributes, os);
   os << "  return success();\n";
   os << "}\n\n";
 }


        


More information about the llvm-branch-commits mailing list