[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (PR #156665)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 3 06:15:11 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Davide Grohmann (davidegrohmann)
<details>
<summary>Changes</summary>
This is the second patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources.
The part 2 implementation includes:
- Serialization and deserialization support for:
- `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM`
- `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM`
- Tests covering binary round-tripping.
Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images.
Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947
---
Patch is 35.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156665.diff
7 Files Affected:
- (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+22)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+287)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+49-2)
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+122)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+80-2)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+36-3)
- (added) mlir/test/Target/SPIRV/graph-ops.mlir (+25)
``````````diff
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index ee18cf815e4a7..4c49ec868bbc8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
if (auto undef = getUndefType(id)) {
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
}
+ if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantARMInfo = getGraphConstantARM(id)) {
+ IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
+ Type resultType = graphConstantARMInfo->resultType;
+ return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
+ graphConstantID);
+ }
return valueMap.lookup(id);
}
@@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeTensorARM:
+ case spirv::Opcode::OpTypeGraphARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
+ case spirv::Opcode::OpGraphConstantARM:
+ return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
+ case spirv::Opcode::OpGraphEntryPointARM:
+ if (deferInstructions) {
+ deferredInstructions.emplace_back(opcode, operands);
+ return success();
+ }
+ return processGraphEntryPointARM(operands);
+ case spirv::Opcode::OpGraphARM:
+ return processGraphARM(operands);
+ case spirv::Opcode::OpGraphSetOutputARM:
+ return processOpGraphSetOutputARM(operands);
+ case spirv::Opcode::OpGraphEndARM:
+ return processGraphEndARM(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
case spirv::Opcode::OpBranch:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3625dd2eb7dd3..37b5d348b0a1c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 2) {
+ return emitError(unknownLoc,
+ "missing graph defintion in OpGraphEntryPointARM");
+ }
+
+ unsigned wordIndex = 0;
+ uint32_t grID = operands[wordIndex++];
+ if (!graphMap.count(grID)) {
+ return emitError(unknownLoc,
+ "missing graph definition/declaration with id ")
+ << grID;
+ }
+
+ spirv::GraphARMOp graphARM = graphMap[grID];
+ StringRef name = decodeStringLiteral(operands, wordIndex);
+ graphARM.setSymName(name);
+ graphARM.setEntryPoint(true);
+
+ SmallVector<Attribute, 4> interface;
+ for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+ if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+ interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+ } else {
+ return emitError(unknownLoc, "undefined result <id> ")
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+ }
+ }
+
+ // RAII guard to reset the insertion point to previous value when done.
+ OpBuilder::InsertionGuard insertionGuard(opBuilder);
+ opBuilder.setInsertionPoint(graphARM);
+ opBuilder.create<spirv::GraphEntryPointARMOp>(
+ unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+ opBuilder.getArrayAttr(interface));
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+ if (curGraph) {
+ return emitError(unknownLoc, "found graph inside graph");
+ }
+ // Get the result type.
+ if (operands.size() < 2) {
+ return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+ }
+
+ Type type = getType(operands[0]);
+ if (!type || !isa<GraphType>(type)) {
+ return emitError(unknownLoc, "unknown graph type from <id> ")
+ << operands[0];
+ }
+ auto graphType = cast<GraphType>(type);
+ if (graphType.getNumResults() <= 0) {
+ return emitError(unknownLoc, "expected at least one result");
+ }
+
+ uint32_t grID = operands[1];
+ if (graphMap.count(grID)) {
+ return emitError(unknownLoc, "duplicate graph definition/declaration");
+ }
+
+ std::string grName = getGraphSymbol(grID);
+ auto graphOp =
+ opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+ curGraph = graphMap[grID] = graphOp;
+ Block *entryBlock = graphOp.addEntryBlock();
+ LLVM_DEBUG({
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ logger.startLine() << "[graph] name: " << grName << "\n";
+ logger.startLine() << "[graph] type: " << graphType << "\n";
+ logger.startLine() << "[graph] ID: " << grID << "\n";
+ logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+ logger.indent();
+ });
+
+ // Parse the op argument instructions.
+ for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> operands;
+ if (failed(sliceInstruction(opcode, operands,
+ spirv::Opcode::OpGraphInputARM))) {
+ return failure();
+ }
+ if (operands.size() != 3) {
+ return emitError(unknownLoc, "expected result type, result <id> and "
+ "input index for OpGraphInputARM");
+ }
+
+ Type argDefinedType = getType(operands[0]);
+ if (!argDefinedType) {
+ return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+ }
+
+ if (argDefinedType != argType) {
+ return emitError(unknownLoc,
+ "mismatch in argument type between graph type "
+ "definition ")
+ << graphType << " and argument type definition " << argDefinedType
+ << " at argument " << index;
+ }
+ if (getValue(operands[1])) {
+ return emitError(unknownLoc, "duplicate definition of result <id> ")
+ << operands[1];
+ }
+
+ IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+ if (!inputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read inputIndex value from constant op ")
+ << operands[2];
+ }
+ BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+ valueMap[operands[1]] = argValue;
+ }
+
+ graphOutputs.resize(graphType.getNumResults());
+
+ // RAII guard to reset the insertion point to the module's region after
+ // deserializing the body of this function.
+ OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+ blockMap[grID] = entryBlock;
+ if (failed(createGraphBlock(grID))) {
+ return failure();
+ }
+
+ // Process all the instructions in the graph until and including
+ // OpGraphEndARM.
+ spirv::Opcode opcode;
+ ArrayRef<uint32_t> instOperands;
+ do {
+ if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+ return failure();
+ }
+
+ if (failed(processInstruction(opcode, instOperands))) {
+ return failure();
+ }
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() != 2) {
+ return emitError(
+ unknownLoc,
+ "expected value id and output index for OpGraphSetOutputARM");
+ }
+
+ uint32_t id = operands[0];
+ Value value = getValue(id);
+ if (!value) {
+ return emitError(unknownLoc, "could not find result <id> ") << id;
+ }
+
+ IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+ if (!outputIndexAttr) {
+ return emitError(unknownLoc,
+ "unable to read outputIndex value from constant op ")
+ << operands[1];
+ }
+ graphOutputs[outputIndexAttr.getInt()] = value;
+ return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+ // Create GraphOutputsARM instruction.
+ opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+ // Process OpGraphEndARM.
+ if (!operands.empty()) {
+ return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+ }
+
+ curBlock = nullptr;
+ curGraph = std::nullopt;
+ graphOutputs.clear();
+
+ LLVM_DEBUG({
+ logger.unindent();
+ logger.startLine()
+ << "//===-------------------------------------------===//\n";
+ });
+ return success();
+}
+
std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
return funcName;
}
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+ std::string graphName = nameMap.lookup(id).str();
+ if (graphName.empty()) {
+ graphName = "spirv_graph_" + std::to_string(id);
+ }
+ return graphName;
+}
+
std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
return op;
}
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+ auto graphConstIt = graphConstantMap.find(id);
+ if (graphConstIt == graphConstantMap.end())
+ return std::nullopt;
+ return graphConstIt->getSecond();
+}
+
LogicalResult
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processMatrixType(operands);
case spirv::Opcode::OpTypeTensorARM:
return processTensorARMType(operands);
+ case spirv::Opcode::OpTypeGraphARM:
+ return processGraphTypeARM(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2) {
+ return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+ "(result_id, num_inputs, (inout0_type, "
+ "inout1_type, ...))")
+ << size;
+ }
+ uint32_t numInputs = operands[1];
+ SmallVector<Type, 1> argTypes;
+ SmallVector<Type, 1> returnTypes;
+ for (unsigned i = 2; i < size; i++) {
+ Type inOutTy = getType(operands[i]);
+ if (!inOutTy) {
+ return emitError(unknownLoc,
+ "OpTypeGraphARM references undefined element type.")
+ << operands[i];
+ }
+ if (i - 2 >= numInputs) {
+ returnTypes.push_back(inOutTy);
+ } else {
+ argTypes.push_back(inOutTy);
+ }
+ }
+ typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
<< resultType;
}
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+ if (operands.size() < 3) {
+ return emitError(unknownLoc)
+ << "OpGraphConstantARM must have at least 2 operands";
+ }
+
+ Type resultType = getType(operands[0]);
+ if (!resultType) {
+ return emitError(unknownLoc, "undefined result type from <id> ")
+ << operands[0];
+ }
+
+ uint32_t resultID = operands[1];
+
+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
+ return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+ }
+
+ APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+ Type i32Ty = opBuilder.getIntegerType(32);
+ IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+ graphConstantMap.try_emplace(
+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+ if (!curGraph) {
+ return emitError(unknownLoc, "a graph block must appear inside a graph");
+ }
+
+ // We may have forward declared this block.
+ Block *block = getOrCreateBlock(graphID);
+ LLVM_DEBUG(logger.startLine()
+ << "[block] populating block " << block << "\n");
+ // If we have seen this block, make sure it was just a forward declaration.
+ assert(block->empty() && "re-deserialize the same block!");
+
+ opBuilder.setInsertionPointToStart(block);
+ blockMap[graphID] = curBlock = block;
+
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index db1cc3f8d79c2..6027f1ac94c23 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -106,6 +106,13 @@ struct SpecConstOperationMaterializationInfo {
SmallVector<uint32_t> enclosedOpOperands;
};
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+ Type resultType;
+ IntegerAttr graphConstantID;
+};
+
//===----------------------------------------------------------------------===//
// Deserializer Declaration
//===----------------------------------------------------------------------===//
@@ -211,9 +218,14 @@ class Deserializer {
/// exists; otherwise creates one based on the <id>.
std::string getFunctionSymbol(uint32_t id);
- /// Returns a symbol to be used for the specialization constant with the given
- /// result <id>. This tries to use the specialization constant's OpName if
+ /// Returns a symbol to be used for the graph name with the given
+ /// result <id>. This tries to use the graph's OpName if
/// exists; otherwise creates one based on the <id>.
+ std::string getGraphSymbol(uint32_t id);
+
+ /// Returns a symbol to be used for the specialization constant with the
+ /// given result <id>. This tries to use the specialization constant's
+ /// OpName if exists; otherwise creates one based on the <id>.
std::string getSpecConstantSymbol(uint32_t id);
/// Gets the specialization constant with the given result <id>.
@@ -237,6 +249,11 @@ class Deserializer {
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
TypedAttr defaultValue);
+ /// Gets the GraphConstantARM ID attribute and result type with the given
+ /// result <id>.
+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+ getGraphConstantARM(uint32_t id);
+
/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
/// defined at module scope and will be deserialized into a
@@ -306,6 +323,16 @@ class Deserializer {
LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+ LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+ LogicalResult processGraphEndARM(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
@@ -353,6 +380,10 @@ class Deserializer {
/// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
+ /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+ /// `operands`.
+ LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
//===--------------------------------------------------------------------===//
// Debug
//===--------------------------------------------------------------------===//
@@ -450,6 +481,9 @@ class Deserializer {
/// blocks declared as selection/loop headers are handled.
LogicalResult structurizeControlFlow();
+ /// Creates a block for graph with the given graphID.
+ LogicalResult createGraphBlock(uint32_t graphID);
+
//===--------------------------------------------------------------------===//
// Instruction
//===--------------------------------------------------------------------===//
@@ -546,6 +580,9 @@ class Deserializer {
/// The current function under construction.
std::optional<spirv::FuncOp> curFunction;
+ /// The current graph under construction.
+ std::optional<spirv::GraphARMOp> curGraph;
+
/// The current block under construction.
Block *curBlock = nullptr;
@@ -599,12 +636,19 @@ class Deserializer {
DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
specConstOperationMap;
+ // Result <id> to GraphConstantARM ID attribute and result type.
+ DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+ graphConstantMap;
+
// Result <id> to variable mapping.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
// Result <id> to function mapping.
DenseMap<uint32_t, spirv::FuncOp> funcMap;
+ // Result <id> to function mapping.
+ DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
// Result <id> to block mapping.
DenseMap<uint32_t, Block *> blockMap;
@@ -668,6 +712,9 @@ class Deserializer {
/// Deserialization options.
DeserializationOptions options;
+ /// List of IDs assigned to graph outputs.
+ SmallVector<Value> graphOutputs;
+
#ifndef NDEBUG
/// A logger used to emit information during the deserialzation process.
llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d62529b85b3aa..e9b180a70bb23 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -203,6 +203,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
return success();
}
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+ if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+ op.getGraphConstantIdAttr())) {
+ valueIDMap[op.getResult()] = resultID;
+ return success();
+ }
+ return failure();
+}
+
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
auto undefType = op.getType();
auto &id = undefValIDMap[undefType];
@@ -368,6 +378,118 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
return success();
}
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+ if (op.getNumResults() < 1) {
+ return op.emitError("cannot serialize graph with no return types");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() <...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/156665
More information about the Mlir-commits
mailing list