[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (PR #156665)

Davide Grohmann llvmlistbot at llvm.org
Thu Sep 11 03:03:43 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/156665

>From 3fd3a228e3df16ef131240b57aaabe696d3fa85e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 11 Aug 2025 13:37:34 +0200
Subject: [PATCH 1/2] [mlir][spirv] Add support for SPV_ARM_graph extension -
 part 2
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is the second patch to add support for the `SPV_ARM_graph` SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
`Graph` abstraction for expressing dataflow computations over full
resources.

The part 2 implementation includes:

- Serialization and deserialization support for:
  - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM`
  - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM`
- Tests covering binary round-tripping.

Graphs currently support only `SPV_ARM_tensors`, but are designed to
generalize to other resource types, such as images.

Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I88a5ff0298e0d30f649798111785ea984db56515
---
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  22 ++
 .../SPIRV/Deserialization/Deserializer.cpp    | 287 ++++++++++++++++++
 .../SPIRV/Deserialization/Deserializer.h      |  51 +++-
 .../SPIRV/Serialization/SerializeOps.cpp      | 122 ++++++++
 .../Target/SPIRV/Serialization/Serializer.cpp |  82 ++++-
 .../Target/SPIRV/Serialization/Serializer.h   |  39 ++-
 mlir/test/Target/SPIRV/graph-ops.mlir         |  25 ++
 7 files changed, 621 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Target/SPIRV/graph-ops.mlir

diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index ee18cf815e4a7..4c49ec868bbc8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
   if (auto undef = getUndefType(id)) {
     return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
   }
+  if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+          graphConstantARMInfo = getGraphConstantARM(id)) {
+    IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
+    Type resultType = graphConstantARMInfo->resultType;
+    return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
+                                                       graphConstantID);
+  }
   return valueMap.lookup(id);
 }
 
@@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
   case spirv::Opcode::OpTypeTensorARM:
+  case spirv::Opcode::OpTypeGraphARM:
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantNull:
     return processConstantNull(operands);
+  case spirv::Opcode::OpGraphConstantARM:
+    return processGraphConstantARM(operands);
   case spirv::Opcode::OpDecorate:
     return processDecoration(operands);
   case spirv::Opcode::OpMemberDecorate:
     return processMemberDecoration(operands);
   case spirv::Opcode::OpFunction:
     return processFunction(operands);
+  case spirv::Opcode::OpGraphEntryPointARM:
+    if (deferInstructions) {
+      deferredInstructions.emplace_back(opcode, operands);
+      return success();
+    }
+    return processGraphEntryPointARM(operands);
+  case spirv::Opcode::OpGraphARM:
+    return processGraphARM(operands);
+  case spirv::Opcode::OpGraphSetOutputARM:
+    return processOpGraphSetOutputARM(operands);
+  case spirv::Opcode::OpGraphEndARM:
+    return processGraphEndARM(operands);
   case spirv::Opcode::OpLabel:
     return processLabel(operands);
   case spirv::Opcode::OpBranch:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3625dd2eb7dd3..37b5d348b0a1c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc,
+                     "missing graph defintion in OpGraphEntryPointARM");
+  }
+
+  unsigned wordIndex = 0;
+  uint32_t grID = operands[wordIndex++];
+  if (!graphMap.count(grID)) {
+    return emitError(unknownLoc,
+                     "missing graph definition/declaration with id ")
+           << grID;
+  }
+
+  spirv::GraphARMOp graphARM = graphMap[grID];
+  StringRef name = decodeStringLiteral(operands, wordIndex);
+  graphARM.setSymName(name);
+  graphARM.setEntryPoint(true);
+
+  SmallVector<Attribute, 4> interface;
+  for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+    if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+      interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+    } else {
+      return emitError(unknownLoc, "undefined result <id> ")
+             << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+    }
+  }
+
+  // RAII guard to reset the insertion point to previous value when done.
+  OpBuilder::InsertionGuard insertionGuard(opBuilder);
+  opBuilder.setInsertionPoint(graphARM);
+  opBuilder.create<spirv::GraphEntryPointARMOp>(
+      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+      opBuilder.getArrayAttr(interface));
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+  if (curGraph) {
+    return emitError(unknownLoc, "found graph inside graph");
+  }
+  // Get the result type.
+  if (operands.size() < 2) {
+    return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+  }
+
+  Type type = getType(operands[0]);
+  if (!type || !isa<GraphType>(type)) {
+    return emitError(unknownLoc, "unknown graph type from <id> ")
+           << operands[0];
+  }
+  auto graphType = cast<GraphType>(type);
+  if (graphType.getNumResults() <= 0) {
+    return emitError(unknownLoc, "expected at least one result");
+  }
+
+  uint32_t grID = operands[1];
+  if (graphMap.count(grID)) {
+    return emitError(unknownLoc, "duplicate graph definition/declaration");
+  }
+
+  std::string grName = getGraphSymbol(grID);
+  auto graphOp =
+      opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+  curGraph = graphMap[grID] = graphOp;
+  Block *entryBlock = graphOp.addEntryBlock();
+  LLVM_DEBUG({
+    logger.startLine()
+        << "//===-------------------------------------------===//\n";
+    logger.startLine() << "[graph] name: " << grName << "\n";
+    logger.startLine() << "[graph] type: " << graphType << "\n";
+    logger.startLine() << "[graph] ID: " << grID << "\n";
+    logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+    logger.indent();
+  });
+
+  // Parse the op argument instructions.
+  for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+    spirv::Opcode opcode;
+    ArrayRef<uint32_t> operands;
+    if (failed(sliceInstruction(opcode, operands,
+                                spirv::Opcode::OpGraphInputARM))) {
+      return failure();
+    }
+    if (operands.size() != 3) {
+      return emitError(unknownLoc, "expected result type, result <id> and "
+                                   "input index for OpGraphInputARM");
+    }
+
+    Type argDefinedType = getType(operands[0]);
+    if (!argDefinedType) {
+      return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+    }
+
+    if (argDefinedType != argType) {
+      return emitError(unknownLoc,
+                       "mismatch in argument type between graph type "
+                       "definition ")
+             << graphType << " and argument type definition " << argDefinedType
+             << " at argument " << index;
+    }
+    if (getValue(operands[1])) {
+      return emitError(unknownLoc, "duplicate definition of result <id> ")
+             << operands[1];
+    }
+
+    IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+    if (!inputIndexAttr) {
+      return emitError(unknownLoc,
+                       "unable to read inputIndex value from constant op ")
+             << operands[2];
+    }
+    BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+    valueMap[operands[1]] = argValue;
+  }
+
+  graphOutputs.resize(graphType.getNumResults());
+
+  // RAII guard to reset the insertion point to the module's region after
+  // deserializing the body of this function.
+  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+  blockMap[grID] = entryBlock;
+  if (failed(createGraphBlock(grID))) {
+    return failure();
+  }
+
+  // Process all the instructions in the graph until and including
+  // OpGraphEndARM.
+  spirv::Opcode opcode;
+  ArrayRef<uint32_t> instOperands;
+  do {
+    if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+      return failure();
+    }
+
+    if (failed(processInstruction(opcode, instOperands))) {
+      return failure();
+    }
+  } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 2) {
+    return emitError(
+        unknownLoc,
+        "expected value id and output index for OpGraphSetOutputARM");
+  }
+
+  uint32_t id = operands[0];
+  Value value = getValue(id);
+  if (!value) {
+    return emitError(unknownLoc, "could not find result <id> ") << id;
+  }
+
+  IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+  if (!outputIndexAttr) {
+    return emitError(unknownLoc,
+                     "unable to read outputIndex value from constant op ")
+           << operands[1];
+  }
+  graphOutputs[outputIndexAttr.getInt()] = value;
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+  // Create GraphOutputsARM instruction.
+  opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+  // Process OpGraphEndARM.
+  if (!operands.empty()) {
+    return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+  }
+
+  curBlock = nullptr;
+  curGraph = std::nullopt;
+  graphOutputs.clear();
+
+  LLVM_DEBUG({
+    logger.unindent();
+    logger.startLine()
+        << "//===-------------------------------------------===//\n";
+  });
+  return success();
+}
+
 std::optional<std::pair<Attribute, Type>>
 spirv::Deserializer::getConstant(uint32_t id) {
   auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
   return funcName;
 }
 
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+  std::string graphName = nameMap.lookup(id).str();
+  if (graphName.empty()) {
+    graphName = "spirv_graph_" + std::to_string(id);
+  }
+  return graphName;
+}
+
 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
   auto constName = nameMap.lookup(id).str();
   if (constName.empty()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
   return op;
 }
 
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+  auto graphConstIt = graphConstantMap.find(id);
+  if (graphConstIt == graphConstantMap.end())
+    return std::nullopt;
+  return graphConstIt->getSecond();
+}
+
 LogicalResult
 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
   unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processMatrixType(operands);
   case spirv::Opcode::OpTypeTensorARM:
     return processTensorARMType(operands);
+  case spirv::Opcode::OpTypeGraphARM:
+    return processGraphTypeARM(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+  unsigned size = operands.size();
+  if (size < 2) {
+    return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+                                 "(result_id, num_inputs, (inout0_type, "
+                                 "inout1_type, ...))")
+           << size;
+  }
+  uint32_t numInputs = operands[1];
+  SmallVector<Type, 1> argTypes;
+  SmallVector<Type, 1> returnTypes;
+  for (unsigned i = 2; i < size; i++) {
+    Type inOutTy = getType(operands[i]);
+    if (!inOutTy) {
+      return emitError(unknownLoc,
+                       "OpTypeGraphARM references undefined element type.")
+             << operands[i];
+    }
+    if (i - 2 >= numInputs) {
+      returnTypes.push_back(inOutTy);
+    } else {
+      argTypes.push_back(inOutTy);
+    }
+  }
+  typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2)
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
          << resultType;
 }
 
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 3) {
+    return emitError(unknownLoc)
+           << "OpGraphConstantARM must have at least 2 operands";
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  uint32_t resultID = operands[1];
+
+  if (!dyn_cast<spirv::TensorArmType>(resultType)) {
+    return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+  }
+
+  APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+  Type i32Ty = opBuilder.getIntegerType(32);
+  IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+  graphConstantMap.try_emplace(
+      resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Control flow
 //===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+  if (!curGraph) {
+    return emitError(unknownLoc, "a graph block must appear inside a graph");
+  }
+
+  // We may have forward declared this block.
+  Block *block = getOrCreateBlock(graphID);
+  LLVM_DEBUG(logger.startLine()
+             << "[block] populating block " << block << "\n");
+  // If we have seen this block, make sure it was just a forward declaration.
+  assert(block->empty() && "re-deserialize the same block!");
+
+  opBuilder.setInsertionPointToStart(block);
+  blockMap[graphID] = curBlock = block;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
   if (!curBlock) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index db1cc3f8d79c2..6027f1ac94c23 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -106,6 +106,13 @@ struct SpecConstOperationMaterializationInfo {
   SmallVector<uint32_t> enclosedOpOperands;
 };
 
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+  Type resultType;
+  IntegerAttr graphConstantID;
+};
+
 //===----------------------------------------------------------------------===//
 // Deserializer Declaration
 //===----------------------------------------------------------------------===//
@@ -211,9 +218,14 @@ class Deserializer {
   /// exists; otherwise creates one based on the <id>.
   std::string getFunctionSymbol(uint32_t id);
 
-  /// Returns a symbol to be used for the specialization constant with the given
-  /// result <id>. This tries to use the specialization constant's OpName if
+  /// Returns a symbol to be used for the graph name with the given
+  /// result <id>. This tries to use the graph's OpName if
   /// exists; otherwise creates one based on the <id>.
+  std::string getGraphSymbol(uint32_t id);
+
+  /// Returns a symbol to be used for the specialization constant with the
+  /// given result <id>. This tries to use the specialization constant's
+  /// OpName if exists; otherwise creates one based on the <id>.
   std::string getSpecConstantSymbol(uint32_t id);
 
   /// Gets the specialization constant with the given result <id>.
@@ -237,6 +249,11 @@ class Deserializer {
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
                                            TypedAttr defaultValue);
 
+  /// Gets the GraphConstantARM ID attribute and result type with the given
+  /// result <id>.
+  std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+  getGraphConstantARM(uint32_t id);
+
   /// Processes the OpVariable instructions at current `offset` into `binary`.
   /// It is expected that this method is used for variables that are to be
   /// defined at module scope and will be deserialized into a
@@ -306,6 +323,16 @@ class Deserializer {
 
   LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphEndARM(ArrayRef<uint32_t> operands);
+
   LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
 
   //===--------------------------------------------------------------------===//
@@ -353,6 +380,10 @@ class Deserializer {
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+  /// `operands`.
+  LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
   //===--------------------------------------------------------------------===//
   // Debug
   //===--------------------------------------------------------------------===//
@@ -450,6 +481,9 @@ class Deserializer {
   /// blocks declared as selection/loop headers are handled.
   LogicalResult structurizeControlFlow();
 
+  /// Creates a block for graph with the given graphID.
+  LogicalResult createGraphBlock(uint32_t graphID);
+
   //===--------------------------------------------------------------------===//
   // Instruction
   //===--------------------------------------------------------------------===//
@@ -546,6 +580,9 @@ class Deserializer {
   /// The current function under construction.
   std::optional<spirv::FuncOp> curFunction;
 
+  /// The current graph under construction.
+  std::optional<spirv::GraphARMOp> curGraph;
+
   /// The current block under construction.
   Block *curBlock = nullptr;
 
@@ -599,12 +636,19 @@ class Deserializer {
   DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
       specConstOperationMap;
 
+  // Result <id> to GraphConstantARM ID attribute and result type.
+  DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+      graphConstantMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
   // Result <id> to function mapping.
   DenseMap<uint32_t, spirv::FuncOp> funcMap;
 
+  // Result <id> to function mapping.
+  DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
   // Result <id> to block mapping.
   DenseMap<uint32_t, Block *> blockMap;
 
@@ -668,6 +712,9 @@ class Deserializer {
   /// Deserialization options.
   DeserializationOptions options;
 
+  /// List of IDs assigned to graph outputs.
+  SmallVector<Value> graphOutputs;
+
 #ifndef NDEBUG
   /// A logger used to emit information during the deserialzation process.
   llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index d62529b85b3aa..e9b180a70bb23 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -203,6 +203,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
   return success();
 }
 
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+  if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+                                                 op.getGraphConstantIdAttr())) {
+    valueIDMap[op.getResult()] = resultID;
+    return success();
+  }
+  return failure();
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -368,6 +378,118 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   return success();
 }
 
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+  if (op.getNumResults() < 1) {
+    return op.emitError("cannot serialize graph with no return types");
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
+  assert(functionHeader.empty() && functionBody.empty());
+
+  uint32_t funcID = getOrCreateFunctionID(op.getName());
+  uint32_t fnTypeID = 0;
+  // Generate type of the function.
+  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
+    return failure();
+  encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
+                        {fnTypeID, funcID});
+
+  // Declare the parameters.
+  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+    uint32_t argTypeID = 0;
+    SmallVector<uint32_t, 3> inputOperands;
+
+    if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+      return failure();
+    }
+
+    uint32_t argValueID = getNextID();
+    valueIDMap[arg] = argValueID;
+
+    auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+    uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+    inputOperands.push_back(argTypeID);
+    inputOperands.push_back(argValueID);
+    inputOperands.push_back(indexID);
+
+    encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
+                          inputOperands);
+  }
+
+  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
+    return failure();
+  if (failed(visitInPrettyBlockOrder(
+          &op.front(), [&](Block *block) { return processBlock(block); },
+          /*skipHeader=*/true))) {
+    return failure();
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
+                          << "' --\n");
+  // Insert OpGraphEndARM.
+  encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
+
+  llvm::append_range(graphs, functionHeader);
+  llvm::append_range(graphs, functionBody);
+  functionHeader.clear();
+  functionBody.clear();
+
+  return success();
+}
+
+LogicalResult
+Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
+  SmallVector<uint32_t, 4> operands;
+  StringRef graph = op.getFn();
+  // Add the graph <id>.
+  uint32_t graphID = getOrCreateFunctionID(graph);
+  operands.push_back(graphID);
+  // Add the name of the graph.
+  spirv::encodeStringLiteralInto(operands, graph);
+
+  // Add the interface values.
+  if (ArrayAttr interface = op.getInterface()) {
+    for (Attribute var : interface.getValue()) {
+      StringRef value = cast<FlatSymbolRefAttr>(var).getValue();
+      if (uint32_t id = getVariableID(value)) {
+        operands.push_back(id);
+      } else {
+        return op.emitError(
+            "referencing undefined global variable."
+            "spirv.GraphEntryPointARM is at the end of spirv.module. All "
+            "referenced variables should already be defined");
+      }
+    }
+  }
+  encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
+  return success();
+}
+
+LogicalResult
+Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
+  for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
+    SmallVector<uint32_t, 2> outputOperands;
+
+    Type resType = value.getType();
+    uint32_t resTypeID = 0;
+    if (failed(processType(op.getLoc(), resType, resTypeID))) {
+      return failure();
+    }
+
+    uint32_t outputID = getValueID(value);
+    auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+    uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+    outputOperands.push_back(outputID);
+    outputOperands.push_back(indexID);
+
+    encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
+                          outputOperands);
+  }
+  return success();
+}
+
 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
   SmallVector<uint32_t, 4> operands;
   SmallVector<StringRef, 2> elidedAttrs;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 7fc779587f4f1..b4be8de670906 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -136,7 +136,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
                     extensions.size() + extendedSets.size() +
                     memoryModel.size() + entryPoints.size() +
                     executionModes.size() + decorations.size() +
-                    typesGlobalValues.size() + functions.size();
+                    typesGlobalValues.size() + functions.size() + graphs.size();
 
   binary.clear();
   binary.reserve(moduleSize);
@@ -154,6 +154,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
   binary.append(decorations.begin(), decorations.end());
   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
   binary.append(functions.begin(), functions.end());
+  binary.append(graphs.begin(), graphs.end());
 }
 
 #ifndef NDEBUG
@@ -509,6 +510,9 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
   if ((isa<FunctionType>(type) &&
        succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
                                      operands))) ||
+      (isa<GraphType>(type) &&
+       succeeded(
+           prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
                                  deferSerialization, serializationCtx))) {
     if (deferSerialization)
@@ -539,7 +543,7 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
     return success();
   }
 
-  return failure();
+  return emitError(loc, "failed to process type: ") << type;
 }
 
 LogicalResult Serializer::prepareBasicType(
@@ -875,6 +879,35 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
   return success();
 }
 
+LogicalResult
+Serializer::prepareGraphType(Location loc, GraphType type,
+                             spirv::Opcode &typeEnum,
+                             SmallVectorImpl<uint32_t> &operands) {
+  typeEnum = spirv::Opcode::OpTypeGraphARM;
+  assert(type.getNumResults() >= 1 &&
+         "serialization requires at least a return value");
+
+  operands.push_back(type.getNumInputs());
+
+  for (const Type &res : type.getInputs()) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(loc, res, argTypeID))) {
+      return failure();
+    }
+    operands.push_back(argTypeID);
+  }
+
+  for (const Type &res : type.getResults()) {
+    uint32_t resultID = 0;
+    if (failed(processType(loc, res, resultID))) {
+      return failure();
+    }
+    operands.push_back(resultID);
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//
@@ -1135,6 +1168,41 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
   return resultID;
 }
 
+uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
+                                            IntegerAttr intAttr) {
+  // De-duplicate graph constants.
+  if (uint32_t id = getGraphConstantARMId(intAttr)) {
+    return id;
+  }
+
+  // Process the type for this graph constant.
+  uint32_t typeID = 0;
+  if (failed(processType(loc, graphConstType, typeID))) {
+    return 0;
+  }
+
+  uint32_t resultID = getNextID();
+  APInt value = intAttr.getValue();
+  unsigned bitwidth = value.getBitWidth();
+  if (bitwidth > 32) {
+    emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
+        << bitwidth << " bits";
+    return 0;
+  }
+  bool isSigned = value.isSignedIntN(bitwidth);
+
+  uint32_t word = 0;
+  if (isSigned) {
+    word = static_cast<int32_t>(value.getSExtValue());
+  } else {
+    word = static_cast<uint32_t>(value.getZExtValue());
+  }
+  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
+                        {typeID, resultID, word});
+  graphConstIDMap[intAttr] = resultID;
+  return resultID;
+}
+
 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
                                        bool isSpec) {
   if (!isSpec) {
@@ -1469,9 +1537,19 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
         return processConstantCompositeReplicateOp(op);
       })
       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
+      .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
+      .Case([&](spirv::GraphEntryPointARMOp op) {
+        return processGraphEntryPointARMOp(op);
+      })
+      .Case([&](spirv::GraphOutputsARMOp op) {
+        return processGraphOutputsARMOp(op);
+      })
       .Case([&](spirv::GlobalVariableOp op) {
         return processGlobalVariableOp(op);
       })
+      .Case([&](spirv::GraphConstantARMOp op) {
+        return processGraphConstantARMOp(op);
+      })
       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index fb2cecdff8e43..add372b19b5af 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -122,6 +122,8 @@ class Serializer {
   LogicalResult
   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
 
+  LogicalResult processGraphConstantARMOp(spirv::GraphConstantARMOp op);
+
   /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -135,6 +137,15 @@ class Serializer {
   LogicalResult processFuncOp(spirv::FuncOp op);
   LogicalResult processFuncParameter(spirv::FuncOp op);
 
+  /// Processes a SPIR-V GraphARM op.
+  LogicalResult processGraphARMOp(spirv::GraphARMOp op);
+
+  /// Processes a SPIR-V GraphEntryPointARM op.
+  LogicalResult processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op);
+
+  /// Processes a SPIR-V GraphOutputsARMOp op.
+  LogicalResult processGraphOutputsARMOp(spirv::GraphOutputsARMOp op);
+
   LogicalResult processVariableOp(spirv::VariableOp op);
 
   /// Process a SPIR-V GlobalVariableOp
@@ -189,6 +200,10 @@ class Serializer {
                                     spirv::Opcode &typeEnum,
                                     SmallVectorImpl<uint32_t> &operands);
 
+  LogicalResult prepareGraphType(Location loc, GraphType type,
+                                 spirv::Opcode &typeEnum,
+                                 SmallVectorImpl<uint32_t> &operands);
+
   //===--------------------------------------------------------------------===//
   // Constant
   //===--------------------------------------------------------------------===//
@@ -238,6 +253,13 @@ class Serializer {
   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
                               bool isSpec = false);
 
+  uint32_t getGraphConstantARMId(Attribute value) const {
+    return graphConstIDMap.lookup(value);
+  }
+
+  uint32_t prepareGraphConstantId(Location loc, Type graphConstType,
+                                  IntegerAttr intAttr);
+
   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
                              bool isSpec = false);
 
@@ -372,6 +394,7 @@ class Serializer {
   SmallVector<uint32_t, 0> decorations;
   SmallVector<uint32_t, 0> typesGlobalValues;
   SmallVector<uint32_t, 0> functions;
+  SmallVector<uint32_t, 0> graphs;
 
   /// Recursive struct references are serialized as OpTypePointer instructions
   /// to the recursive struct type. However, the OpTypePointer instruction
@@ -388,15 +411,22 @@ class Serializer {
       recursiveStructInfos;
 
   /// `functionHeader` contains all the instructions that must be in the first
-  /// block in the function, and `functionBody` contains the rest. After
-  /// processing FuncOp, the encoded instructions of a function are appended to
-  /// `functions`. An example of instructions in `functionHeader` in order:
+  /// block in the function or graph, and `functionBody` contains the rest.
+  /// After processing FuncOp/GraphARMOp, the encoded instructions of a function
+  /// or graph are appended to `functions` or `graphs` respectively. Examples of
+  /// instructions in `functionHeader` in order:
+  ///
+  /// For a FuncOp:
   /// OpFunction ...
   /// OpFunctionParameter ...
   /// OpFunctionParameter ...
   /// OpLabel ...
   /// OpVariable ...
   /// OpVariable ...
+  ///
+  /// For a GraphARMOp
+  /// OpGraphARM ...
+  /// OpGraphInputARM ...
   SmallVector<uint32_t, 0> functionHeader;
   SmallVector<uint32_t, 0> functionBody;
 
@@ -412,6 +442,9 @@ class Serializer {
   /// Map from specialization constant names to their <id>s.
   llvm::StringMap<uint32_t> specConstIDMap;
 
+  /// Map from graph constant ID value to their <id>s.
+  DenseMap<Attribute, uint32_t> graphConstIDMap;
+
   /// Map from GlobalVariableOps name to <id>s.
   llvm::StringMap<uint32_t> globalVarIDMap;
 
diff --git a/mlir/test/Target/SPIRV/graph-ops.mlir b/mlir/test/Target/SPIRV/graph-ops.mlir
new file mode 100644
index 0000000000000..c956157bfa6c1
--- /dev/null
+++ b/mlir/test/Target/SPIRV/graph-ops.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %}
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, Int8, Int16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]> {
+  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+  }
+
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = false} {
+  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+}

>From d32b486de720b509bdefec721312df8d98c3214a Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 11 Sep 2025 11:01:50 +0200
Subject: [PATCH 2/2] fix code review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I7028dbe4bbcd644f82369785614b16d990a6f11c
---
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  4 +--
 .../SPIRV/Deserialization/Deserializer.cpp    | 28 +++++++++----------
 .../Target/SPIRV/Serialization/Serializer.cpp | 14 ++++------
 3 files changed, 22 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 4c49ec868bbc8..c27f9aa91332c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -90,8 +90,8 @@ Value spirv::Deserializer::getValue(uint32_t id) {
           graphConstantARMInfo = getGraphConstantARM(id)) {
     IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
     Type resultType = graphConstantARMInfo->resultType;
-    return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
-                                                       graphConstantID);
+    return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
+                                             graphConstantID);
   }
   return valueMap.lookup(id);
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 37b5d348b0a1c..54ea7779e2fa9 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -677,14 +677,14 @@ spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
   }
 
   unsigned wordIndex = 0;
-  uint32_t grID = operands[wordIndex++];
-  if (!graphMap.count(grID)) {
+  uint32_t graphID = operands[wordIndex++];
+  if (!graphMap.contains(graphID)) {
     return emitError(unknownLoc,
                      "missing graph definition/declaration with id ")
-           << grID;
+           << graphID;
   }
 
-  spirv::GraphARMOp graphARM = graphMap[grID];
+  spirv::GraphARMOp graphARM = graphMap[graphID];
   StringRef name = decodeStringLiteral(operands, wordIndex);
   graphARM.setSymName(name);
   graphARM.setEntryPoint(true);
@@ -729,22 +729,22 @@ spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
     return emitError(unknownLoc, "expected at least one result");
   }
 
-  uint32_t grID = operands[1];
-  if (graphMap.count(grID)) {
+  uint32_t graphID = operands[1];
+  if (graphMap.count(graphID)) {
     return emitError(unknownLoc, "duplicate graph definition/declaration");
   }
 
-  std::string grName = getGraphSymbol(grID);
+  std::string graphName = getGraphSymbol(graphID);
   auto graphOp =
-      opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
-  curGraph = graphMap[grID] = graphOp;
+      opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
+  curGraph = graphMap[graphID] = graphOp;
   Block *entryBlock = graphOp.addEntryBlock();
   LLVM_DEBUG({
     logger.startLine()
         << "//===-------------------------------------------===//\n";
-    logger.startLine() << "[graph] name: " << grName << "\n";
+    logger.startLine() << "[graph] name: " << graphName << "\n";
     logger.startLine() << "[graph] type: " << graphType << "\n";
-    logger.startLine() << "[graph] ID: " << grID << "\n";
+    logger.startLine() << "[graph] ID: " << graphID << "\n";
     logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
     logger.indent();
   });
@@ -795,8 +795,8 @@ spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
   // deserializing the body of this function.
   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
 
-  blockMap[grID] = entryBlock;
-  if (failed(createGraphBlock(grID))) {
+  blockMap[graphID] = entryBlock;
+  if (failed(createGraphBlock(graphID))) {
     return failure();
   }
 
@@ -1535,7 +1535,7 @@ spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
   uint32_t numInputs = operands[1];
   SmallVector<Type, 1> argTypes;
   SmallVector<Type, 1> returnTypes;
-  for (unsigned i = 2; i < size; i++) {
+  for (unsigned i = 2; i < size; ++i) {
     Type inOutTy = getType(operands[i]);
     if (!inOutTy) {
       return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b4be8de670906..b56e7788625f5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -889,20 +889,18 @@ Serializer::prepareGraphType(Location loc, GraphType type,
 
   operands.push_back(type.getNumInputs());
 
-  for (const Type &res : type.getInputs()) {
+  for (Type argType : type.getInputs()) {
     uint32_t argTypeID = 0;
-    if (failed(processType(loc, res, argTypeID))) {
+    if (failed(processType(loc, argType, argTypeID)))
       return failure();
-    }
     operands.push_back(argTypeID);
   }
 
-  for (const Type &res : type.getResults()) {
-    uint32_t resultID = 0;
-    if (failed(processType(loc, res, resultID))) {
+  for (Type resType : type.getResults()) {
+    uint32_t resTypeID = 0;
+    if (failed(processType(loc, resType, resTypeID)))
       return failure();
-    }
-    operands.push_back(resultID);
+    operands.push_back(resTypeID);
   }
 
   return success();



More information about the Mlir-commits mailing list