[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (PR #156665)

Jakub Kuderski llvmlistbot at llvm.org
Wed Sep 10 09:35:01 PDT 2025


================
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc,
+                     "missing graph defintion in OpGraphEntryPointARM");
+  }
+
+  unsigned wordIndex = 0;
+  uint32_t grID = operands[wordIndex++];
+  if (!graphMap.count(grID)) {
+    return emitError(unknownLoc,
+                     "missing graph definition/declaration with id ")
+           << grID;
+  }
+
+  spirv::GraphARMOp graphARM = graphMap[grID];
+  StringRef name = decodeStringLiteral(operands, wordIndex);
+  graphARM.setSymName(name);
+  graphARM.setEntryPoint(true);
+
+  SmallVector<Attribute, 4> interface;
+  for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
+    if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
+      interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+    } else {
+      return emitError(unknownLoc, "undefined result <id> ")
+             << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+    }
+  }
+
+  // RAII guard to reset the insertion point to previous value when done.
+  OpBuilder::InsertionGuard insertionGuard(opBuilder);
+  opBuilder.setInsertionPoint(graphARM);
+  opBuilder.create<spirv::GraphEntryPointARMOp>(
+      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+      opBuilder.getArrayAttr(interface));
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+  if (curGraph) {
+    return emitError(unknownLoc, "found graph inside graph");
+  }
+  // Get the result type.
+  if (operands.size() < 2) {
+    return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+  }
+
+  Type type = getType(operands[0]);
+  if (!type || !isa<GraphType>(type)) {
+    return emitError(unknownLoc, "unknown graph type from <id> ")
+           << operands[0];
+  }
+  auto graphType = cast<GraphType>(type);
+  if (graphType.getNumResults() <= 0) {
+    return emitError(unknownLoc, "expected at least one result");
+  }
+
+  uint32_t grID = operands[1];
+  if (graphMap.count(grID)) {
+    return emitError(unknownLoc, "duplicate graph definition/declaration");
+  }
+
+  std::string grName = getGraphSymbol(grID);
+  auto graphOp =
+      opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+  curGraph = graphMap[grID] = graphOp;
+  Block *entryBlock = graphOp.addEntryBlock();
+  LLVM_DEBUG({
+    logger.startLine()
+        << "//===-------------------------------------------===//\n";
+    logger.startLine() << "[graph] name: " << grName << "\n";
+    logger.startLine() << "[graph] type: " << graphType << "\n";
+    logger.startLine() << "[graph] ID: " << grID << "\n";
+    logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+    logger.indent();
+  });
+
+  // Parse the op argument instructions.
+  for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
+    spirv::Opcode opcode;
+    ArrayRef<uint32_t> operands;
+    if (failed(sliceInstruction(opcode, operands,
+                                spirv::Opcode::OpGraphInputARM))) {
+      return failure();
+    }
+    if (operands.size() != 3) {
+      return emitError(unknownLoc, "expected result type, result <id> and "
+                                   "input index for OpGraphInputARM");
+    }
+
+    Type argDefinedType = getType(operands[0]);
+    if (!argDefinedType) {
+      return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
+    }
+
+    if (argDefinedType != argType) {
+      return emitError(unknownLoc,
+                       "mismatch in argument type between graph type "
+                       "definition ")
+             << graphType << " and argument type definition " << argDefinedType
+             << " at argument " << index;
+    }
+    if (getValue(operands[1])) {
+      return emitError(unknownLoc, "duplicate definition of result <id> ")
+             << operands[1];
+    }
+
+    IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
+    if (!inputIndexAttr) {
+      return emitError(unknownLoc,
+                       "unable to read inputIndex value from constant op ")
+             << operands[2];
+    }
+    BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
+    valueMap[operands[1]] = argValue;
+  }
+
+  graphOutputs.resize(graphType.getNumResults());
+
+  // RAII guard to reset the insertion point to the module's region after
+  // deserializing the body of this function.
+  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+  blockMap[grID] = entryBlock;
+  if (failed(createGraphBlock(grID))) {
+    return failure();
+  }
+
+  // Process all the instructions in the graph until and including
+  // OpGraphEndARM.
+  spirv::Opcode opcode;
+  ArrayRef<uint32_t> instOperands;
+  do {
+    if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+      return failure();
+    }
+
+    if (failed(processInstruction(opcode, instOperands))) {
+      return failure();
+    }
+  } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 2) {
+    return emitError(
+        unknownLoc,
+        "expected value id and output index for OpGraphSetOutputARM");
+  }
+
+  uint32_t id = operands[0];
+  Value value = getValue(id);
+  if (!value) {
+    return emitError(unknownLoc, "could not find result <id> ") << id;
+  }
+
+  IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
+  if (!outputIndexAttr) {
+    return emitError(unknownLoc,
+                     "unable to read outputIndex value from constant op ")
+           << operands[1];
+  }
+  graphOutputs[outputIndexAttr.getInt()] = value;
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
+  // Create GraphOutputsARM instruction.
+  opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
----------------
kuhar wrote:

also here

https://github.com/llvm/llvm-project/pull/156665


More information about the Mlir-commits mailing list