[llvm-branch-commits] [mlir] ae895ac - [mlir][spirv] De-template deserialization
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Dec 23 11:50:28 PST 2020
Author: Lei Zhang
Date: 2020-12-23T14:45:46-05:00
New Revision: ae895ac4b9fa86b9471617357f66c0cd6cdb70b8
URL: https://github.com/llvm/llvm-project/commit/ae895ac4b9fa86b9471617357f66c0cd6cdb70b8
DIFF: https://github.com/llvm/llvm-project/commit/ae895ac4b9fa86b9471617357f66c0cd6cdb70b8.diff
LOG: [mlir][spirv] De-template deserialization
Previously for each op we generate a separate deserialization
method for it. Those deserialization methods duplicate the logic
of parsing operands/results/attributes and such.
This commit creates a generic method and let suitable op-specific
deserialization method to call into it.
wc -l SPIRVSerialization.inc: before 13290; after: 8304 (So -4986)
Reviewed By: hanchung, ThomasRaoux
Differential Revision: https://reviews.llvm.org/D93504
Added:
Modified:
mlir/lib/Target/SPIRV/Deserialization.cpp
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization.cpp
index 94ea19f94123..30f46f6fc605 100644
--- a/mlir/lib/Target/SPIRV/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization.cpp
@@ -34,6 +34,10 @@ using namespace mlir;
#define DEBUG_TYPE "spirv-deserialization"
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
/// Decodes a string literal in `words` starting at `wordIndex`. Update the
/// latter to point to the position in words after the string literal.
static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
@@ -55,6 +59,10 @@ static inline bool isFnEntryBlock(Block *block) {
}
namespace {
+//===----------------------------------------------------------------------===//
+// Utility Definitions
+//===----------------------------------------------------------------------===//
+
/// A struct for containing a header block's merge and continue targets.
///
/// This struct is used to track original structured control flow info from
@@ -124,6 +132,10 @@ struct DeferredStructTypeInfo {
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
};
+//===----------------------------------------------------------------------===//
+// Deserializer Declaration
+//===----------------------------------------------------------------------===//
+
/// A SPIR-V module serializer.
///
/// A SPIR-V binary module is a single linear stream of instructions; each
@@ -423,6 +435,14 @@ class Deserializer {
ArrayRef<uint32_t> operands,
bool deferInstructions = true);
+ /// Processes a SPIR-V instruction from the given `operands`. It should
+ /// deserialize into an op with the given `opName` and `numOperands`.
+ /// This method is a generic one for dispatching any SPIR-V ops without
+ /// variadic operands and attributes in TableGen definitions.
+ LogicalResult processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
+ StringRef opName, bool hasResult,
+ unsigned numOperands);
+
/// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
/// insertion point.
LogicalResult processUndef(ArrayRef<uint32_t> operands);
@@ -580,6 +600,10 @@ class Deserializer {
};
} // namespace
+//===----------------------------------------------------------------------===//
+// Deserializer Method Definitions
+//===----------------------------------------------------------------------===//
+
Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
module(createModuleOp()), opBuilder(module->body()) {}
@@ -2497,6 +2521,87 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return dispatchToAutogenDeserialization(opcode, operands);
}
+LogicalResult
+Deserializer::processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
+ StringRef opName, bool hasResult,
+ unsigned numOperands) {
+ SmallVector<Type, 1> resultTypes;
+ uint32_t valueID = 0;
+
+ size_t wordIndex= 0;
+ if (hasResult) {
+ if (wordIndex >= words.size())
+ return emitError(unknownLoc,
+ "expected result type <id> while deserializing for ")
+ << opName;
+
+ // Decode the type <id>
+ auto type = getType(words[wordIndex]);
+ if (!type)
+ return emitError(unknownLoc, "unknown type result <id>: ")
+ << words[wordIndex];
+ resultTypes.push_back(type);
+ ++wordIndex;
+
+ // Decode the result <id>
+ if (wordIndex >= words.size())
+ return emitError(unknownLoc,
+ "expected result <id> while deserializing for ")
+ << opName;
+ valueID = words[wordIndex];
+ ++wordIndex;
+ }
+
+ SmallVector<Value, 4> operands;
+ SmallVector<NamedAttribute, 4> attributes;
+
+ // Decode operands
+ size_t operandIndex = 0;
+ for (; operandIndex < numOperands && wordIndex < words.size();
+ ++operandIndex, ++wordIndex) {
+ auto arg = getValue(words[wordIndex]);
+ if (!arg)
+ return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
+ operands.push_back(arg);
+ }
+ if (operandIndex != numOperands) {
+ return emitError(
+ unknownLoc,
+ "found less operands than expected when deserializing for ")
+ << opName << "; only " << operandIndex << " of " << numOperands
+ << " processed";
+ }
+ if (wordIndex != words.size()) {
+ return emitError(
+ unknownLoc,
+ "found more operands than expected when deserializing for ")
+ << opName << "; only " << wordIndex << " of " << words.size()
+ << " processed";
+ }
+
+ // Attach attributes from decorations
+ if (decorations.count(valueID)) {
+ auto attrs = decorations[valueID].getAttrs();
+ attributes.append(attrs.begin(), attrs.end());
+ }
+
+ // Create the op and update bookkeeping maps
+ Location loc = createFileLineColLoc(opBuilder);
+ OperationState opState(loc, opName);
+ opState.addOperands(operands);
+ if (hasResult)
+ opState.addTypes(resultTypes);
+ opState.addAttributes(attributes);
+ Operation *op = opBuilder.createOperation(opState);
+ if (hasResult)
+ valueMap[valueID] = op->getResult(0);
+
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ clearDebugLine();
+
+ return success();
+}
+
LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpUndef instruction must have two operands");
@@ -2779,6 +2884,7 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
+
} // namespace
namespace mlir {
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index cd8be8a984f5..ab520114e7f5 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -914,16 +914,28 @@ static void emitDeserializationFunction(const Record *attrClass,
const Record *record,
const Operator &op, raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
- if (!record->getValueAsBit("autogenSerialization")) {
+ if (!record->getValueAsBit("autogenSerialization"))
return;
- }
+
StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
wordIndex("wordIndex"), opVar("op"), operands("operands"),
attributes("attributes");
+
+ // Method declaration
os << formatv("template <> "
"LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
"uint32_t> {1}) {{\n",
op.getQualCppClassName(), words);
+
+ // Special case for ops without attributes in TableGen definitions
+ if (op.getNumAttributes() == 0 && op.getNumVariableLengthOperands() == 0) {
+ os << formatv(" return processOpWithoutGrammarAttr("
+ "{0}, \"{1}\", {2}, {3});\n}\n\n",
+ words, op.getOperationName(),
+ op.getNumResults() ? "true" : "false", op.getNumOperands());
+ return;
+ }
+
os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
@@ -938,6 +950,9 @@ static void emitDeserializationFunction(const Record *attrClass,
emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
operands, attributes, os);
+ // Decorations
+ emitDecorationDeserialization(op, " ", valueID, attributes, os);
+
os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n");
os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); "
"(void){1};\n",
@@ -953,9 +968,6 @@ static void emitDeserializationFunction(const Record *attrClass,
// next end of block.
os << formatv(" if ({0}.hasTrait<OpTrait::IsTerminator>())\n", opVar);
os << formatv(" clearDebugLine();\n");
-
- // Decorations
- emitDecorationDeserialization(op, " ", valueID, attributes, os);
os << " return success();\n";
os << "}\n\n";
}
More information about the llvm-branch-commits
mailing list