[llvm-branch-commits] [mlir] 8349fa0 - [mlir][spirv] NFC: split deserialization into multiple source files
Lei Zhang via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 12 08:26:47 PST 2021
Author: Lei Zhang
Date: 2021-01-12T11:21:03-05:00
New Revision: 8349fa0fdd3a372f88ea53de6c906d987c1f4fec
URL: https://github.com/llvm/llvm-project/commit/8349fa0fdd3a372f88ea53de6c906d987c1f4fec
DIFF: https://github.com/llvm/llvm-project/commit/8349fa0fdd3a372f88ea53de6c906d987c1f4fec.diff
LOG: [mlir][spirv] NFC: split deserialization into multiple source files
This avoids large source files and gives a better structure. It also
allows leveraging compilation parallelism.
Reviewed By: mravishankar
Differential Revision: https://reviews.llvm.org/D94360
Added:
mlir/lib/Target/SPIRV/CMakeLists.txt
mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt
mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt
mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
Modified:
mlir/lib/Target/CMakeLists.txt
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
Removed:
mlir/lib/Target/SPIRV/Deserialization.cpp
mlir/lib/Target/SPIRV/Serialization.cpp
################################################################################
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 1b1a02db5511..51a0e78a4edf 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -1,3 +1,5 @@
+add_subdirectory(SPIRV)
+
add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
LLVMIR/DebugTranslation.cpp
LLVMIR/ModuleTranslation.cpp
@@ -132,52 +134,3 @@ add_mlir_translation_library(MLIRTargetROCDLIR
MLIRROCDLIR
MLIRTargetLLVMIRModuleTranslation
)
-
-add_mlir_translation_library(MLIRSPIRVBinaryUtils
- SPIRV/SPIRVBinaryUtils.cpp
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSPIRV
- MLIRSupport
- )
-
-add_mlir_translation_library(MLIRSPIRVSerialization
- SPIRV/Serialization.cpp
-
- DEPENDS
- MLIRSPIRVSerializationGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSPIRV
- MLIRSPIRVBinaryUtils
- MLIRSupport
- MLIRTranslation
- )
-
-add_mlir_translation_library(MLIRSPIRVDeserialization
- SPIRV/Deserialization.cpp
-
- DEPENDS
- MLIRSPIRVSerializationGen
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSPIRV
- MLIRSPIRVBinaryUtils
- MLIRSupport
- MLIRTranslation
- )
-
-add_mlir_translation_library(MLIRSPIRVTranslateRegistration
- SPIRV/TranslateRegistration.cpp
-
- LINK_LIBS PUBLIC
- MLIRIR
- MLIRSPIRV
- MLIRSPIRVSerialization
- MLIRSPIRVDeserialization
- MLIRSupport
- MLIRTranslation
- )
diff --git a/mlir/lib/Target/SPIRV/CMakeLists.txt b/mlir/lib/Target/SPIRV/CMakeLists.txt
new file mode 100644
index 000000000000..cddbc0971337
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/CMakeLists.txt
@@ -0,0 +1,28 @@
+add_subdirectory(Deserialization)
+add_subdirectory(Serialization)
+
+set(LLVM_OPTIONAL_SOURCES
+ SPIRVBinaryUtils.cpp
+ TranslateRegistration.cpp
+ )
+
+add_mlir_translation_library(MLIRSPIRVBinaryUtils
+ SPIRVBinaryUtils.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRV
+ MLIRSupport
+ )
+
+add_mlir_translation_library(MLIRSPIRVTranslateRegistration
+ TranslateRegistration.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRV
+ MLIRSPIRVSerialization
+ MLIRSPIRVDeserialization
+ MLIRSupport
+ MLIRTranslation
+ )
diff --git a/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt b/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt
new file mode 100644
index 000000000000..99d40e11baa6
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Deserialization/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_translation_library(MLIRSPIRVDeserialization
+ DeserializeOps.cpp
+ Deserializer.cpp
+ Deserialization.cpp
+
+ DEPENDS
+ MLIRSPIRVSerializationGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRV
+ MLIRSPIRVBinaryUtils
+ MLIRSupport
+ MLIRTranslation
+ )
+
+
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
new file mode 100644
index 000000000000..2eb08669f658
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp
@@ -0,0 +1,23 @@
+//===- Deserialization.cpp - MLIR SPIR-V Deserialization ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/SPIRV/Deserialization.h"
+
+#include "Deserializer.h"
+
+namespace mlir {
+spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
+ MLIRContext *context) {
+ Deserializer deserializer(binary, context);
+
+ if (failed(deserializer.deserialize()))
+ return nullptr;
+
+ return deserializer.collect();
+}
+} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
new file mode 100644
index 000000000000..f11804a11a9a
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -0,0 +1,565 @@
+//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the Deserializer methods for SPIR-V binary instructions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Deserializer.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "spirv-deserialization"
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Extracts the opcode from the given first word of a SPIR-V instruction.
+static inline spirv::Opcode extractOpcode(uint32_t word) {
+ return static_cast<spirv::Opcode>(word & 0xffff);
+}
+
+//===----------------------------------------------------------------------===//
+// Instruction
+//===----------------------------------------------------------------------===//
+
+Value spirv::Deserializer::getValue(uint32_t id) {
+ if (auto constInfo = getConstant(id)) {
+ // Materialize a `spv.constant` op at every use site.
+ return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
+ constInfo->first);
+ }
+ if (auto varOp = getGlobalVariable(id)) {
+ auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
+ unknownLoc, varOp.type(),
+ opBuilder.getSymbolRefAttr(varOp.getOperation()));
+ return addressOfOp.pointer();
+ }
+ if (auto constOp = getSpecConstant(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, constOp.default_value().getType(),
+ opBuilder.getSymbolRefAttr(constOp.getOperation()));
+ return referenceOfOp.reference();
+ }
+ if (auto constCompositeOp = getSpecConstantComposite(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, constCompositeOp.type(),
+ opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
+ return referenceOfOp.reference();
+ }
+ if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
+ return materializeSpecConstantOperation(
+ id, specConstOperationInfo->enclodesOpcode,
+ specConstOperationInfo->resultTypeID,
+ specConstOperationInfo->enclosedOpOperands);
+ }
+ if (auto undef = getUndefType(id)) {
+ return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
+ }
+ return valueMap.lookup(id);
+}
+
+LogicalResult
+spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
+ ArrayRef<uint32_t> &operands,
+ Optional<spirv::Opcode> expectedOpcode) {
+ auto binarySize = binary.size();
+ if (curOffset >= binarySize) {
+ return emitError(unknownLoc, "expected ")
+ << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
+ : "more")
+ << " instruction";
+ }
+
+ // For each instruction, get its word count from the first word to slice it
+ // from the stream properly, and then dispatch to the instruction handler.
+
+ uint32_t wordCount = binary[curOffset] >> 16;
+
+ if (wordCount == 0)
+ return emitError(unknownLoc, "word count cannot be zero");
+
+ uint32_t nextOffset = curOffset + wordCount;
+ if (nextOffset > binarySize)
+ return emitError(unknownLoc, "insufficient words for the last instruction");
+
+ opcode = extractOpcode(binary[curOffset]);
+ operands = binary.slice(curOffset + 1, wordCount - 1);
+ curOffset = nextOffset;
+ return success();
+}
+
+LogicalResult spirv::Deserializer::processInstruction(
+ spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
+ LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
+ << spirv::stringifyOpcode(opcode) << "\n");
+
+ // First dispatch all the instructions whose opcode does not correspond to
+ // those that have a direct mirror in the SPIR-V dialect
+ switch (opcode) {
+ case spirv::Opcode::OpCapability:
+ return processCapability(operands);
+ case spirv::Opcode::OpExtension:
+ return processExtension(operands);
+ case spirv::Opcode::OpExtInst:
+ return processExtInst(operands);
+ case spirv::Opcode::OpExtInstImport:
+ return processExtInstImport(operands);
+ case spirv::Opcode::OpMemberName:
+ return processMemberName(operands);
+ case spirv::Opcode::OpMemoryModel:
+ return processMemoryModel(operands);
+ case spirv::Opcode::OpEntryPoint:
+ case spirv::Opcode::OpExecutionMode:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ break;
+ case spirv::Opcode::OpVariable:
+ if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
+ return processGlobalVariable(operands);
+ }
+ break;
+ case spirv::Opcode::OpLine:
+ return processDebugLine(operands);
+ case spirv::Opcode::OpNoLine:
+ return clearDebugLine();
+ case spirv::Opcode::OpName:
+ return processName(operands);
+ case spirv::Opcode::OpString:
+ return processDebugString(operands);
+ case spirv::Opcode::OpModuleProcessed:
+ case spirv::Opcode::OpSource:
+ case spirv::Opcode::OpSourceContinued:
+ case spirv::Opcode::OpSourceExtension:
+ // TODO: This is debug information embedded in the binary which should be
+ // translated into the spv.module.
+ return success();
+ case spirv::Opcode::OpTypeVoid:
+ case spirv::Opcode::OpTypeBool:
+ case spirv::Opcode::OpTypeInt:
+ case spirv::Opcode::OpTypeFloat:
+ case spirv::Opcode::OpTypeVector:
+ case spirv::Opcode::OpTypeMatrix:
+ case spirv::Opcode::OpTypeArray:
+ case spirv::Opcode::OpTypeFunction:
+ case spirv::Opcode::OpTypeRuntimeArray:
+ case spirv::Opcode::OpTypeStruct:
+ case spirv::Opcode::OpTypePointer:
+ case spirv::Opcode::OpTypeCooperativeMatrixNV:
+ return processType(opcode, operands);
+ case spirv::Opcode::OpTypeForwardPointer:
+ return processTypeForwardPointer(operands);
+ case spirv::Opcode::OpConstant:
+ return processConstant(operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstant:
+ return processConstant(operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantComposite:
+ return processConstantComposite(operands);
+ case spirv::Opcode::OpSpecConstantComposite:
+ return processSpecConstantComposite(operands);
+ case spirv::Opcode::OpSpecConstantOperation:
+ return processSpecConstantOperation(operands);
+ case spirv::Opcode::OpConstantTrue:
+ return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstantTrue:
+ return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantFalse:
+ return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
+ case spirv::Opcode::OpSpecConstantFalse:
+ return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
+ case spirv::Opcode::OpConstantNull:
+ return processConstantNull(operands);
+ case spirv::Opcode::OpDecorate:
+ return processDecoration(operands);
+ case spirv::Opcode::OpMemberDecorate:
+ return processMemberDecoration(operands);
+ case spirv::Opcode::OpFunction:
+ return processFunction(operands);
+ case spirv::Opcode::OpLabel:
+ return processLabel(operands);
+ case spirv::Opcode::OpBranch:
+ return processBranch(operands);
+ case spirv::Opcode::OpBranchConditional:
+ return processBranchConditional(operands);
+ case spirv::Opcode::OpSelectionMerge:
+ return processSelectionMerge(operands);
+ case spirv::Opcode::OpLoopMerge:
+ return processLoopMerge(operands);
+ case spirv::Opcode::OpPhi:
+ return processPhi(operands);
+ case spirv::Opcode::OpUndef:
+ return processUndef(operands);
+ default:
+ break;
+ }
+ return dispatchToAutogenDeserialization(opcode, operands);
+}
+
+LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
+ ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
+ unsigned numOperands) {
+ SmallVector<Type, 1> resultTypes;
+ uint32_t valueID = 0;
+
+ size_t wordIndex = 0;
+ if (hasResult) {
+ if (wordIndex >= words.size())
+ return emitError(unknownLoc,
+ "expected result type <id> while deserializing for ")
+ << opName;
+
+ // Decode the type <id>
+ auto type = getType(words[wordIndex]);
+ if (!type)
+ return emitError(unknownLoc, "unknown type result <id>: ")
+ << words[wordIndex];
+ resultTypes.push_back(type);
+ ++wordIndex;
+
+ // Decode the result <id>
+ if (wordIndex >= words.size())
+ return emitError(unknownLoc,
+ "expected result <id> while deserializing for ")
+ << opName;
+ valueID = words[wordIndex];
+ ++wordIndex;
+ }
+
+ SmallVector<Value, 4> operands;
+ SmallVector<NamedAttribute, 4> attributes;
+
+ // Decode operands
+ size_t operandIndex = 0;
+ for (; operandIndex < numOperands && wordIndex < words.size();
+ ++operandIndex, ++wordIndex) {
+ auto arg = getValue(words[wordIndex]);
+ if (!arg)
+ return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
+ operands.push_back(arg);
+ }
+ if (operandIndex != numOperands) {
+ return emitError(
+ unknownLoc,
+ "found less operands than expected when deserializing for ")
+ << opName << "; only " << operandIndex << " of " << numOperands
+ << " processed";
+ }
+ if (wordIndex != words.size()) {
+ return emitError(
+ unknownLoc,
+ "found more operands than expected when deserializing for ")
+ << opName << "; only " << wordIndex << " of " << words.size()
+ << " processed";
+ }
+
+ // Attach attributes from decorations
+ if (decorations.count(valueID)) {
+ auto attrs = decorations[valueID].getAttrs();
+ attributes.append(attrs.begin(), attrs.end());
+ }
+
+ // Create the op and update bookkeeping maps
+ Location loc = createFileLineColLoc(opBuilder);
+ OperationState opState(loc, opName);
+ opState.addOperands(operands);
+ if (hasResult)
+ opState.addTypes(resultTypes);
+ opState.addAttributes(attributes);
+ Operation *op = opBuilder.createOperation(opState);
+ if (hasResult)
+ valueMap[valueID] = op->getResult(0);
+
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ clearDebugLine();
+
+ return success();
+}
+
+LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpUndef instruction must have two operands");
+ }
+ auto type = getType(operands[0]);
+ if (!type) {
+ return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
+ }
+ undefMap[operands[1]] = type;
+ return success();
+}
+
+LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 4) {
+ return emitError(unknownLoc,
+ "OpExtInst must have at least 4 operands, result type "
+ "<id>, result <id>, set <id> and instruction opcode");
+ }
+ if (!extendedInstSets.count(operands[2])) {
+ return emitError(unknownLoc, "undefined set <id> in OpExtInst");
+ }
+ SmallVector<uint32_t, 4> slicedOperands;
+ slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
+ slicedOperands.append(std::next(operands.begin(), 4), operands.end());
+ return dispatchToExtensionSetAutogenDeserialization(
+ extendedInstSets[operands[2]], operands[3], slicedOperands);
+}
+
+namespace mlir {
+namespace spirv {
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc,
+ "missing Execution Model specification in OpEntryPoint");
+ }
+ auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc, "missing <id> in OpEntryPoint");
+ }
+ // Get the function <id>
+ auto fnID = words[wordIndex++];
+ // Get the function name
+ auto fnName = decodeStringLiteral(words, wordIndex);
+ // Verify that the function <id> matches the fnName
+ auto parsedFunc = getFunction(fnID);
+ if (!parsedFunc) {
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+ }
+ if (parsedFunc.getName() != fnName) {
+ return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
+ "and OpFunction with <id> ")
+ << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
+ }
+ SmallVector<Attribute, 4> interface;
+ while (wordIndex < words.size()) {
+ auto arg = getGlobalVariable(words[wordIndex]);
+ if (!arg) {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << words[wordIndex] << " while decoding OpEntryPoint";
+ }
+ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
+ wordIndex++;
+ }
+ opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
+ opBuilder.getSymbolRefAttr(fnName),
+ opBuilder.getArrayAttr(interface));
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
+ unsigned wordIndex = 0;
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc,
+ "missing function result <id> in OpExecutionMode");
+ }
+ // Get the function <id> to get the name of the function
+ auto fnID = words[wordIndex++];
+ auto fn = getFunction(fnID);
+ if (!fn) {
+ return emitError(unknownLoc, "no function matching <id> ") << fnID;
+ }
+ // Get the Execution mode
+ if (wordIndex >= words.size()) {
+ return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
+ }
+ auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
+
+ // Get the values
+ SmallVector<Attribute, 4> attrListElems;
+ while (wordIndex < words.size()) {
+ attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
+ }
+ auto values = opBuilder.getArrayAttr(attrListElems);
+ opBuilder.create<spirv::ExecutionModeOp>(
+ unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 3) {
+ return emitError(
+ unknownLoc,
+ "OpControlBarrier must have execution scope <id>, memory scope <id> "
+ "and memory semantics <id>");
+ }
+
+ SmallVector<IntegerAttr, 3> argAttrs;
+ for (auto operand : operands) {
+ auto argAttr = getConstantInt(operand);
+ if (!argAttr) {
+ return emitError(unknownLoc,
+ "expected 32-bit integer constant from <id> ")
+ << operand << " for OpControlBarrier";
+ }
+ argAttrs.push_back(argAttr);
+ }
+
+ opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
+ argAttrs[1], argAttrs[2]);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc,
+ "OpFunctionCall must have at least 3 operands");
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ // Use null type to mean no result type.
+ if (isVoidType(resultType))
+ resultType = nullptr;
+
+ auto resultID = operands[1];
+ auto functionID = operands[2];
+
+ auto functionName = getFunctionSymbol(functionID);
+
+ SmallVector<Value, 4> arguments;
+ for (auto operand : llvm::drop_begin(operands, 3)) {
+ auto value = getValue(operand);
+ if (!value) {
+ return emitError(unknownLoc, "unknown <id> ")
+ << operand << " used by OpFunctionCall";
+ }
+ arguments.push_back(value);
+ }
+
+ auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
+ unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
+ arguments);
+
+ if (resultType)
+ valueMap[resultID] = opFunctionCall.getResult(0);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
+ "and memory semantics <id>");
+ }
+
+ SmallVector<IntegerAttr, 2> argAttrs;
+ for (auto operand : operands) {
+ auto argAttr = getConstantInt(operand);
+ if (!argAttr) {
+ return emitError(unknownLoc,
+ "expected 32-bit integer constant from <id> ")
+ << operand << " for OpMemoryBarrier";
+ }
+ argAttrs.push_back(argAttr);
+ }
+
+ opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
+ argAttrs[1]);
+ return success();
+}
+
+template <>
+LogicalResult
+Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
+ SmallVector<Type, 1> resultTypes;
+ size_t wordIndex = 0;
+ SmallVector<Value, 4> operands;
+ SmallVector<NamedAttribute, 4> attributes;
+
+ if (wordIndex < words.size()) {
+ auto arg = getValue(words[wordIndex]);
+
+ if (!arg) {
+ return emitError(unknownLoc, "unknown result <id> : ")
+ << words[wordIndex];
+ }
+
+ operands.push_back(arg);
+ wordIndex++;
+ }
+
+ if (wordIndex < words.size()) {
+ auto arg = getValue(words[wordIndex]);
+
+ if (!arg) {
+ return emitError(unknownLoc, "unknown result <id> : ")
+ << words[wordIndex];
+ }
+
+ operands.push_back(arg);
+ wordIndex++;
+ }
+
+ bool isAlignedAttr = false;
+
+ if (wordIndex < words.size()) {
+ auto attrValue = words[wordIndex++];
+ attributes.push_back(opBuilder.getNamedAttr(
+ "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
+ isAlignedAttr = (attrValue == 2);
+ }
+
+ if (isAlignedAttr && wordIndex < words.size()) {
+ attributes.push_back(opBuilder.getNamedAttr(
+ "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
+ }
+
+ if (wordIndex < words.size()) {
+ attributes.push_back(opBuilder.getNamedAttr(
+ "source_memory_access",
+ opBuilder.getI32IntegerAttr(words[wordIndex++])));
+ }
+
+ if (wordIndex < words.size()) {
+ attributes.push_back(opBuilder.getNamedAttr(
+ "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
+ }
+
+ if (wordIndex != words.size()) {
+ return emitError(unknownLoc,
+ "found more operands than expected when deserializing "
+ "spirv::CopyMemoryOp, only ")
+ << wordIndex << " of " << words.size() << " processed";
+ }
+
+ Location loc = createFileLineColLoc(opBuilder);
+ opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
+
+ return success();
+}
+
+// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
+// various Deserializer::processOp<...>() specializations.
+#define GET_DESERIALIZATION_FNS
+#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
+
+} // namespace spirv
+} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
similarity index 58%
rename from mlir/lib/Target/SPIRV/Deserialization.cpp
rename to mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 07eb3d35e0a4..5ce169a0d47f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1,4 +1,4 @@
-//===- Deserializer.cpp - MLIR SPIR-V Deserialization ---------------------===//
+//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,11 +6,11 @@
//
//===----------------------------------------------------------------------===//
//
-// This file defines the SPIR-V binary to MLIR SPIR-V module deserialization.
+// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
//
//===----------------------------------------------------------------------===//
-#include "mlir/Target/SPIRV/Deserialization.h"
+#include "Deserializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
@@ -24,7 +24,6 @@
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
-#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/bit.h"
@@ -40,607 +39,22 @@ using namespace mlir;
// Utility Functions
//===----------------------------------------------------------------------===//
-/// Decodes a string literal in `words` starting at `wordIndex`. Update the
-/// latter to point to the position in words after the string literal.
-static inline StringRef decodeStringLiteral(ArrayRef<uint32_t> words,
- unsigned &wordIndex) {
- StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
- wordIndex += str.size() / 4 + 1;
- return str;
-}
-
-/// Extracts the opcode from the given first word of a SPIR-V instruction.
-static inline spirv::Opcode extractOpcode(uint32_t word) {
- return static_cast<spirv::Opcode>(word & 0xffff);
-}
-
/// Returns true if the given `block` is a function entry block.
static inline bool isFnEntryBlock(Block *block) {
return block->isEntryBlock() &&
isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
}
-namespace {
-//===----------------------------------------------------------------------===//
-// Utility Definitions
-//===----------------------------------------------------------------------===//
-
-/// A struct for containing a header block's merge and continue targets.
-///
-/// This struct is used to track original structured control flow info from
-/// SPIR-V blob. This info will be used to create spv.selection/spv.loop
-/// later.
-struct BlockMergeInfo {
- Block *mergeBlock;
- Block *continueBlock; // nullptr for spv.selection
- Location loc;
- uint32_t control;
-
- BlockMergeInfo(Location location, uint32_t control)
- : mergeBlock(nullptr), continueBlock(nullptr), loc(location),
- control(control) {}
- BlockMergeInfo(Location location, uint32_t control, Block *m,
- Block *c = nullptr)
- : mergeBlock(m), continueBlock(c), loc(location), control(control) {}
-};
-
-/// A struct for containing OpLine instruction information.
-struct DebugLine {
- uint32_t fileID;
- uint32_t line;
- uint32_t col;
-
- DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum)
- : fileID(fileIDNum), line(lineNum), col(colNum) {}
-};
-
-/// Map from a selection/loop's header block to its merge (and continue) target.
-using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>;
-
-/// A "deferred struct type" is a struct type with one or more member types not
-/// known when the Deserializer first encounters the struct. This happens, for
-/// example, with recursive structs where a pointer to the struct type is
-/// forward declared through OpTypeForwardPointer in the SPIR-V module before
-/// the struct declaration; the actual pointer to struct type should be defined
-/// later through an OpTypePointer. For example, the following C struct:
-///
-/// struct A {
-/// A* next;
-/// };
-///
-/// would be represented in the SPIR-V module as:
-///
-/// OpName %A "A"
-/// OpTypeForwardPointer %APtr Generic
-/// %A = OpTypeStruct %APtr
-/// %APtr = OpTypePointer Generic %A
-///
-/// This means that the spirv::StructType cannot be fully constructed directly
-/// when the Deserializer encounters it. Instead we create a
-/// DeferredStructTypeInfo that contains all the information we know about the
-/// spirv::StructType. Once all forward references for the struct are resolved,
-/// the struct's body is set with all member info.
-struct DeferredStructTypeInfo {
- spirv::StructType deferredStructType;
-
- // A list of all unresolved member types for the struct. First element of each
- // item is operand ID, second element is member index in the struct.
- SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
-
- // The list of member types. For unresolved members, this list contains
- // place-holder empty types that will be updated later.
- SmallVector<Type, 4> memberTypes;
- SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
- SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
-};
-
-/// A struct that collects the info needed to materialize/emit a
-/// SpecConstantOperation op.
-struct SpecConstOperationMaterializationInfo {
- spirv::Opcode enclodesOpcode;
- uint32_t resultTypeID;
- SmallVector<uint32_t> enclosedOpOperands;
-};
-
-//===----------------------------------------------------------------------===//
-// Deserializer Declaration
-//===----------------------------------------------------------------------===//
-
-/// A SPIR-V module serializer.
-///
-/// A SPIR-V binary module is a single linear stream of instructions; each
-/// instruction is composed of 32-bit words. The first word of an instruction
-/// records the total number of words of that instruction using the 16
-/// higher-order bits. So this deserializer uses that to get instruction
-/// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
-///
-// TODO: clean up created ops on errors
-class Deserializer {
-public:
- /// Creates a deserializer for the given SPIR-V `binary` module.
- /// The SPIR-V ModuleOp will be created into `context.
- explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
-
- /// Deserializes the remembered SPIR-V binary module.
- LogicalResult deserialize();
-
- /// Collects the final SPIR-V ModuleOp.
- spirv::OwningSPIRVModuleRef collect();
-
-private:
- //===--------------------------------------------------------------------===//
- // Module structure
- //===--------------------------------------------------------------------===//
-
- /// Initializes the `module` ModuleOp in this deserializer instance.
- spirv::OwningSPIRVModuleRef createModuleOp();
-
- /// Processes SPIR-V module header in `binary`.
- LogicalResult processHeader();
-
- /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping
- /// in the deserializer.
- LogicalResult processCapability(ArrayRef<uint32_t> operands);
-
- /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
- /// in the deserializer.
- LogicalResult processExtension(ArrayRef<uint32_t> words);
-
- /// Processes the SPIR-V OpExtInstImport with `operands` and updates
- /// bookkeeping in the deserializer.
- LogicalResult processExtInstImport(ArrayRef<uint32_t> words);
-
- /// Attaches (version, capabilities, extensions) triple to `module` as an
- /// attribute.
- void attachVCETriple();
-
- /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
- LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
-
- /// Process SPIR-V OpName with `operands`.
- LogicalResult processName(ArrayRef<uint32_t> operands);
-
- /// Processes an OpDecorate instruction.
- LogicalResult processDecoration(ArrayRef<uint32_t> words);
-
- // Processes an OpMemberDecorate instruction.
- LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
-
- /// Processes an OpMemberName instruction.
- LogicalResult processMemberName(ArrayRef<uint32_t> words);
-
- /// Gets the function op associated with a result <id> of OpFunction.
- spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
-
- /// Processes the SPIR-V function at the current `offset` into `binary`.
- /// The operands to the OpFunction instruction is passed in as ``operands`.
- /// This method processes each instruction inside the function and dispatches
- /// them to their handler method accordingly.
- LogicalResult processFunction(ArrayRef<uint32_t> operands);
-
- /// Processes OpFunctionEnd and finalizes function. This wires up block
- /// argument created from OpPhi instructions and also structurizes control
- /// flow.
- LogicalResult processFunctionEnd(ArrayRef<uint32_t> operands);
-
- /// Gets the constant's attribute and type associated with the given <id>.
- Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
-
- /// Gets the info needed to materialize the spec constant operation op
- /// associated with the given <id>.
- Optional<SpecConstOperationMaterializationInfo>
- getSpecConstantOperation(uint32_t id);
-
- /// Gets the constant's integer attribute with the given <id>. Returns a
- /// null IntegerAttr if the given is not registered or does not correspond
- /// to an integer constant.
- IntegerAttr getConstantInt(uint32_t id);
-
- /// Returns a symbol to be used for the function name with the given
- /// result <id>. This tries to use the function's OpName if
- /// exists; otherwise creates one based on the <id>.
- std::string getFunctionSymbol(uint32_t id);
-
- /// Returns a symbol to be used for the specialization constant with the given
- /// result <id>. This tries to use the specialization constant's OpName if
- /// exists; otherwise creates one based on the <id>.
- std::string getSpecConstantSymbol(uint32_t id);
-
- /// Gets the specialization constant with the given result <id>.
- spirv::SpecConstantOp getSpecConstant(uint32_t id) {
- return specConstMap.lookup(id);
- }
-
- /// Gets the composite specialization constant with the given result <id>.
- spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
- return specConstCompositeMap.lookup(id);
- }
-
- /// Creates a spirv::SpecConstantOp.
- spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
- Attribute defaultValue);
-
- /// Processes the OpVariable instructions at current `offset` into `binary`.
- /// It is expected that this method is used for variables that are to be
- /// defined at module scope and will be deserialized into a spv.globalVariable
- /// instruction.
- LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
-
- /// Gets the global variable associated with a result <id> of OpVariable.
- spirv::GlobalVariableOp getGlobalVariable(uint32_t id) {
- return globalVariableMap.lookup(id);
- }
-
- //===--------------------------------------------------------------------===//
- // Type
- //===--------------------------------------------------------------------===//
-
- /// Gets type for a given result <id>.
- Type getType(uint32_t id) { return typeMap.lookup(id); }
-
- /// Get the type associated with the result <id> of an OpUndef.
- Type getUndefType(uint32_t id) { return undefMap.lookup(id); }
-
- /// Returns true if the given `type` is for SPIR-V void type.
- bool isVoidType(Type type) const { return type.isa<NoneType>(); }
-
- /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
- /// registers the type into `module`.
- LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
-
- LogicalResult processOpTypePointer(ArrayRef<uint32_t> operands);
-
- LogicalResult processArrayType(ArrayRef<uint32_t> operands);
-
- LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
-
- LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
-
- LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
-
- LogicalResult processStructType(ArrayRef<uint32_t> operands);
-
- LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
-
- //===--------------------------------------------------------------------===//
- // Constant
- //===--------------------------------------------------------------------===//
-
- /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
- /// `operands`. `isSpec` indicates whether this is a specialization constant.
- LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
-
- /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
- /// given `operands`. `isSpec` indicates whether this is a specialization
- /// constant.
- LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
- bool isSpec);
-
- /// Processes a SPIR-V OpConstantComposite instruction with the given
- /// `operands`.
- LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpSpecConstantComposite instruction with the given
- /// `operands`.
- LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpSpecConstantOperation instruction with the given
- /// `operands`.
- LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
-
- /// Materializes/emits an OpSpecConstantOperation instruction.
- Value materializeSpecConstantOperation(uint32_t resultID,
- spirv::Opcode enclosedOpcode,
- uint32_t resultTypeID,
- ArrayRef<uint32_t> enclosedOpOperands);
-
- /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
- LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
-
- //===--------------------------------------------------------------------===//
- // Debug
- //===--------------------------------------------------------------------===//
-
- /// Discontinues any source-level location information that might be active
- /// from a previous OpLine instruction.
- LogicalResult clearDebugLine();
-
- /// Creates a FileLineColLoc with the OpLine location information.
- Location createFileLineColLoc(OpBuilder opBuilder);
-
- /// Processes a SPIR-V OpLine instruction with the given `operands`.
- LogicalResult processDebugLine(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpString instruction with the given `operands`.
- LogicalResult processDebugString(ArrayRef<uint32_t> operands);
-
- //===--------------------------------------------------------------------===//
- // Control flow
- //===--------------------------------------------------------------------===//
-
- /// Returns the block for the given label <id>.
- Block *getBlock(uint32_t id) const { return blockMap.lookup(id); }
-
- // In SPIR-V, structured control flow is explicitly declared using merge
- // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect,
- // we use spv.selection and spv.loop to group structured control flow.
- // The deserializer need to turn structured control flow marked with merge
- // instructions into using spv.selection/spv.loop ops.
- //
- // Because structured control flow can nest and the basic block order have
- // flexibility, we cannot isolate a structured selection/loop without
- // deserializing all the blocks. So we use the following approach:
- //
- // 1. Deserialize all basic blocks in a function and create MLIR blocks for
- // them into the function's region. In the meanwhile, keep a map between
- // selection/loop header blocks to their corresponding merge (and continue)
- // target blocks.
- // 2. For each selection/loop header block, recursively get all basic blocks
- // reachable (except the merge block) and put them in a newly created
- // spv.selection/spv.loop's region. Structured control flow guarantees
- // that we enter and exit in structured ways and the construct is nestable.
- // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge
- // block and redirect all branches to the old header block to the old
- // merge block (which contains the spv.selection/spv.loop op now).
-
- /// For OpPhi instructions, we use block arguments to represent them. OpPhi
- /// encodes a list of (value, predecessor) pairs. At the time of handling the
- /// block containing an OpPhi instruction, the predecessor block might not be
- /// processed yet, also the value sent by it. So we need to defer handling
- /// the block argument from the predecessors. We use the following approach:
- ///
- /// 1. For each OpPhi instruction, add a block argument to the current block
- /// in construction. Record the block argument in `valueMap` so its uses
- /// can be resolved. For the list of (value, predecessor) pairs, update
- /// `blockPhiInfo` for bookkeeping.
- /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each
- /// block recorded there to create the proper block arguments on their
- /// terminators.
-
- /// A data structure for containing a SPIR-V block's phi info. It will be
- /// represented as block argument in SPIR-V dialect.
- using BlockPhiInfo =
- SmallVector<uint32_t, 2>; // The result <id> of the values sent
-
- /// Gets or creates the block corresponding to the given label <id>. The newly
- /// created block will always be placed at the end of the current function.
- Block *getOrCreateBlock(uint32_t id);
-
- LogicalResult processBranch(ArrayRef<uint32_t> operands);
-
- LogicalResult processBranchConditional(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpLabel instruction with the given `operands`.
- LogicalResult processLabel(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`.
- LogicalResult processSelectionMerge(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
- LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
-
- /// Processes a SPIR-V OpPhi instruction with the given `operands`.
- LogicalResult processPhi(ArrayRef<uint32_t> operands);
-
- /// Creates block arguments on predecessors previously recorded when handling
- /// OpPhi instructions.
- LogicalResult wireUpBlockArgument();
-
- /// Extracts blocks belonging to a structured selection/loop into a
- /// spv.selection/spv.loop op. This method iterates until all blocks
- /// declared as selection/loop headers are handled.
- LogicalResult structurizeControlFlow();
-
- //===--------------------------------------------------------------------===//
- // Instruction
- //===--------------------------------------------------------------------===//
-
- /// Get the Value associated with a result <id>.
- ///
- /// This method materializes normal constants and inserts "casting" ops
- /// (`spv.mlir.addressof` and `spv.mlir.referenceof`) to turn an symbol into a
- /// SSA value for handling uses of module scope constants/variables in
- /// functions.
- Value getValue(uint32_t id);
-
- /// Slices the first instruction out of `binary` and returns its opcode and
- /// operands via `opcode` and `operands` respectively. Returns failure if
- /// there is no more remaining instructions (`expectedOpcode` will be used to
- /// compose the error message) or the next instruction is malformed.
- LogicalResult
- sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
- Optional<spirv::Opcode> expectedOpcode = llvm::None);
-
- /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
- /// This method is the main entrance for handling SPIR-V instruction; it
- /// checks the instruction opcode and dispatches to the corresponding handler.
- /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode)
- /// might need to be deferred, since they contain forward references to <id>s
- /// in the deserialized binary, but module in SPIR-V dialect expects these to
- /// be ssa-uses.
- LogicalResult processInstruction(spirv::Opcode opcode,
- ArrayRef<uint32_t> operands,
- bool deferInstructions = true);
-
- /// Processes a SPIR-V instruction from the given `operands`. It should
- /// deserialize into an op with the given `opName` and `numOperands`.
- /// This method is a generic one for dispatching any SPIR-V ops without
- /// variadic operands and attributes in TableGen definitions.
- LogicalResult processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
- StringRef opName, bool hasResult,
- unsigned numOperands);
-
- /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
- /// insertion point.
- LogicalResult processUndef(ArrayRef<uint32_t> operands);
-
- LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
-
- /// Method to dispatch to the specialized deserialization function for an
- /// operation in SPIR-V dialect that is a mirror of an instruction in the
- /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
- /// all operations in SPIR-V dialect that have hasOpcode == 1.
- LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
- ArrayRef<uint32_t> words);
-
- /// Processes a SPIR-V OpExtInst with given `operands`. This slices the
- /// entries of `operands` that specify the extended instruction set <id> and
- /// the instruction opcode. The op deserializer is then invoked using the
- /// other entries.
- LogicalResult processExtInst(ArrayRef<uint32_t> operands);
-
- /// Dispatches the deserialization of extended instruction set operation based
- /// on the extended instruction set name, and instruction opcode. This is
- /// autogenerated from ODS.
- LogicalResult
- dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName,
- uint32_t instructionID,
- ArrayRef<uint32_t> words);
-
- /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
- /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
- /// == 1 and autogenSerialization == 1 in ODS.
- template <typename OpTy>
- LogicalResult processOp(ArrayRef<uint32_t> words) {
- return emitError(unknownLoc, "unsupported deserialization for ")
- << OpTy::getOperationName() << " op";
- }
-
-private:
- /// The SPIR-V binary module.
- ArrayRef<uint32_t> binary;
-
- /// Contains the data of the OpLine instruction which precedes the current
- /// processing instruction.
- llvm::Optional<DebugLine> debugLine;
-
- /// The current word offset into the binary module.
- unsigned curOffset = 0;
-
- /// MLIRContext to create SPIR-V ModuleOp into.
- MLIRContext *context;
-
- // TODO: create Location subclass for binary blob
- Location unknownLoc;
-
- /// The SPIR-V ModuleOp.
- spirv::OwningSPIRVModuleRef module;
-
- /// The current function under construction.
- Optional<spirv::FuncOp> curFunction;
-
- /// The current block under construction.
- Block *curBlock = nullptr;
-
- OpBuilder opBuilder;
-
- spirv::Version version;
-
- /// The list of capabilities used by the module.
- llvm::SmallSetVector<spirv::Capability, 4> capabilities;
-
- /// The list of extensions used by the module.
- llvm::SmallSetVector<spirv::Extension, 2> extensions;
-
- // Result <id> to type mapping.
- DenseMap<uint32_t, Type> typeMap;
-
- // Result <id> to constant attribute and type mapping.
- ///
- /// In the SPIR-V binary format, all constants are placed in the module and
- /// shared by instructions at module level and in subsequent functions. But in
- /// the SPIR-V dialect, we materialize the constant to where it's used in the
- /// function. So when seeing a constant instruction in the binary format, we
- /// don't immediately emit a constant op into the module, we keep its value
- /// (and type) here. Later when it's used, we materialize the constant.
- DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
-
- // Result <id> to spec constant mapping.
- DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
-
- // Result <id> to composite spec constant mapping.
- DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
-
- /// Result <id> to info needed to materialize an OpSpecConstantOperation
- /// mapping.
- DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
- specConstOperationMap;
-
- // Result <id> to variable mapping.
- DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
-
- // Result <id> to function mapping.
- DenseMap<uint32_t, spirv::FuncOp> funcMap;
-
- // Result <id> to block mapping.
- DenseMap<uint32_t, Block *> blockMap;
-
- // Header block to its merge (and continue) target mapping.
- BlockMergeInfoMap blockMergeInfo;
-
- // Block to its phi (block argument) mapping.
- DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
-
- // Result <id> to value mapping.
- DenseMap<uint32_t, Value> valueMap;
-
- // Mapping from result <id> to undef value of a type.
- DenseMap<uint32_t, Type> undefMap;
-
- // Result <id> to name mapping.
- DenseMap<uint32_t, StringRef> nameMap;
-
- // Result <id> to debug info mapping.
- DenseMap<uint32_t, StringRef> debugInfoMap;
-
- // Result <id> to decorations mapping.
- DenseMap<uint32_t, NamedAttrList> decorations;
-
- // Result <id> to type decorations.
- DenseMap<uint32_t, uint32_t> typeDecorations;
-
- // Result <id> to member decorations.
- // decorated-struct-type-<id> ->
- // (struct-member-index -> (decoration -> decoration-operands))
- DenseMap<uint32_t,
- DenseMap<uint32_t, DenseMap<spirv::Decoration, ArrayRef<uint32_t>>>>
- memberDecorationMap;
-
- // Result <id> to member name.
- // struct-type-<id> -> (struct-member-index -> name)
- DenseMap<uint32_t, DenseMap<uint32_t, StringRef>> memberNameMap;
-
- // Result <id> to extended instruction set name.
- DenseMap<uint32_t, StringRef> extendedInstSets;
-
- // List of instructions that are processed in a deferred fashion (after an
- // initial processing of the entire binary). Some operations like
- // OpEntryPoint, and OpExecutionMode use forward references to function
- // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and
- // spv.ExecutionMode) need these references resolved. So these instructions
- // are deserialized and stored for processing once the entire binary is
- // processed.
- SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4>
- deferredInstructions;
-
- /// A list of IDs for all types forward-declared through OpTypeForwardPointer
- /// instructions.
- llvm::SetVector<uint32_t> typeForwardPointerIDs;
-
- /// A list of all structs which have unresolved member types.
- SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// Deserializer Method Definitions
//===----------------------------------------------------------------------===//
-Deserializer::Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context)
+spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
+ MLIRContext *context)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
module(createModuleOp()), opBuilder(module->body()) {}
-LogicalResult Deserializer::deserialize() {
+LogicalResult spirv::Deserializer::deserialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
if (failed(processHeader()))
@@ -674,7 +88,7 @@ LogicalResult Deserializer::deserialize() {
return success();
}
-spirv::OwningSPIRVModuleRef Deserializer::collect() {
+spirv::OwningSPIRVModuleRef spirv::Deserializer::collect() {
return std::move(module);
}
@@ -682,14 +96,14 @@ spirv::OwningSPIRVModuleRef Deserializer::collect() {
// Module structure
//===----------------------------------------------------------------------===//
-spirv::OwningSPIRVModuleRef Deserializer::createModuleOp() {
+spirv::OwningSPIRVModuleRef spirv::Deserializer::createModuleOp() {
OpBuilder builder(context);
OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
spirv::ModuleOp::build(builder, state);
return cast<spirv::ModuleOp>(Operation::create(state));
}
-LogicalResult Deserializer::processHeader() {
+LogicalResult spirv::Deserializer::processHeader() {
if (binary.size() < spirv::kHeaderWordCount)
return emitError(unknownLoc,
"SPIR-V binary module must have a 5-word header");
@@ -728,7 +142,8 @@ LogicalResult Deserializer::processHeader() {
return success();
}
-LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
if (operands.size() != 1)
return emitError(unknownLoc, "OpMemoryModel must have one parameter");
@@ -740,7 +155,7 @@ LogicalResult Deserializer::processCapability(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> words) {
+LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
if (words.empty()) {
return emitError(
unknownLoc,
@@ -760,7 +175,8 @@ LogicalResult Deserializer::processExtension(ArrayRef<uint32_t> words) {
return success();
}
-LogicalResult Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
+LogicalResult
+spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
if (words.size() < 2) {
return emitError(unknownLoc,
"OpExtInstImport must have a result <id> and a literal "
@@ -776,14 +192,15 @@ LogicalResult Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
return success();
}
-void Deserializer::attachVCETriple() {
+void spirv::Deserializer::attachVCETriple() {
(*module)->setAttr(
spirv::ModuleOp::getVCETripleAttrName(),
spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
extensions.getArrayRef(), context));
}
-LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
return emitError(unknownLoc, "OpMemoryModel must have two operands");
@@ -797,7 +214,7 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
+LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// TODO: This function should also be auto-generated. For now, since only a
// few decorations are processed/handled in a meaningful manner, going with a
// manual implementation.
@@ -871,7 +288,8 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return success();
}
-LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
+LogicalResult
+spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
// The binary layout of OpMemberDecorate is
diff erent comparing to OpDecorate
if (words.size() < 3) {
return emitError(unknownLoc,
@@ -892,7 +310,7 @@ LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
return success();
}
-LogicalResult Deserializer::processMemberName(ArrayRef<uint32_t> words) {
+LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
if (words.size() < 3) {
return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
}
@@ -906,7 +324,8 @@ LogicalResult Deserializer::processMemberName(ArrayRef<uint32_t> words) {
return success();
}
-LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
return emitError(unknownLoc, "found function inside function");
}
@@ -1043,7 +462,8 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return processFunctionEnd(instOperands);
}
-LogicalResult Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
// Process OpFunctionEnd.
if (!operands.empty()) {
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
@@ -1061,22 +481,23 @@ LogicalResult Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}
-Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
+Optional<std::pair<Attribute, Type>>
+spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
if (constIt == constantMap.end())
return llvm::None;
return constIt->getSecond();
}
-Optional<SpecConstOperationMaterializationInfo>
-Deserializer::getSpecConstantOperation(uint32_t id) {
+Optional<spirv::SpecConstOperationMaterializationInfo>
+spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
auto constIt = specConstOperationMap.find(id);
if (constIt == specConstOperationMap.end())
return llvm::None;
return constIt->getSecond();
}
-std::string Deserializer::getFunctionSymbol(uint32_t id) {
+std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
auto funcName = nameMap.lookup(id).str();
if (funcName.empty()) {
funcName = "spirv_fn_" + std::to_string(id);
@@ -1084,7 +505,7 @@ std::string Deserializer::getFunctionSymbol(uint32_t id) {
return funcName;
}
-std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
+std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
constName = "spirv_spec_const_" + std::to_string(id);
@@ -1092,9 +513,9 @@ std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
return constName;
}
-spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc,
- uint32_t resultID,
- Attribute defaultValue) {
+spirv::SpecConstantOp
+spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
+ Attribute defaultValue) {
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
defaultValue);
@@ -1106,7 +527,8 @@ spirv::SpecConstantOp Deserializer::createSpecConstant(Location loc,
return op;
}
-LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
if (operands.size() < 3) {
return emitError(
@@ -1177,7 +599,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
return success();
}
-IntegerAttr Deserializer::getConstantInt(uint32_t id) {
+IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
auto constInfo = getConstant(id);
if (!constInfo) {
return nullptr;
@@ -1185,7 +607,7 @@ IntegerAttr Deserializer::getConstantInt(uint32_t id) {
return constInfo->first.dyn_cast<IntegerAttr>();
}
-LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc, "OpName needs at least 2 operands");
}
@@ -1207,8 +629,8 @@ LogicalResult Deserializer::processName(ArrayRef<uint32_t> operands) {
// Type
//===----------------------------------------------------------------------===//
-LogicalResult Deserializer::processType(spirv::Opcode opcode,
- ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands) {
if (operands.empty()) {
return emitError(unknownLoc, "type instruction with opcode ")
<< spirv::stringifyOpcode(opcode) << " needs at least one <id>";
@@ -1303,7 +725,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return success();
}
-LogicalResult Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 3)
return emitError(unknownLoc, "OpTypePointer must have two parameters");
@@ -1356,7 +779,8 @@ LogicalResult Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 3) {
return emitError(unknownLoc,
"OpTypeArray must have element type and count parameters");
@@ -1388,7 +812,8 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
assert(!operands.empty() && "No operands for processing function type");
if (operands.size() == 1) {
return emitError(unknownLoc, "missing return type for OpTypeFunction");
@@ -1414,7 +839,7 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
}
LogicalResult
-Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 5) {
return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
"type and row x column parameters");
@@ -1443,7 +868,7 @@ Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
}
LogicalResult
-Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
}
@@ -1458,7 +883,8 @@ Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
// TODO: Find a way to handle identified structs when debug info is stripped.
if (operands.empty()) {
@@ -1545,7 +971,8 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
if (operands.size() != 3) {
// Three operands are needed: result_id, column_type, and column_count
return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
@@ -1564,12 +991,25 @@ LogicalResult Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2)
+ return emitError(unknownLoc,
+ "OpTypeForwardPointer instruction must have two operands");
+
+ typeForwardPointerIDs.insert(operands[0]);
+ // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
+ // instruction that defines the actual type.
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
-LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
- bool isSpec) {
+LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
+ bool isSpec) {
StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
if (operands.size() < 2) {
@@ -1682,9 +1122,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
"scalar integer or floating-point type");
}
-LogicalResult Deserializer::processConstantBool(bool isTrue,
- ArrayRef<uint32_t> operands,
- bool isSpec) {
+LogicalResult spirv::Deserializer::processConstantBool(
+ bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
if (operands.size() != 2) {
return emitError(unknownLoc, "Op")
<< (isSpec ? "Spec" : "") << "Constant"
@@ -1706,7 +1145,7 @@ LogicalResult Deserializer::processConstantBool(bool isTrue,
}
LogicalResult
-Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc,
"OpConstantComposite must have type <id> and result <id>");
@@ -1751,7 +1190,7 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
LogicalResult
-Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc,
"OpConstantComposite must have type <id> and result <id>");
@@ -1786,7 +1225,7 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
}
LogicalResult
-Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
if (operands.size() < 3)
return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
"result <id>, and operand opcode");
@@ -1812,7 +1251,7 @@ Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
return success();
}
-Value Deserializer::materializeSpecConstantOperation(
+Value spirv::Deserializer::materializeSpecConstantOperation(
uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
ArrayRef<uint32_t> enclosedOpOperands) {
@@ -1870,7 +1309,8 @@ Value Deserializer::materializeSpecConstantOperation(
return specConstOperationOp.getResult();
}
-LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
"OpConstantNull must have type <id> and result <id>");
@@ -1899,7 +1339,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
// Control flow
//===----------------------------------------------------------------------===//
-Block *Deserializer::getOrCreateBlock(uint32_t id) {
+Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
if (auto *block = getBlock(id)) {
LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
<< " @ " << block << "\n");
@@ -1915,7 +1355,7 @@ Block *Deserializer::getOrCreateBlock(uint32_t id) {
return blockMap[id] = block;
}
-LogicalResult Deserializer::processBranch(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpBranch must appear inside a block");
}
@@ -1936,7 +1376,7 @@ LogicalResult Deserializer::processBranch(ArrayRef<uint32_t> operands) {
}
LogicalResult
-Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
+spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc,
"OpBranchConditional must appear inside a block");
@@ -1969,7 +1409,7 @@ Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
if (!curFunction) {
return emitError(unknownLoc, "OpLabel must appear inside a function");
}
@@ -1991,7 +1431,8 @@ LogicalResult Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
}
@@ -2016,7 +1457,8 @@ LogicalResult Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpLoopMerge must appear in a block");
}
@@ -2042,7 +1484,7 @@ LogicalResult Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::processPhi(ArrayRef<uint32_t> operands) {
+LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
if (!curBlock) {
return emitError(unknownLoc, "OpPhi must appear in a block");
}
@@ -2086,7 +1528,7 @@ class ControlFlowStructurizer {
/// This method will also update `mergeInfo` by remapping all blocks inside to
/// the newly cloned ones inside structured control flow op's regions.
static LogicalResult structurize(Location loc, uint32_t control,
- BlockMergeInfoMap &mergeInfo,
+ spirv::BlockMergeInfoMap &mergeInfo,
Block *headerBlock, Block *mergeBlock,
Block *continueBlock) {
return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
@@ -2096,7 +1538,7 @@ class ControlFlowStructurizer {
private:
ControlFlowStructurizer(Location loc, uint32_t control,
- BlockMergeInfoMap &mergeInfo, Block *header,
+ spirv::BlockMergeInfoMap &mergeInfo, Block *header,
Block *merge, Block *cont)
: location(loc), control(control), blockMergeInfo(mergeInfo),
headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
@@ -2115,7 +1557,7 @@ class ControlFlowStructurizer {
Location location;
uint32_t control;
- BlockMergeInfoMap &blockMergeInfo;
+ spirv::BlockMergeInfoMap &blockMergeInfo;
Block *headerBlock;
Block *mergeBlock;
@@ -2339,7 +1781,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
return success();
}
-LogicalResult Deserializer::wireUpBlockArgument() {
+LogicalResult spirv::Deserializer::wireUpBlockArgument() {
LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
OpBuilder::InsertionGuard guard(opBuilder);
@@ -2388,7 +1830,7 @@ LogicalResult Deserializer::wireUpBlockArgument() {
return success();
}
-LogicalResult Deserializer::structurizeControlFlow() {
+LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
while (!blockMergeInfo.empty()) {
@@ -2428,7 +1870,7 @@ LogicalResult Deserializer::structurizeControlFlow() {
// Debug
//===----------------------------------------------------------------------===//
-Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
+Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
if (!debugLine)
return unknownLoc;
@@ -2439,7 +1881,8 @@ Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
debugLine->line, debugLine->col);
}
-LogicalResult Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
// According to SPIR-V spec:
// "This location information applies to the instructions physically
// following this instruction, up to the first occurrence of any of the
@@ -2451,12 +1894,13 @@ LogicalResult Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
return success();
}
-LogicalResult Deserializer::clearDebugLine() {
+LogicalResult spirv::Deserializer::clearDebugLine() {
debugLine = llvm::None;
return success();
}
-LogicalResult Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
+LogicalResult
+spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
if (operands.size() < 2)
return emitError(unknownLoc, "OpString needs at least 2 operands");
@@ -2474,560 +1918,3 @@ LogicalResult Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
debugInfoMap[operands[0]] = debugString;
return success();
}
-
-//===----------------------------------------------------------------------===//
-// Instruction
-//===----------------------------------------------------------------------===//
-
-Value Deserializer::getValue(uint32_t id) {
- if (auto constInfo = getConstant(id)) {
- // Materialize a `spv.constant` op at every use site.
- return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
- constInfo->first);
- }
- if (auto varOp = getGlobalVariable(id)) {
- auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
- unknownLoc, varOp.type(),
- opBuilder.getSymbolRefAttr(varOp.getOperation()));
- return addressOfOp.pointer();
- }
- if (auto constOp = getSpecConstant(id)) {
- auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constOp.default_value().getType(),
- opBuilder.getSymbolRefAttr(constOp.getOperation()));
- return referenceOfOp.reference();
- }
- if (auto constCompositeOp = getSpecConstantComposite(id)) {
- auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
- unknownLoc, constCompositeOp.type(),
- opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
- return referenceOfOp.reference();
- }
- if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
- return materializeSpecConstantOperation(
- id, specConstOperationInfo->enclodesOpcode,
- specConstOperationInfo->resultTypeID,
- specConstOperationInfo->enclosedOpOperands);
- }
- if (auto undef = getUndefType(id)) {
- return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
- }
- return valueMap.lookup(id);
-}
-
-LogicalResult
-Deserializer::sliceInstruction(spirv::Opcode &opcode,
- ArrayRef<uint32_t> &operands,
- Optional<spirv::Opcode> expectedOpcode) {
- auto binarySize = binary.size();
- if (curOffset >= binarySize) {
- return emitError(unknownLoc, "expected ")
- << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
- : "more")
- << " instruction";
- }
-
- // For each instruction, get its word count from the first word to slice it
- // from the stream properly, and then dispatch to the instruction handler.
-
- uint32_t wordCount = binary[curOffset] >> 16;
-
- if (wordCount == 0)
- return emitError(unknownLoc, "word count cannot be zero");
-
- uint32_t nextOffset = curOffset + wordCount;
- if (nextOffset > binarySize)
- return emitError(unknownLoc, "insufficient words for the last instruction");
-
- opcode = extractOpcode(binary[curOffset]);
- operands = binary.slice(curOffset + 1, wordCount - 1);
- curOffset = nextOffset;
- return success();
-}
-
-LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
- ArrayRef<uint32_t> operands,
- bool deferInstructions) {
- LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction "
- << spirv::stringifyOpcode(opcode) << "\n");
-
- // First dispatch all the instructions whose opcode does not correspond to
- // those that have a direct mirror in the SPIR-V dialect
- switch (opcode) {
- case spirv::Opcode::OpCapability:
- return processCapability(operands);
- case spirv::Opcode::OpExtension:
- return processExtension(operands);
- case spirv::Opcode::OpExtInst:
- return processExtInst(operands);
- case spirv::Opcode::OpExtInstImport:
- return processExtInstImport(operands);
- case spirv::Opcode::OpMemberName:
- return processMemberName(operands);
- case spirv::Opcode::OpMemoryModel:
- return processMemoryModel(operands);
- case spirv::Opcode::OpEntryPoint:
- case spirv::Opcode::OpExecutionMode:
- if (deferInstructions) {
- deferredInstructions.emplace_back(opcode, operands);
- return success();
- }
- break;
- case spirv::Opcode::OpVariable:
- if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
- return processGlobalVariable(operands);
- }
- break;
- case spirv::Opcode::OpLine:
- return processDebugLine(operands);
- case spirv::Opcode::OpNoLine:
- return clearDebugLine();
- case spirv::Opcode::OpName:
- return processName(operands);
- case spirv::Opcode::OpString:
- return processDebugString(operands);
- case spirv::Opcode::OpModuleProcessed:
- case spirv::Opcode::OpSource:
- case spirv::Opcode::OpSourceContinued:
- case spirv::Opcode::OpSourceExtension:
- // TODO: This is debug information embedded in the binary which should be
- // translated into the spv.module.
- return success();
- case spirv::Opcode::OpTypeVoid:
- case spirv::Opcode::OpTypeBool:
- case spirv::Opcode::OpTypeInt:
- case spirv::Opcode::OpTypeFloat:
- case spirv::Opcode::OpTypeVector:
- case spirv::Opcode::OpTypeMatrix:
- case spirv::Opcode::OpTypeArray:
- case spirv::Opcode::OpTypeFunction:
- case spirv::Opcode::OpTypeRuntimeArray:
- case spirv::Opcode::OpTypeStruct:
- case spirv::Opcode::OpTypePointer:
- case spirv::Opcode::OpTypeCooperativeMatrixNV:
- return processType(opcode, operands);
- case spirv::Opcode::OpConstant:
- return processConstant(operands, /*isSpec=*/false);
- case spirv::Opcode::OpSpecConstant:
- return processConstant(operands, /*isSpec=*/true);
- case spirv::Opcode::OpConstantComposite:
- return processConstantComposite(operands);
- case spirv::Opcode::OpSpecConstantComposite:
- return processSpecConstantComposite(operands);
- case spirv::Opcode::OpSpecConstantOperation:
- return processSpecConstantOperation(operands);
- case spirv::Opcode::OpConstantTrue:
- return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
- case spirv::Opcode::OpSpecConstantTrue:
- return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
- case spirv::Opcode::OpConstantFalse:
- return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
- case spirv::Opcode::OpSpecConstantFalse:
- return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
- case spirv::Opcode::OpConstantNull:
- return processConstantNull(operands);
- case spirv::Opcode::OpDecorate:
- return processDecoration(operands);
- case spirv::Opcode::OpMemberDecorate:
- return processMemberDecoration(operands);
- case spirv::Opcode::OpFunction:
- return processFunction(operands);
- case spirv::Opcode::OpLabel:
- return processLabel(operands);
- case spirv::Opcode::OpBranch:
- return processBranch(operands);
- case spirv::Opcode::OpBranchConditional:
- return processBranchConditional(operands);
- case spirv::Opcode::OpSelectionMerge:
- return processSelectionMerge(operands);
- case spirv::Opcode::OpLoopMerge:
- return processLoopMerge(operands);
- case spirv::Opcode::OpPhi:
- return processPhi(operands);
- case spirv::Opcode::OpUndef:
- return processUndef(operands);
- case spirv::Opcode::OpTypeForwardPointer:
- return processTypeForwardPointer(operands);
- default:
- break;
- }
- return dispatchToAutogenDeserialization(opcode, operands);
-}
-
-LogicalResult
-Deserializer::processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
- StringRef opName, bool hasResult,
- unsigned numOperands) {
- SmallVector<Type, 1> resultTypes;
- uint32_t valueID = 0;
-
- size_t wordIndex= 0;
- if (hasResult) {
- if (wordIndex >= words.size())
- return emitError(unknownLoc,
- "expected result type <id> while deserializing for ")
- << opName;
-
- // Decode the type <id>
- auto type = getType(words[wordIndex]);
- if (!type)
- return emitError(unknownLoc, "unknown type result <id>: ")
- << words[wordIndex];
- resultTypes.push_back(type);
- ++wordIndex;
-
- // Decode the result <id>
- if (wordIndex >= words.size())
- return emitError(unknownLoc,
- "expected result <id> while deserializing for ")
- << opName;
- valueID = words[wordIndex];
- ++wordIndex;
- }
-
- SmallVector<Value, 4> operands;
- SmallVector<NamedAttribute, 4> attributes;
-
- // Decode operands
- size_t operandIndex = 0;
- for (; operandIndex < numOperands && wordIndex < words.size();
- ++operandIndex, ++wordIndex) {
- auto arg = getValue(words[wordIndex]);
- if (!arg)
- return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
- operands.push_back(arg);
- }
- if (operandIndex != numOperands) {
- return emitError(
- unknownLoc,
- "found less operands than expected when deserializing for ")
- << opName << "; only " << operandIndex << " of " << numOperands
- << " processed";
- }
- if (wordIndex != words.size()) {
- return emitError(
- unknownLoc,
- "found more operands than expected when deserializing for ")
- << opName << "; only " << wordIndex << " of " << words.size()
- << " processed";
- }
-
- // Attach attributes from decorations
- if (decorations.count(valueID)) {
- auto attrs = decorations[valueID].getAttrs();
- attributes.append(attrs.begin(), attrs.end());
- }
-
- // Create the op and update bookkeeping maps
- Location loc = createFileLineColLoc(opBuilder);
- OperationState opState(loc, opName);
- opState.addOperands(operands);
- if (hasResult)
- opState.addTypes(resultTypes);
- opState.addAttributes(attributes);
- Operation *op = opBuilder.createOperation(opState);
- if (hasResult)
- valueMap[valueID] = op->getResult(0);
-
- if (op->hasTrait<OpTrait::IsTerminator>())
- clearDebugLine();
-
- return success();
-}
-
-LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
- if (operands.size() != 2) {
- return emitError(unknownLoc, "OpUndef instruction must have two operands");
- }
- auto type = getType(operands[0]);
- if (!type) {
- return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
- }
- undefMap[operands[1]] = type;
- return success();
-}
-
-LogicalResult
-Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
- if (operands.size() != 2)
- return emitError(unknownLoc,
- "OpTypeForwardPointer instruction must have two operands");
-
- typeForwardPointerIDs.insert(operands[0]);
- // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
- // instruction that defines the actual type.
-
- return success();
-}
-
-LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
- if (operands.size() < 4) {
- return emitError(unknownLoc,
- "OpExtInst must have at least 4 operands, result type "
- "<id>, result <id>, set <id> and instruction opcode");
- }
- if (!extendedInstSets.count(operands[2])) {
- return emitError(unknownLoc, "undefined set <id> in OpExtInst");
- }
- SmallVector<uint32_t, 4> slicedOperands;
- slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
- slicedOperands.append(std::next(operands.begin(), 4), operands.end());
- return dispatchToExtensionSetAutogenDeserialization(
- extendedInstSets[operands[2]], operands[3], slicedOperands);
-}
-
-namespace {
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
- unsigned wordIndex = 0;
- if (wordIndex >= words.size()) {
- return emitError(unknownLoc,
- "missing Execution Model specification in OpEntryPoint");
- }
- auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]);
- if (wordIndex >= words.size()) {
- return emitError(unknownLoc, "missing <id> in OpEntryPoint");
- }
- // Get the function <id>
- auto fnID = words[wordIndex++];
- // Get the function name
- auto fnName = decodeStringLiteral(words, wordIndex);
- // Verify that the function <id> matches the fnName
- auto parsedFunc = getFunction(fnID);
- if (!parsedFunc) {
- return emitError(unknownLoc, "no function matching <id> ") << fnID;
- }
- if (parsedFunc.getName() != fnName) {
- return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
- "and OpFunction with <id> ")
- << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
- }
- SmallVector<Attribute, 4> interface;
- while (wordIndex < words.size()) {
- auto arg = getGlobalVariable(words[wordIndex]);
- if (!arg) {
- return emitError(unknownLoc, "undefined result <id> ")
- << words[wordIndex] << " while decoding OpEntryPoint";
- }
- interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation()));
- wordIndex++;
- }
- opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel,
- opBuilder.getSymbolRefAttr(fnName),
- opBuilder.getArrayAttr(interface));
- return success();
-}
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
- unsigned wordIndex = 0;
- if (wordIndex >= words.size()) {
- return emitError(unknownLoc,
- "missing function result <id> in OpExecutionMode");
- }
- // Get the function <id> to get the name of the function
- auto fnID = words[wordIndex++];
- auto fn = getFunction(fnID);
- if (!fn) {
- return emitError(unknownLoc, "no function matching <id> ") << fnID;
- }
- // Get the Execution mode
- if (wordIndex >= words.size()) {
- return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
- }
- auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]);
-
- // Get the values
- SmallVector<Attribute, 4> attrListElems;
- while (wordIndex < words.size()) {
- attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
- }
- auto values = opBuilder.getArrayAttr(attrListElems);
- opBuilder.create<spirv::ExecutionModeOp>(
- unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values);
- return success();
-}
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
- if (operands.size() != 3) {
- return emitError(
- unknownLoc,
- "OpControlBarrier must have execution scope <id>, memory scope <id> "
- "and memory semantics <id>");
- }
-
- SmallVector<IntegerAttr, 3> argAttrs;
- for (auto operand : operands) {
- auto argAttr = getConstantInt(operand);
- if (!argAttr) {
- return emitError(unknownLoc,
- "expected 32-bit integer constant from <id> ")
- << operand << " for OpControlBarrier";
- }
- argAttrs.push_back(argAttr);
- }
-
- opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0],
- argAttrs[1], argAttrs[2]);
- return success();
-}
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
- if (operands.size() < 3) {
- return emitError(unknownLoc,
- "OpFunctionCall must have at least 3 operands");
- }
-
- Type resultType = getType(operands[0]);
- if (!resultType) {
- return emitError(unknownLoc, "undefined result type from <id> ")
- << operands[0];
- }
-
- // Use null type to mean no result type.
- if (isVoidType(resultType))
- resultType = nullptr;
-
- auto resultID = operands[1];
- auto functionID = operands[2];
-
- auto functionName = getFunctionSymbol(functionID);
-
- SmallVector<Value, 4> arguments;
- for (auto operand : llvm::drop_begin(operands, 3)) {
- auto value = getValue(operand);
- if (!value) {
- return emitError(unknownLoc, "unknown <id> ")
- << operand << " used by OpFunctionCall";
- }
- arguments.push_back(value);
- }
-
- auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
- unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName),
- arguments);
-
- if (resultType)
- valueMap[resultID] = opFunctionCall.getResult(0);
- return success();
-}
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
- if (operands.size() != 2) {
- return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
- "and memory semantics <id>");
- }
-
- SmallVector<IntegerAttr, 2> argAttrs;
- for (auto operand : operands) {
- auto argAttr = getConstantInt(operand);
- if (!argAttr) {
- return emitError(unknownLoc,
- "expected 32-bit integer constant from <id> ")
- << operand << " for OpMemoryBarrier";
- }
- argAttrs.push_back(argAttr);
- }
-
- opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0],
- argAttrs[1]);
- return success();
-}
-
-template <>
-LogicalResult
-Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
- SmallVector<Type, 1> resultTypes;
- size_t wordIndex = 0;
- SmallVector<Value, 4> operands;
- SmallVector<NamedAttribute, 4> attributes;
-
- if (wordIndex < words.size()) {
- auto arg = getValue(words[wordIndex]);
-
- if (!arg) {
- return emitError(unknownLoc, "unknown result <id> : ")
- << words[wordIndex];
- }
-
- operands.push_back(arg);
- wordIndex++;
- }
-
- if (wordIndex < words.size()) {
- auto arg = getValue(words[wordIndex]);
-
- if (!arg) {
- return emitError(unknownLoc, "unknown result <id> : ")
- << words[wordIndex];
- }
-
- operands.push_back(arg);
- wordIndex++;
- }
-
- bool isAlignedAttr = false;
-
- if (wordIndex < words.size()) {
- auto attrValue = words[wordIndex++];
- attributes.push_back(opBuilder.getNamedAttr(
- "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
- isAlignedAttr = (attrValue == 2);
- }
-
- if (isAlignedAttr && wordIndex < words.size()) {
- attributes.push_back(opBuilder.getNamedAttr(
- "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
- }
-
- if (wordIndex < words.size()) {
- attributes.push_back(opBuilder.getNamedAttr(
- "source_memory_access",
- opBuilder.getI32IntegerAttr(words[wordIndex++])));
- }
-
- if (wordIndex < words.size()) {
- attributes.push_back(opBuilder.getNamedAttr(
- "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
- }
-
- if (wordIndex != words.size()) {
- return emitError(unknownLoc,
- "found more operands than expected when deserializing "
- "spirv::CopyMemoryOp, only ")
- << wordIndex << " of " << words.size() << " processed";
- }
-
- Location loc = createFileLineColLoc(opBuilder);
- opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);
-
- return success();
-}
-
-// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
-// various Deserializer::processOp<...>() specializations.
-#define GET_DESERIALIZATION_FNS
-#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
-
-} // namespace
-
-namespace mlir {
-spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef<uint32_t> binary,
- MLIRContext *context) {
- Deserializer deserializer(binary, context);
-
- if (failed(deserializer.deserialize()))
- return nullptr;
-
- return deserializer.collect();
-}
-} // namespace mlir
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
new file mode 100644
index 000000000000..826441da1dc0
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -0,0 +1,613 @@
+//===- Deserializer.h - MLIR SPIR-V Deserializer ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the SPIR-V binary to MLIR SPIR-V module deserializer.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_SPIRV_DESERIALIZER_H
+#define MLIR_TARGET_SPIRV_DESERIALIZER_H
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringRef.h"
+#include <cstdint>
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Decodes a string literal in `words` starting at `wordIndex`. Update the
+/// latter to point to the position in words after the string literal.
+static inline llvm::StringRef
+decodeStringLiteral(llvm::ArrayRef<uint32_t> words, unsigned &wordIndex) {
+ llvm::StringRef str(reinterpret_cast<const char *>(words.data() + wordIndex));
+ wordIndex += str.size() / 4 + 1;
+ return str;
+}
+
+namespace mlir {
+namespace spirv {
+
+//===----------------------------------------------------------------------===//
+// Utility Definitions
+//===----------------------------------------------------------------------===//
+
+/// A struct for containing a header block's merge and continue targets.
+///
+/// This struct is used to track original structured control flow info from
+/// SPIR-V blob. This info will be used to create spv.selection/spv.loop
+/// later.
+struct BlockMergeInfo {
+ Block *mergeBlock;
+ Block *continueBlock; // nullptr for spv.selection
+ Location loc;
+ uint32_t control;
+
+ BlockMergeInfo(Location location, uint32_t control)
+ : mergeBlock(nullptr), continueBlock(nullptr), loc(location),
+ control(control) {}
+ BlockMergeInfo(Location location, uint32_t control, Block *m,
+ Block *c = nullptr)
+ : mergeBlock(m), continueBlock(c), loc(location), control(control) {}
+};
+
+/// A struct for containing OpLine instruction information.
+struct DebugLine {
+ uint32_t fileID;
+ uint32_t line;
+ uint32_t col;
+
+ DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum)
+ : fileID(fileIDNum), line(lineNum), col(colNum) {}
+};
+
+/// Map from a selection/loop's header block to its merge (and continue) target.
+using BlockMergeInfoMap = DenseMap<Block *, BlockMergeInfo>;
+
+/// A "deferred struct type" is a struct type with one or more member types not
+/// known when the Deserializer first encounters the struct. This happens, for
+/// example, with recursive structs where a pointer to the struct type is
+/// forward declared through OpTypeForwardPointer in the SPIR-V module before
+/// the struct declaration; the actual pointer to struct type should be defined
+/// later through an OpTypePointer. For example, the following C struct:
+///
+/// struct A {
+/// A* next;
+/// };
+///
+/// would be represented in the SPIR-V module as:
+///
+/// OpName %A "A"
+/// OpTypeForwardPointer %APtr Generic
+/// %A = OpTypeStruct %APtr
+/// %APtr = OpTypePointer Generic %A
+///
+/// This means that the spirv::StructType cannot be fully constructed directly
+/// when the Deserializer encounters it. Instead we create a
+/// DeferredStructTypeInfo that contains all the information we know about the
+/// spirv::StructType. Once all forward references for the struct are resolved,
+/// the struct's body is set with all member info.
+struct DeferredStructTypeInfo {
+ spirv::StructType deferredStructType;
+
+ // A list of all unresolved member types for the struct. First element of each
+ // item is operand ID, second element is member index in the struct.
+ SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
+
+ // The list of member types. For unresolved members, this list contains
+ // place-holder empty types that will be updated later.
+ SmallVector<Type, 4> memberTypes;
+ SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
+ SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+};
+
+/// A struct that collects the info needed to materialize/emit a
+/// SpecConstantOperation op.
+struct SpecConstOperationMaterializationInfo {
+ spirv::Opcode enclodesOpcode;
+ uint32_t resultTypeID;
+ SmallVector<uint32_t> enclosedOpOperands;
+};
+
+//===----------------------------------------------------------------------===//
+// Deserializer Declaration
+//===----------------------------------------------------------------------===//
+
+/// A SPIR-V module serializer.
+///
+/// A SPIR-V binary module is a single linear stream of instructions; each
+/// instruction is composed of 32-bit words. The first word of an instruction
+/// records the total number of words of that instruction using the 16
+/// higher-order bits. So this deserializer uses that to get instruction
+/// boundary and parse instructions and build a SPIR-V ModuleOp gradually.
+///
+// TODO: clean up created ops on errors
+class Deserializer {
+public:
+ /// Creates a deserializer for the given SPIR-V `binary` module.
+ /// The SPIR-V ModuleOp will be created into `context.
+ explicit Deserializer(ArrayRef<uint32_t> binary, MLIRContext *context);
+
+ /// Deserializes the remembered SPIR-V binary module.
+ LogicalResult deserialize();
+
+ /// Collects the final SPIR-V ModuleOp.
+ spirv::OwningSPIRVModuleRef collect();
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Module structure
+ //===--------------------------------------------------------------------===//
+
+ /// Initializes the `module` ModuleOp in this deserializer instance.
+ spirv::OwningSPIRVModuleRef createModuleOp();
+
+ /// Processes SPIR-V module header in `binary`.
+ LogicalResult processHeader();
+
+ /// Processes the SPIR-V OpCapability with `operands` and updates bookkeeping
+ /// in the deserializer.
+ LogicalResult processCapability(ArrayRef<uint32_t> operands);
+
+ /// Processes the SPIR-V OpExtension with `operands` and updates bookkeeping
+ /// in the deserializer.
+ LogicalResult processExtension(ArrayRef<uint32_t> words);
+
+ /// Processes the SPIR-V OpExtInstImport with `operands` and updates
+ /// bookkeeping in the deserializer.
+ LogicalResult processExtInstImport(ArrayRef<uint32_t> words);
+
+ /// Attaches (version, capabilities, extensions) triple to `module` as an
+ /// attribute.
+ void attachVCETriple();
+
+ /// Processes the SPIR-V OpMemoryModel with `operands` and updates `module`.
+ LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
+
+ /// Process SPIR-V OpName with `operands`.
+ LogicalResult processName(ArrayRef<uint32_t> operands);
+
+ /// Processes an OpDecorate instruction.
+ LogicalResult processDecoration(ArrayRef<uint32_t> words);
+
+ // Processes an OpMemberDecorate instruction.
+ LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
+
+ /// Processes an OpMemberName instruction.
+ LogicalResult processMemberName(ArrayRef<uint32_t> words);
+
+ /// Gets the function op associated with a result <id> of OpFunction.
+ spirv::FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
+
+ /// Processes the SPIR-V function at the current `offset` into `binary`.
+ /// The operands to the OpFunction instruction is passed in as ``operands`.
+ /// This method processes each instruction inside the function and dispatches
+ /// them to their handler method accordingly.
+ LogicalResult processFunction(ArrayRef<uint32_t> operands);
+
+ /// Processes OpFunctionEnd and finalizes function. This wires up block
+ /// argument created from OpPhi instructions and also structurizes control
+ /// flow.
+ LogicalResult processFunctionEnd(ArrayRef<uint32_t> operands);
+
+ /// Gets the constant's attribute and type associated with the given <id>.
+ Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
+
+ /// Gets the info needed to materialize the spec constant operation op
+ /// associated with the given <id>.
+ Optional<SpecConstOperationMaterializationInfo>
+ getSpecConstantOperation(uint32_t id);
+
+ /// Gets the constant's integer attribute with the given <id>. Returns a
+ /// null IntegerAttr if the given is not registered or does not correspond
+ /// to an integer constant.
+ IntegerAttr getConstantInt(uint32_t id);
+
+ /// Returns a symbol to be used for the function name with the given
+ /// result <id>. This tries to use the function's OpName if
+ /// exists; otherwise creates one based on the <id>.
+ std::string getFunctionSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the given
+ /// result <id>. This tries to use the specialization constant's OpName if
+ /// exists; otherwise creates one based on the <id>.
+ std::string getSpecConstantSymbol(uint32_t id);
+
+ /// Gets the specialization constant with the given result <id>.
+ spirv::SpecConstantOp getSpecConstant(uint32_t id) {
+ return specConstMap.lookup(id);
+ }
+
+ /// Gets the composite specialization constant with the given result <id>.
+ spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
+ return specConstCompositeMap.lookup(id);
+ }
+
+ /// Creates a spirv::SpecConstantOp.
+ spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
+ Attribute defaultValue);
+
+ /// Processes the OpVariable instructions at current `offset` into `binary`.
+ /// It is expected that this method is used for variables that are to be
+ /// defined at module scope and will be deserialized into a spv.globalVariable
+ /// instruction.
+ LogicalResult processGlobalVariable(ArrayRef<uint32_t> operands);
+
+ /// Gets the global variable associated with a result <id> of OpVariable.
+ spirv::GlobalVariableOp getGlobalVariable(uint32_t id) {
+ return globalVariableMap.lookup(id);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type
+ //===--------------------------------------------------------------------===//
+
+ /// Gets type for a given result <id>.
+ Type getType(uint32_t id) { return typeMap.lookup(id); }
+
+ /// Get the type associated with the result <id> of an OpUndef.
+ Type getUndefType(uint32_t id) { return undefMap.lookup(id); }
+
+ /// Returns true if the given `type` is for SPIR-V void type.
+ bool isVoidType(Type type) const { return type.isa<NoneType>(); }
+
+ /// Processes a SPIR-V type instruction with given `opcode` and `operands` and
+ /// registers the type into `module`.
+ LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
+
+ LogicalResult processOpTypePointer(ArrayRef<uint32_t> operands);
+
+ LogicalResult processArrayType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processCooperativeMatrixType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processStructType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
+
+ LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Constant
+ //===--------------------------------------------------------------------===//
+
+ /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
+ /// `operands`. `isSpec` indicates whether this is a specialization constant.
+ LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
+
+ /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
+ /// given `operands`. `isSpec` indicates whether this is a specialization
+ /// constant.
+ LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
+ bool isSpec);
+
+ /// Processes a SPIR-V OpConstantComposite instruction with the given
+ /// `operands`.
+ LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpSpecConstantComposite instruction with the given
+ /// `operands`.
+ LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpSpecConstantOperation instruction with the given
+ /// `operands`.
+ LogicalResult processSpecConstantOperation(ArrayRef<uint32_t> operands);
+
+ /// Materializes/emits an OpSpecConstantOperation instruction.
+ Value materializeSpecConstantOperation(uint32_t resultID,
+ spirv::Opcode enclosedOpcode,
+ uint32_t resultTypeID,
+ ArrayRef<uint32_t> enclosedOpOperands);
+
+ /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
+ LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Debug
+ //===--------------------------------------------------------------------===//
+
+ /// Discontinues any source-level location information that might be active
+ /// from a previous OpLine instruction.
+ LogicalResult clearDebugLine();
+
+ /// Creates a FileLineColLoc with the OpLine location information.
+ Location createFileLineColLoc(OpBuilder opBuilder);
+
+ /// Processes a SPIR-V OpLine instruction with the given `operands`.
+ LogicalResult processDebugLine(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpString instruction with the given `operands`.
+ LogicalResult processDebugString(ArrayRef<uint32_t> operands);
+
+ //===--------------------------------------------------------------------===//
+ // Control flow
+ //===--------------------------------------------------------------------===//
+
+ /// Returns the block for the given label <id>.
+ Block *getBlock(uint32_t id) const { return blockMap.lookup(id); }
+
+ // In SPIR-V, structured control flow is explicitly declared using merge
+ // instructions (OpSelectionMerge and OpLoopMerge). In the SPIR-V dialect,
+ // we use spv.selection and spv.loop to group structured control flow.
+ // The deserializer need to turn structured control flow marked with merge
+ // instructions into using spv.selection/spv.loop ops.
+ //
+ // Because structured control flow can nest and the basic block order have
+ // flexibility, we cannot isolate a structured selection/loop without
+ // deserializing all the blocks. So we use the following approach:
+ //
+ // 1. Deserialize all basic blocks in a function and create MLIR blocks for
+ // them into the function's region. In the meanwhile, keep a map between
+ // selection/loop header blocks to their corresponding merge (and continue)
+ // target blocks.
+ // 2. For each selection/loop header block, recursively get all basic blocks
+ // reachable (except the merge block) and put them in a newly created
+ // spv.selection/spv.loop's region. Structured control flow guarantees
+ // that we enter and exit in structured ways and the construct is nestable.
+ // 3. Put the new spv.selection/spv.loop op at the beginning of the old merge
+ // block and redirect all branches to the old header block to the old
+ // merge block (which contains the spv.selection/spv.loop op now).
+
+ /// For OpPhi instructions, we use block arguments to represent them. OpPhi
+ /// encodes a list of (value, predecessor) pairs. At the time of handling the
+ /// block containing an OpPhi instruction, the predecessor block might not be
+ /// processed yet, also the value sent by it. So we need to defer handling
+ /// the block argument from the predecessors. We use the following approach:
+ ///
+ /// 1. For each OpPhi instruction, add a block argument to the current block
+ /// in construction. Record the block argument in `valueMap` so its uses
+ /// can be resolved. For the list of (value, predecessor) pairs, update
+ /// `blockPhiInfo` for bookkeeping.
+ /// 2. After processing all blocks, loop over `blockPhiInfo` to fix up each
+ /// block recorded there to create the proper block arguments on their
+ /// terminators.
+
+ /// A data structure for containing a SPIR-V block's phi info. It will be
+ /// represented as block argument in SPIR-V dialect.
+ using BlockPhiInfo =
+ SmallVector<uint32_t, 2>; // The result <id> of the values sent
+
+ /// Gets or creates the block corresponding to the given label <id>. The newly
+ /// created block will always be placed at the end of the current function.
+ Block *getOrCreateBlock(uint32_t id);
+
+ LogicalResult processBranch(ArrayRef<uint32_t> operands);
+
+ LogicalResult processBranchConditional(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpLabel instruction with the given `operands`.
+ LogicalResult processLabel(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpSelectionMerge instruction with the given `operands`.
+ LogicalResult processSelectionMerge(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpLoopMerge instruction with the given `operands`.
+ LogicalResult processLoopMerge(ArrayRef<uint32_t> operands);
+
+ /// Processes a SPIR-V OpPhi instruction with the given `operands`.
+ LogicalResult processPhi(ArrayRef<uint32_t> operands);
+
+ /// Creates block arguments on predecessors previously recorded when handling
+ /// OpPhi instructions.
+ LogicalResult wireUpBlockArgument();
+
+ /// Extracts blocks belonging to a structured selection/loop into a
+ /// spv.selection/spv.loop op. This method iterates until all blocks
+ /// declared as selection/loop headers are handled.
+ LogicalResult structurizeControlFlow();
+
+ //===--------------------------------------------------------------------===//
+ // Instruction
+ //===--------------------------------------------------------------------===//
+
+ /// Get the Value associated with a result <id>.
+ ///
+ /// This method materializes normal constants and inserts "casting" ops
+ /// (`spv.mlir.addressof` and `spv.mlir.referenceof`) to turn an symbol into a
+ /// SSA value for handling uses of module scope constants/variables in
+ /// functions.
+ Value getValue(uint32_t id);
+
+ /// Slices the first instruction out of `binary` and returns its opcode and
+ /// operands via `opcode` and `operands` respectively. Returns failure if
+ /// there is no more remaining instructions (`expectedOpcode` will be used to
+ /// compose the error message) or the next instruction is malformed.
+ LogicalResult
+ sliceInstruction(spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
+ Optional<spirv::Opcode> expectedOpcode = llvm::None);
+
+ /// Processes a SPIR-V instruction with the given `opcode` and `operands`.
+ /// This method is the main entrance for handling SPIR-V instruction; it
+ /// checks the instruction opcode and dispatches to the corresponding handler.
+ /// Processing of Some instructions (like OpEntryPoint and OpExecutionMode)
+ /// might need to be deferred, since they contain forward references to <id>s
+ /// in the deserialized binary, but module in SPIR-V dialect expects these to
+ /// be ssa-uses.
+ LogicalResult processInstruction(spirv::Opcode opcode,
+ ArrayRef<uint32_t> operands,
+ bool deferInstructions = true);
+
+ /// Processes a SPIR-V instruction from the given `operands`. It should
+ /// deserialize into an op with the given `opName` and `numOperands`.
+ /// This method is a generic one for dispatching any SPIR-V ops without
+ /// variadic operands and attributes in TableGen definitions.
+ LogicalResult processOpWithoutGrammarAttr(ArrayRef<uint32_t> words,
+ StringRef opName, bool hasResult,
+ unsigned numOperands);
+
+ /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
+ /// insertion point.
+ LogicalResult processUndef(ArrayRef<uint32_t> operands);
+
+ /// Method to dispatch to the specialized deserialization function for an
+ /// operation in SPIR-V dialect that is a mirror of an instruction in the
+ /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
+ /// all operations in SPIR-V dialect that have hasOpcode == 1.
+ LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
+ ArrayRef<uint32_t> words);
+
+ /// Processes a SPIR-V OpExtInst with given `operands`. This slices the
+ /// entries of `operands` that specify the extended instruction set <id> and
+ /// the instruction opcode. The op deserializer is then invoked using the
+ /// other entries.
+ LogicalResult processExtInst(ArrayRef<uint32_t> operands);
+
+ /// Dispatches the deserialization of extended instruction set operation based
+ /// on the extended instruction set name, and instruction opcode. This is
+ /// autogenerated from ODS.
+ LogicalResult
+ dispatchToExtensionSetAutogenDeserialization(StringRef extensionSetName,
+ uint32_t instructionID,
+ ArrayRef<uint32_t> words);
+
+ /// Method to deserialize an operation in the SPIR-V dialect that is a mirror
+ /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
+ /// == 1 and autogenSerialization == 1 in ODS.
+ template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
+ return emitError(unknownLoc, "unsupported deserialization for ")
+ << OpTy::getOperationName() << " op";
+ }
+
+private:
+ /// The SPIR-V binary module.
+ ArrayRef<uint32_t> binary;
+
+ /// Contains the data of the OpLine instruction which precedes the current
+ /// processing instruction.
+ llvm::Optional<DebugLine> debugLine;
+
+ /// The current word offset into the binary module.
+ unsigned curOffset = 0;
+
+ /// MLIRContext to create SPIR-V ModuleOp into.
+ MLIRContext *context;
+
+ // TODO: create Location subclass for binary blob
+ Location unknownLoc;
+
+ /// The SPIR-V ModuleOp.
+ spirv::OwningSPIRVModuleRef module;
+
+ /// The current function under construction.
+ Optional<spirv::FuncOp> curFunction;
+
+ /// The current block under construction.
+ Block *curBlock = nullptr;
+
+ OpBuilder opBuilder;
+
+ spirv::Version version;
+
+ /// The list of capabilities used by the module.
+ llvm::SmallSetVector<spirv::Capability, 4> capabilities;
+
+ /// The list of extensions used by the module.
+ llvm::SmallSetVector<spirv::Extension, 2> extensions;
+
+ // Result <id> to type mapping.
+ DenseMap<uint32_t, Type> typeMap;
+
+ // Result <id> to constant attribute and type mapping.
+ ///
+ /// In the SPIR-V binary format, all constants are placed in the module and
+ /// shared by instructions at module level and in subsequent functions. But in
+ /// the SPIR-V dialect, we materialize the constant to where it's used in the
+ /// function. So when seeing a constant instruction in the binary format, we
+ /// don't immediately emit a constant op into the module, we keep its value
+ /// (and type) here. Later when it's used, we materialize the constant.
+ DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
+
+ // Result <id> to spec constant mapping.
+ DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
+
+ // Result <id> to composite spec constant mapping.
+ DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
+
+ /// Result <id> to info needed to materialize an OpSpecConstantOperation
+ /// mapping.
+ DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
+ specConstOperationMap;
+
+ // Result <id> to variable mapping.
+ DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
+
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, spirv::FuncOp> funcMap;
+
+ // Result <id> to block mapping.
+ DenseMap<uint32_t, Block *> blockMap;
+
+ // Header block to its merge (and continue) target mapping.
+ BlockMergeInfoMap blockMergeInfo;
+
+ // Block to its phi (block argument) mapping.
+ DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
+
+ // Result <id> to value mapping.
+ DenseMap<uint32_t, Value> valueMap;
+
+ // Mapping from result <id> to undef value of a type.
+ DenseMap<uint32_t, Type> undefMap;
+
+ // Result <id> to name mapping.
+ DenseMap<uint32_t, StringRef> nameMap;
+
+ // Result <id> to debug info mapping.
+ DenseMap<uint32_t, StringRef> debugInfoMap;
+
+ // Result <id> to decorations mapping.
+ DenseMap<uint32_t, NamedAttrList> decorations;
+
+ // Result <id> to type decorations.
+ DenseMap<uint32_t, uint32_t> typeDecorations;
+
+ // Result <id> to member decorations.
+ // decorated-struct-type-<id> ->
+ // (struct-member-index -> (decoration -> decoration-operands))
+ DenseMap<uint32_t,
+ DenseMap<uint32_t, DenseMap<spirv::Decoration, ArrayRef<uint32_t>>>>
+ memberDecorationMap;
+
+ // Result <id> to member name.
+ // struct-type-<id> -> (struct-member-index -> name)
+ DenseMap<uint32_t, DenseMap<uint32_t, StringRef>> memberNameMap;
+
+ // Result <id> to extended instruction set name.
+ DenseMap<uint32_t, StringRef> extendedInstSets;
+
+ // List of instructions that are processed in a deferred fashion (after an
+ // initial processing of the entire binary). Some operations like
+ // OpEntryPoint, and OpExecutionMode use forward references to function
+ // <id>s. In SPIR-V dialect the corresponding operations (spv.EntryPoint and
+ // spv.ExecutionMode) need these references resolved. So these instructions
+ // are deserialized and stored for processing once the entire binary is
+ // processed.
+ SmallVector<std::pair<spirv::Opcode, ArrayRef<uint32_t>>, 4>
+ deferredInstructions;
+
+ /// A list of IDs for all types forward-declared through OpTypeForwardPointer
+ /// instructions.
+ llvm::SetVector<uint32_t> typeForwardPointerIDs;
+
+ /// A list of all structs which have unresolved member types.
+ SmallVector<DeferredStructTypeInfo, 0> deferredStructTypesInfos;
+};
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_TARGET_SPIRV_DESERIALIZER_H
diff --git a/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt
new file mode 100644
index 000000000000..c4120960a22b
--- /dev/null
+++ b/mlir/lib/Target/SPIRV/Serialization/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_translation_library(MLIRSPIRVSerialization
+ Serialization.cpp
+
+ DEPENDS
+ MLIRSPIRVSerializationGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSPIRV
+ MLIRSPIRVBinaryUtils
+ MLIRSupport
+ MLIRTranslation
+ )
+
+
diff --git a/mlir/lib/Target/SPIRV/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
similarity index 100%
rename from mlir/lib/Target/SPIRV/Serialization.cpp
rename to mlir/lib/Target/SPIRV/Serialization/Serialization.cpp
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 74fe1e0fdb08..20bf8773b137 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -996,11 +996,10 @@ static void emitDeserializationFunction(const Record *attrClass,
/// based on the `opcode`.
static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
raw_ostream &os) {
- os << formatv(
- "LogicalResult "
- "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, "
- "ArrayRef<uint32_t> {1}) {{\n",
- opcode, words);
+ os << formatv("LogicalResult spirv::Deserializer::"
+ "dispatchToAutogenDeserialization(spirv::Opcode {0},"
+ " ArrayRef<uint32_t> {1}) {{\n",
+ opcode, words);
os << formatv(" switch ({0}) {{\n", opcode);
}
@@ -1043,8 +1042,8 @@ static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
StringRef instructionID,
StringRef words,
raw_ostream &os) {
- os << formatv("LogicalResult "
- "Deserializer::dispatchToExtensionSetAutogenDeserialization("
+ os << formatv("LogicalResult spirv::Deserializer::"
+ "dispatchToExtensionSetAutogenDeserialization("
"StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
extensionSetName, instructionID, words);
}
More information about the llvm-branch-commits
mailing list