[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (PR #156665)
Jakub Kuderski
llvmlistbot at llvm.org
Wed Sep 10 09:35:01 PDT 2025
================
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "missing graph defintion in OpGraphEntryPointARM");
+ }
+
+ unsigned wordIndex = 0;
+ uint32_t grID = operands[wordIndex++];
+ if (!graphMap.count(grID)) {
+ return emitError(unknownLoc,
+ "missing graph definition/declaration with id ")
+ << grID;
+ }
+
+ spirv::GraphARMOp graphARM = graphMap[grID];
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ graphARM.setSymName(name);
+ graphARM.setEntryPoint(true);
+
+ SmallVector<Attribute, 4> interface;
+ for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+ if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+ interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+ } else {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+ }
+ }
+
+ // RAII guard to reset the insertion point to previous value when done.
+ OpBuilder::InsertionGuard insertionGuard(opBuilder);
+ opBuilder.setInsertionPoint(graphARM);
+ opBuilder.create<spirv::GraphEntryPointARMOp>(
+ unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ opBuilder.getArrayAttr(interface));
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+ if (curGraph) {
+ return emitError(unknownLoc, "found graph inside graph");
+ }
+ // Get the result type.
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+ }
+
+ Type type = getType(operands[0]);
+ if (!type || !isa<GraphType>(type)) {
+ return emitError(unknownLoc, "unknown graph type from <id> ")
+ << operands[0];
+ }
+ auto graphType = cast<GraphType>(type);
+ if (graphType.getNumResults() <= 0) {
+ return emitError(unknownLoc, "expected at least one result");
+ }
+
+ uint32_t grID = operands[1];
+ if (graphMap.count(grID)) {
+ return emitError(unknownLoc, "duplicate graph definition/declaration");
+ }
+
+ std::string grName = getGraphSymbol(grID);
+ auto graphOp =
+ opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+ curGraph = graphMap[grID] = graphOp;
+ Block *entryBlock = graphOp.addEntryBlock();
+ LLVM_DEBUG({
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ logger.startLine() << "[graph] name: " << grName << "\n";
+ logger.startLine() << "[graph] type: " << graphType << "\n";
+ logger.startLine() << "[graph] ID: " << grID << "\n";
+ logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+ logger.indent();
+ });
+
+ // Parse the op argument instructions.
+ for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpGraphInputARM))) {
+ return failure();
+ }
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "expected result type, result <id> and "
+ "input index for OpGraphInputARM");
+ }
+
+ Type argDefinedType = getType(operands[0]);
+ if (!argDefinedType) {
+ return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+ }
+
+ if (argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between graph type "
+ "definition ")
+ << graphType << " and argument type definition " << argDefinedType
+ << " at argument " << index;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> ")
+ << operands[1];
+ }
+
+ IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+ if (!inputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read inputIndex value from constant op ")
+ << operands[2];
+ }
+ BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+ valueMap[operands[1]] = argValue;
+ }
+
+ graphOutputs.resize(graphType.getNumResults());
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ blockMap[grID] = entryBlock;
+ if (failed(createGraphBlock(grID))) {
+ return failure();
+ }
+
+ // Process all the instructions in the graph until and including
+ // OpGraphEndARM.
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> instOperands;
+ do {
+ if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+ return failure();
+ }
+
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected value id and output index for OpGraphSetOutputARM");
+ }
+
+ uint32_t id = operands[0];
+ Value value = getValue(id);
+ if (!value) {
+ return emitError(unknownLoc, "could not find result <id> ") << id;
+ }
+
+ IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+ if (!outputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read outputIndex value from constant op ")
+ << operands[1];
+ }
+ graphOutputs[outputIndexAttr.getInt()] = value;
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+ // Create GraphOutputsARM instruction.
+ opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
----------------
kuhar wrote:
also here
https://github.com/llvm/llvm-project/pull/156665
More information about the Mlir-commits
mailing list