[Mlir-commits] [mlir] [mlir][spirv] Add SPIR-V NonSemantic.Graph.DebugInfo (PR #199519)

Davide Grohmann llvmlistbot at llvm.org
Mon May 25 04:49:13 PDT 2026


https://github.com/davidegrohmann created https://github.com/llvm/llvm-project/pull/199519

Add serialization and deserialization support for the NonSemantic.Graph.DebugInfo.1 extended instruction set used with ARM graph modules.

Definition:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/nonsemantic/NonSemantic.Graph.DebugInfo.asciidoc

Serialize DebugGraph, DebugOperation, and DebugTensor records when debug info is enabled. DebugOperation now references the DebugGraph result id, and debug graph records are emitted before operations that reference them. DebugTensor records cover graph inputs, graph outputs, and tensor-typed spirv.Constant results.

Deserialize the debug records back into MLIR locations for graphs, TOSA operation results, tensors, and materialized tensor constants. Avoid default-inserting empty Values while applying debug records, and diagnose undefined debug ids instead.

Add round-trip lit coverage for FileLineColLoc, NameLoc, FusedLoc, grouped TOSA ops, and missing extension diagnostics. Add a binary-level serialization unit test for DebugGraph/DebugOperation references, ordering, and DebugTensor records for inputs, outputs, and constants.

>From a84778a2a835cca3f28c6cb5f9a07e77854bd4c8 Mon Sep 17 00:00:00 2001
From: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Date: Tue, 10 Mar 2026 11:00:32 +0000
Subject: [PATCH] [mlir][spirv] Add SPIR-V NonSemantic.Graph.DebugInfo

Add serialization and deserialization support for the
NonSemantic.Graph.DebugInfo.1 extended instruction set used with ARM
graph modules.

Definition:
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/nonsemantic/NonSemantic.Graph.DebugInfo.asciidoc

Serialize DebugGraph, DebugOperation, and DebugTensor records when
debug info is enabled. DebugOperation now references the DebugGraph
result id, and debug graph records are emitted before operations that
reference them. DebugTensor records cover graph inputs, graph outputs,
and tensor-typed spirv.Constant results.

Deserialize the debug records back into MLIR locations for graphs,
TOSA operation results, tensors, and materialized tensor
constants. Avoid default-inserting empty Values while applying debug
records, and diagnose undefined debug ids instead.

Add round-trip lit coverage for FileLineColLoc, NameLoc, FusedLoc,
grouped TOSA ops, and missing extension diagnostics. Add a
binary-level serialization unit test for DebugGraph/DebugOperation
references, ordering, and DebugTensor records for inputs, outputs, and
constants.

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: If71c026aa08b2bf9052eba35b173e1ce4498cb9d
---
 .../mlir/Target/SPIRV/SPIRVExtInstSets.h      |  40 ++++
 .../TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp   |   1 +
 .../SPIRV/Deserialization/DeserializeOps.cpp  | 110 +++++++++-
 .../SPIRV/Deserialization/Deserializer.h      |   8 +
 .../SPIRV/Serialization/SerializeOps.cpp      | 161 ++++++++++++++
 .../Target/SPIRV/Serialization/Serializer.cpp |  45 +++-
 .../Target/SPIRV/Serialization/Serializer.h   |  28 +++
 mlir/test/Target/SPIRV/debug-info.mlir        | 131 +++++++++++
 .../Dialect/SPIRV/DeserializationTest.cpp     |  46 ++++
 .../Dialect/SPIRV/SerializationTest.cpp       | 206 ++++++++++++++++++
 10 files changed, 769 insertions(+), 7 deletions(-)
 create mode 100644 mlir/include/mlir/Target/SPIRV/SPIRVExtInstSets.h
 create mode 100644 mlir/test/Target/SPIRV/debug-info.mlir

diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVExtInstSets.h b/mlir/include/mlir/Target/SPIRV/SPIRVExtInstSets.h
new file mode 100644
index 0000000000000..88ee32c5defb1
--- /dev/null
+++ b/mlir/include/mlir/Target/SPIRV/SPIRVExtInstSets.h
@@ -0,0 +1,40 @@
+//===- SPIRVExtInstSets.h - SPIR-V ext inst sets ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares extended instruction set constants used by SPIR-V
+// (de)serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_SPIRV_SPIRVEXTINSTSETS_H
+#define MLIR_TARGET_SPIRV_SPIRVEXTINSTSETS_H
+
+#include "llvm/ADT/StringRef.h"
+#include <cstdint>
+
+namespace mlir {
+namespace spirv {
+
+/// Extension set name for TOSA ops.
+constexpr StringLiteral extTosa("TOSA.001000.1");
+
+/// Extension set name for non-semantic graph debug info.
+constexpr StringLiteral extDebugInfo("NonSemantic.Graph.DebugInfo.1");
+
+/// Instruction opcodes in the NonSemantic.Graph.DebugInfo.1 extended
+/// instruction set.
+enum class GraphDebugInfoExtInst : uint32_t {
+  DebugGraph = 1,
+  DebugOperation = 2,
+  DebugTensor = 3,
+};
+
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_TARGET_SPIRV_SPIRVEXTINSTSETS_H
diff --git a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
index bef30e84b3289..b5ae5c0275b26 100644
--- a/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
+++ b/mlir/lib/Conversion/TosaToSPIRVTosa/TosaToSPIRVTosaPass.cpp
@@ -51,6 +51,7 @@ spirv::VerCapExtAttr getDefaultVerCapExtAttr(MLIRContext *context) {
           spirv::Extension::SPV_EXT_replicated_composites,
           spirv::Extension::SPV_KHR_bfloat16,
           spirv::Extension::SPV_EXT_float8,
+          spirv::Extension::SPV_KHR_non_semantic_info,
       },
       context);
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index f65b559ed1369..54cd47716a254 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Target/SPIRV/SPIRVExtInstSets.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
@@ -35,6 +36,12 @@ static inline spirv::Opcode extractOpcode(uint32_t word) {
   return static_cast<spirv::Opcode>(word & 0xffff);
 }
 
+/// Returns a NameLoc location from the given debug info string.
+static inline NameLoc getLocFromDebugInfoString(OpBuilder &builder,
+                                                StringRef source) {
+  return NameLoc::get(builder.getStringAttr(source));
+}
+
 //===----------------------------------------------------------------------===//
 // Instruction
 //===----------------------------------------------------------------------===//
@@ -42,7 +49,10 @@ static inline spirv::Opcode extractOpcode(uint32_t word) {
 Value spirv::Deserializer::getValue(uint32_t id) {
   if (auto constInfo = getConstant(id)) {
     // Materialize a `spirv.Constant` op at every use site.
-    return spirv::ConstantOp::create(opBuilder, unknownLoc, constInfo->second,
+    Location loc = unknownLoc;
+    if (auto locAttr = constantLocMap.lookup(id))
+      loc = Location(locAttr);
+    return spirv::ConstantOp::create(opBuilder, loc, constInfo->second,
                                      constInfo->first);
   }
   if (std::optional<std::pair<Attribute, Type>> constCompositeReplicateInfo =
@@ -171,8 +181,13 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processCapability(operands);
   case spirv::Opcode::OpExtension:
     return processExtension(operands);
-  case spirv::Opcode::OpExtInst:
+  case spirv::Opcode::OpExtInst: {
+    auto setIt = operands.size() >= 4 ? extendedInstSets.find(operands[2])
+                                      : extendedInstSets.end();
+    if (setIt != extendedInstSets.end() && setIt->second == extDebugInfo)
+      return processDebugInfoExtInst(operands, deferInstructions);
     return processExtInst(operands);
+  }
   case spirv::Opcode::OpExtInstImport:
     return processExtInstImport(operands);
   case spirv::Opcode::OpMemberName:
@@ -388,6 +403,97 @@ LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processDebugInfoExtInst(ArrayRef<uint32_t> operands,
+                                             bool deferInstructions) {
+  if (deferInstructions) {
+    deferredInstructions.emplace_back(spirv::Opcode::OpExtInst, operands);
+    return success();
+  }
+
+  if (operands.size() < 4) {
+    return emitError(unknownLoc,
+                     "OpExtInst must have at least 4 operands, result type "
+                     "<id>, result <id>, set <id> and instruction opcode");
+  }
+
+  auto &extensionSetName = extendedInstSets[operands[2]];
+  assert(extensionSetName == extDebugInfo);
+
+  auto getDebugLoc = [&](uint32_t stringID) -> FailureOr<Location> {
+    auto stringIt = debugInfoMap.find(stringID);
+    if (stringIt == debugInfoMap.end()) {
+      emitError(unknownLoc, "undefined string <id> ")
+          << stringID << " in DebugInfo";
+      return failure();
+    }
+    Location loc = getLocFromDebugInfoString(opBuilder, stringIt->second);
+    return loc;
+  };
+
+  auto instructionID = static_cast<spirv::GraphDebugInfoExtInst>(operands[3]);
+  switch (instructionID) {
+  case spirv::GraphDebugInfoExtInst::DebugGraph: {
+    if (operands.size() < 6)
+      return emitError(unknownLoc, "DebugGraph must have graph and string IDs");
+    auto &graphID = operands[4];
+    auto &stringID = operands[5];
+    auto graphIt = graphMap.find(graphID);
+    if (graphIt == graphMap.end())
+      return emitError(unknownLoc, "undefined graph <id> ")
+             << graphID << " in DebugGraph";
+    FailureOr<Location> loc = getDebugLoc(stringID);
+    if (failed(loc))
+      return failure();
+    graphIt->second->setLoc(*loc);
+    break;
+  }
+  case spirv::GraphDebugInfoExtInst::DebugOperation: {
+    if (operands.size() < 7)
+      return emitError(unknownLoc,
+                       "DebugOperation must have graph, string and "
+                       "instruction IDs");
+    auto &stringID = operands[5];
+    FailureOr<Location> loc = getDebugLoc(stringID);
+    if (failed(loc))
+      return failure();
+    SmallVector<uint32_t> operationIDs;
+    operationIDs.append(std::next(operands.begin(), 6), operands.end());
+    for (auto &operationID : operationIDs) {
+      auto valueIt = valueMap.find(operationID);
+      if (valueIt == valueMap.end())
+        return emitError(unknownLoc, "undefined operation <id> ")
+               << operationID << " in DebugOperation";
+      valueIt->second.setLoc(*loc);
+    }
+    break;
+  }
+  case spirv::GraphDebugInfoExtInst::DebugTensor: {
+    if (operands.size() < 6)
+      return emitError(unknownLoc, "DebugTensor must have tensor and string IDs");
+    auto &stringID = operands[5];
+    auto &tensorID = operands[4];
+    FailureOr<Location> loc = getDebugLoc(stringID);
+    if (failed(loc))
+      return failure();
+    if (constantMap.contains(tensorID)) {
+      constantLocMap[tensorID] = *loc;
+      break;
+    }
+    auto valueIt = valueMap.find(tensorID);
+    if (valueIt == valueMap.end())
+      return emitError(unknownLoc, "undefined tensor <id> ")
+             << tensorID << " in DebugTensor";
+    valueIt->second.setLoc(*loc);
+    break;
+  }
+  default:
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
   if (operands.size() < 4) {
     return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index b2adbb5518789..d2b7dbaad4655 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -563,6 +563,11 @@ class Deserializer {
   /// other entries.
   LogicalResult processExtInst(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpExtInst with given `operands` for a DebugInfo
+  /// extension instruction.
+  LogicalResult processDebugInfoExtInst(ArrayRef<uint32_t> operands,
+                                        bool deferInstructions);
+
   /// Dispatches the deserialization of extended instruction set operation based
   /// on the extended instruction set name, and instruction opcode. This is
   /// autogenerated from ODS.
@@ -632,6 +637,9 @@ class Deserializer {
   /// (and type) here. Later when it's used, we materialize the constant.
   DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
 
+  // Result <id> to debug location for constants materialized from constantMap.
+  DenseMap<uint32_t, LocationAttr> constantLocMap;
+
   // Result <id> to replicated constant attribute and type mapping.
   ///
   /// In the SPIR-V binary format, OpConstantCompositeReplicateEXT is placed in
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 841fc55a8627a..83341519fbc9e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Target/SPIRV/SPIRVExtInstSets.h"
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Debug.h"
@@ -24,6 +25,33 @@
 
 using namespace mlir;
 
+namespace {
+// Location::print() emits MLIR syntax such as `loc("name")` or
+// `loc(fused["op", "file":1:2])`. NonSemantic.Graph.DebugInfo stores the
+// source/debug name itself in an OpString, so keep this conversion to the
+// payload string explicit.
+std::string getDebugInfoStringFromLoc(Location loc) {
+  if (auto fileLineCol = dyn_cast<FileLineColLoc>(loc)) {
+    return fileLineCol.getFilename().str() + ":" +
+           std::to_string(fileLineCol.getLine()) + ":" +
+           std::to_string(fileLineCol.getColumn());
+  }
+  if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
+    return nameLoc.getName().str();
+  }
+  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
+    std::ostringstream result;
+    std::transform(
+        fusedLoc.getLocations().begin(), fusedLoc.getLocations().end(),
+        std::ostream_iterator<std::string>(result, ";"),
+        [&](const Location loc) { return getDebugInfoStringFromLoc(loc); });
+
+    return result.str();
+  }
+  return "";
+}
+} // namespace
+
 /// A pre-order depth-first visitor function for processing basic blocks.
 ///
 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
@@ -61,6 +89,9 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
   if (auto resultID =
           prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
     valueIDMap[op.getResult()] = resultID;
+    if (isa<spirv::TensorArmType>(op.getType()) &&
+        failed(encodeDebugInfoTensorInst(op.getResult())))
+      return failure();
     return success();
   }
   return failure();
@@ -386,6 +417,121 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   return success();
 }
 
+LogicalResult Serializer::encodeDebugStringInst(const std::string &str,
+                                                uint32_t &stringID) {
+  if (!options.emitDebugInfo)
+    return success();
+
+  SmallVector<uint32_t, 2> operands;
+  stringID = getNextID();
+  operands.push_back(stringID);
+  spirv::encodeStringLiteralInto(operands, str);
+  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
+
+  return success();
+}
+
+LogicalResult Serializer::encodeDebugInfoGraphInst(spirv::GraphARMOp op,
+                                                   uint32_t &debugGraphID) {
+  if (!options.emitDebugInfo)
+    return success();
+
+  processVoidType(typesGlobalValues);
+
+  uint32_t stringID = 0;
+  if (failed(encodeDebugStringInst(getDebugInfoStringFromLoc(op.getLoc()),
+                                   stringID)))
+    return failure();
+
+  SmallVector<uint32_t, 6> operands;
+  operands.push_back(getTypeID(getVoidType()));
+  debugGraphID = getNextID();
+  operands.push_back(debugGraphID);
+  uint32_t graphID = getOrCreateFunctionID(op.getName());
+  operands.push_back(graphID);
+  operands.push_back(stringID);
+
+  if (failed(encodeExtensionInstruction(
+          nullptr, extDebugInfo,
+          static_cast<uint32_t>(GraphDebugInfoExtInst::DebugGraph), operands,
+          graphsDebugInfo)))
+    return failure();
+
+  return success();
+}
+
+LogicalResult
+Serializer::encodeDebugInfoOperationInst(uint32_t debugGraphID,
+                                         SetVector<Operation *> ops) {
+  if (!options.emitDebugInfo)
+    return success();
+
+  if (ops.empty())
+    return success();
+
+  SmallVector<uint32_t, 4> instructionIDs;
+  for (auto op : ops)
+    for (auto result : op->getOpResults())
+      instructionIDs.push_back(getValueID(result));
+
+  if (instructionIDs.empty())
+    return success();
+
+  processVoidType(typesGlobalValues);
+
+  uint32_t stringID = 0;
+  if (failed(encodeDebugStringInst(getDebugInfoStringFromLoc(ops[0]->getLoc()),
+                                   stringID)))
+    return failure();
+
+  SmallVector<uint32_t, 5> operands;
+  operands.push_back(getTypeID(getVoidType()));
+  operands.push_back(getNextID());
+  operands.push_back(debugGraphID);
+  operands.push_back(stringID);
+  operands.append(instructionIDs);
+
+  if (failed(encodeExtensionInstruction(
+          nullptr, extDebugInfo,
+          static_cast<uint32_t>(GraphDebugInfoExtInst::DebugOperation),
+          operands,
+          graphsDebugInfo)))
+    return failure();
+
+  return success();
+}
+
+LogicalResult Serializer::encodeDebugInfoTensorInst(Value tensor) {
+  if (!options.emitDebugInfo)
+    return success();
+
+  processVoidType(typesGlobalValues);
+
+  auto it = valueIDMap.find(tensor);
+  if (it == valueIDMap.end())
+    return success();
+  auto tensorID = it->second;
+
+  uint32_t stringID = 0;
+  if (failed(encodeDebugStringInst(getDebugInfoStringFromLoc(tensor.getLoc()),
+                                   stringID)))
+    return failure();
+
+  SmallVector<uint32_t, 4> operands;
+  operands.push_back(getTypeID(getVoidType()));
+  operands.push_back(getNextID());
+  operands.push_back(tensorID);
+  operands.push_back(stringID);
+
+  if (failed(encodeExtensionInstruction(
+          nullptr, extDebugInfo,
+          static_cast<uint32_t>(GraphDebugInfoExtInst::DebugTensor), operands,
+          graphsDebugInfo)))
+    return failure();
+
+  return success();
+}
+
 LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
   if (op.getNumResults() < 1) {
     return op.emitError("cannot serialize graph with no return types");
@@ -423,6 +569,9 @@ LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
 
     encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
                           inputOperands);
+
+    if (failed(encodeDebugInfoTensorInst(arg)))
+      return failure();
   }
 
   if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
@@ -443,6 +592,15 @@ LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
   functionHeader.clear();
   functionBody.clear();
 
+  uint32_t debugGraphID = 0;
+  if (failed(encodeDebugInfoGraphInst(op, debugGraphID)))
+    return failure();
+
+  for (const auto &[loc, ops] : tosaOpsMap[funcID]) {
+    if (failed(encodeDebugInfoOperationInst(debugGraphID, ops)))
+      return failure();
+  }
+
   return success();
 }
 
@@ -492,6 +650,9 @@ Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
     outputOperands.push_back(outputID);
     outputOperands.push_back(indexID);
 
+    if (failed(encodeDebugInfoTensorInst(value)))
+      return failure();
+
     encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
                           outputOperands);
   }
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 11a7bf66d792d..553fcce6c5ea7 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -173,6 +173,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
   binary.append(functions.begin(), functions.end());
   binary.append(graphs.begin(), graphs.end());
+  binary.append(graphsDebugInfo.begin(), graphsDebugInfo.end());
 }
 
 #ifndef NDEBUG
@@ -632,6 +633,16 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
   return emitError(loc, "failed to process type: ") << type;
 }
 
+void Serializer::processVoidType(SmallVectorImpl<uint32_t> &binary) {
+  auto voidType = getVoidType();
+  uint32_t voidTypeID = getTypeID(voidType);
+  if (!voidTypeID) {
+    voidTypeID = getNextID();
+    encodeInstructionInto(binary, spirv::Opcode::OpTypeVoid, {voidTypeID});
+    typeIDMap[voidType] = voidTypeID;
+  }
+}
+
 LogicalResult Serializer::prepareBasicType(
     Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
     SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
@@ -1612,7 +1623,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
 
 LogicalResult Serializer::encodeExtensionInstruction(
     Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
-    ArrayRef<uint32_t> operands) {
+    ArrayRef<uint32_t> operands, SmallVectorImpl<uint32_t> &binary) {
   // Check if the extension has been imported.
   auto &setID = extendedInstSetIDMap[extensionSetName];
   if (!setID) {
@@ -1635,8 +1646,20 @@ LogicalResult Serializer::encodeExtensionInstruction(
   extInstOperands.push_back(setID);
   extInstOperands.push_back(extensionOpcode);
   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
-  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
-                        extInstOperands);
+  encodeInstructionInto(binary, spirv::Opcode::OpExtInst, extInstOperands);
+  return success();
+}
+
+LogicalResult Serializer::encodeExtensionInstruction(
+    Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
+    ArrayRef<uint32_t> operands) {
+  if (failed(encodeExtensionInstruction(op, extensionSetName, extensionOpcode,
+                                        operands, functionBody)))
+    return failure();
+
+  if (extensionSetName == extTosa)
+    updateTosaOpsMap(op);
+
   return success();
 }
 
@@ -1751,8 +1774,10 @@ LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
   for (Value operand : op->getOperands())
     operands.push_back(getValueID(operand));
 
-  if (failed(emitDebugLine(functionBody, loc)))
-    return failure();
+  if (extInstSet != extTosa)
+    // OpLine cannot be present in graphs
+    if (failed(emitDebugLine(functionBody, loc)))
+      return failure();
 
   if (extInstSet.empty()) {
     encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
@@ -1772,6 +1797,16 @@ LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
   return success();
 }
 
+void Serializer::updateTosaOpsMap(Operation *op) {
+  if (!options.emitDebugInfo)
+    return;
+
+  if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op->getParentOp())) {
+    if (auto graphID = getFunctionID(graphOp.getName()))
+      tosaOpsMap[graphID][op->getLoc()].insert(op);
+  }
+}
+
 LogicalResult Serializer::emitDecoration(uint32_t target,
                                          spirv::Decoration decoration,
                                          ArrayRef<uint32_t> params) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index e43556fab9acf..3f95c76a466b8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/Target/SPIRV/SPIRVExtInstSets.h"
 #include "mlir/Target/SPIRV/Serialization.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
@@ -217,6 +218,8 @@ class Serializer {
                                  spirv::Opcode &typeEnum,
                                  SmallVectorImpl<uint32_t> &operands);
 
+  void processVoidType(SmallVectorImpl<uint32_t> &binary);
+
   //===--------------------------------------------------------------------===//
   // Constant
   //===--------------------------------------------------------------------===//
@@ -328,6 +331,23 @@ class Serializer {
                                            uint32_t opcode,
                                            ArrayRef<uint32_t> operands);
 
+  LogicalResult encodeExtensionInstruction(Operation *op,
+                                           StringRef extensionSetName,
+                                           uint32_t opcode,
+                                           ArrayRef<uint32_t> operands,
+                                           SmallVectorImpl<uint32_t> &binary);
+
+  LogicalResult encodeDebugStringInst(const std::string &str,
+                                      uint32_t &stringID);
+
+  LogicalResult encodeDebugInfoGraphInst(spirv::GraphARMOp op,
+                                         uint32_t &debugGraphID);
+
+  LogicalResult encodeDebugInfoOperationInst(uint32_t debugGraphID,
+                                             SetVector<Operation *> ops);
+
+  LogicalResult encodeDebugInfoTensorInst(Value tensor);
+
   uint32_t getValueID(Value val) const { return valueIDMap.lookup(val); }
 
   LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
@@ -363,6 +383,9 @@ class Serializer {
   // Utilities
   //===--------------------------------------------------------------------===//
 
+  /// Updates tosaOpsMap after ensuring that the op is inside a graph.
+  void updateTosaOpsMap(Operation *op);
+
   /// Emits an OpDecorate instruction to decorate the given `target` with the
   /// given `decoration`.
   LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration,
@@ -417,6 +440,7 @@ class Serializer {
   SmallVector<uint32_t, 0> typesGlobalValues;
   SmallVector<uint32_t, 0> functions;
   SmallVector<uint32_t, 0> graphs;
+  SmallVector<uint32_t, 0> graphsDebugInfo;
 
   /// Recursive struct references are serialized as OpTypePointer instructions
   /// to the recursive struct type. However, the OpTypePointer instruction
@@ -485,6 +509,10 @@ class Serializer {
   /// Map from extended instruction set name to <id>s.
   llvm::StringMap<uint32_t> extendedInstSetIDMap;
 
+  /// Map of graph <id> to map of locations in that graph to set of tosa ops in
+  /// that location
+  DenseMap<uint32_t, DenseMap<Location, SetVector<Operation *>>> tosaOpsMap;
+
   /// Map from values used in OpPhi instructions to their offset in the
   /// `functions` section.
   ///
diff --git a/mlir/test/Target/SPIRV/debug-info.mlir b/mlir/test/Target/SPIRV/debug-info.mlir
new file mode 100644
index 0000000000000..5fc221e4af034
--- /dev/null
+++ b/mlir/test/Target/SPIRV/debug-info.mlir
@@ -0,0 +1,131 @@
+// RUN: mlir-translate -no-implicit-module -split-input-file --verify-diagnostics -mlir-print-debuginfo -test-spirv-roundtrip-debug %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: FileLineCol Locations
+//===----------------------------------------------------------------------===//
+
+// CHECK: #loc[[LOC_TENSOR0:.*]] = loc("{{.*}}debug-info.mlir{{.*}}:12:27")
+// CHECK: #loc[[LOC_TENSOR1:.*]] = loc("{{.*}}debug-info.mlir{{.*}}:12:67")
+// CHECK: #loc[[LOC_TENSOR2:.*]] = loc("{{.*}}debug-info.mlir{{.*}}:12:105")
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_non_semantic_info]> {
+  // CHECK: spirv.ARM.Graph @{{.*}}(%arg0: !spirv.arm.tensor<1x16x16x1xi8> loc("{{.*}}debug-info.mlir{{.*}}:12:27"), %arg1: !spirv.arm.tensor<8x3x3x1xi8> loc("{{.*}}debug-info.mlir{{.*}}:12:67"), %arg2: !spirv.arm.tensor<8xi32> loc("{{.*}}debug-info.mlir{{.*}}:12:105")) -> !spirv.arm.tensor<1x14x14x8xi32> attributes {entry_point = false} {
+  spirv.ARM.Graph @conv2d(%arg0: !spirv.arm.tensor<1x16x16x1xi8>, %arg1: !spirv.arm.tensor<8x3x3x1xi8>, %arg2: !spirv.arm.tensor<8xi32>) -> (!spirv.arm.tensor<1x14x14x8xi32>) {
+      %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      // CHECK: {{%.*}} = spirv.Tosa.Conv2D{{.*}}loc(#loc[[LOC_OP:.*]])
+      %2 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %0, %1 : !spirv.arm.tensor<1x16x16x1xi8>, !spirv.arm.tensor<8x3x3x1xi8>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x14x14x8xi32>
+      spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<1x14x14x8xi32>
+  // CHECK: } loc(#loc[[LOC_GRAPH:.*]])
+  }
+}
+// CHECK-DAG: #loc[[LOC_GRAPH]] = loc("{{.*}}debug-info.mlir{{.*}}:12:3")
+// CHECK-DAG: #loc[[LOC_OP]] = loc("{{.*}}debug-info.mlir{{.*}}:16:12")
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: Name Locations
+//===----------------------------------------------------------------------===//
+
+// CHECK: #loc[[NAME_TENSOR0:.*]] = loc("tensor_0")
+// CHECK: #loc[[NAME_TENSOR1:.*]] = loc("tensor_1")
+// CHECK: #loc[[NAME_TENSOR2:.*]] = loc("tensor_2")
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_non_semantic_info]> {
+  // CHECK: spirv.ARM.Graph @{{.*}}(%arg0: !spirv.arm.tensor<1x16x16x1xi8> loc("tensor_0"), %arg1: !spirv.arm.tensor<8x3x3x1xi8> loc("tensor_1"), %arg2: !spirv.arm.tensor<8xi32> loc("tensor_2")) -> !spirv.arm.tensor<1x14x14x8xi32> attributes {entry_point = false} {
+  spirv.ARM.Graph @conv2d(%arg0: !spirv.arm.tensor<1x16x16x1xi8> loc("tensor_0") , %arg1: !spirv.arm.tensor<8x3x3x1xi8> loc("tensor_1"), %arg2: !spirv.arm.tensor<8xi32> loc("tensor_2")) -> (!spirv.arm.tensor<1x14x14x8xi32>) {
+      %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      // CHECK: {{%.*}} = spirv.Tosa.Conv2D{{.*}}loc(#loc[[NAME_OP:.*]])
+      %2 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %0, %1 : !spirv.arm.tensor<1x16x16x1xi8>, !spirv.arm.tensor<8x3x3x1xi8>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x14x14x8xi32> loc("op_0")
+      spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<1x14x14x8xi32>
+  // CHECK: } loc(#loc[[NAME_GRAPH:.*]])
+  } loc("graph_0")
+}
+// CHECK-DAG: #loc[[NAME_GRAPH]] = loc("graph_0")
+// CHECK-DAG: #loc[[NAME_OP]] = loc("op_0")
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: Multiple tosa ops with same locations
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_non_semantic_info]> {
+  spirv.ARM.Graph @test(%arg0: !spirv.arm.tensor<2x9x3x32xi16> loc("tensor_0")) -> (!spirv.arm.tensor<2x9x3x32xi8>) attributes {entry_point = true} {
+    %weight = spirv.ARM.GraphConstant {graph_constant_id = 0 : i32} : !spirv.arm.tensor<32x1x1x32xi8>
+    %bias = spirv.ARM.GraphConstant {graph_constant_id = 1 : i32} : !spirv.arm.tensor<32xi64>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+    %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+    // CHECK: %[[CONV2D:.*]] = spirv.Tosa.Conv2D{{.*}}loc(#loc[[SAME_OP:.*]])
+    %conv2d = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %weight, %bias, %0, %1  : !spirv.arm.tensor<2x9x3x32xi16>, !spirv.arm.tensor<32x1x1x32xi8>, !spirv.arm.tensor<32xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x9x3x32xi64> loc("op_0")
+    %multiplier = spirv.ARM.GraphConstant {graph_constant_id = 2 : i32} : !spirv.arm.tensor<1xi16>
+    %shift = spirv.ARM.GraphConstant {graph_constant_id = 3 : i32} : !spirv.arm.tensor<1xi8>
+    %5 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+    %6 = spirv.Constant dense<-4> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.Rescale{{.*}}loc(#loc[[SAME_OP]])
+    %rescale = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %conv2d, %multiplier, %shift, %5, %6 : !spirv.arm.tensor<2x9x3x32xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x9x3x32xi8> loc("op_0")
+    spirv.ARM.GraphOutputs %rescale : !spirv.arm.tensor<2x9x3x32xi8>
+  // CHECK: } loc(#loc[[SAME_GRAPH:.*]])
+  } loc("graph_0")
+}
+// CHECK-DAG: #loc[[SAME_GRAPH]] = loc("graph_0")
+// CHECK-DAG: #loc[[SAME_OP]] = loc("op_0")
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: Multiple tosa ops with differing locations
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_non_semantic_info]> {
+  spirv.ARM.Graph @test(%arg0: !spirv.arm.tensor<2x9x3x32xi16> loc("tensor_0")) -> (!spirv.arm.tensor<2x9x3x32xi8>) attributes {entry_point = true} {
+    %weight = spirv.ARM.GraphConstant {graph_constant_id = 0 : i32} : !spirv.arm.tensor<32x1x1x32xi8>
+    %bias = spirv.ARM.GraphConstant {graph_constant_id = 1 : i32} : !spirv.arm.tensor<32xi64>
+    %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
+    %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+    // CHECK: %[[CONV2D:.*]] = spirv.Tosa.Conv2D{{.*}}loc(#loc[[MULTI_OP0:.*]])
+    %conv2d = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT48>, local_bound = false, %arg0, %weight, %bias, %0, %1  : !spirv.arm.tensor<2x9x3x32xi16>, !spirv.arm.tensor<32x1x1x32xi8>, !spirv.arm.tensor<32xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x9x3x32xi64> loc("op_0")
+    %multiplier = spirv.ARM.GraphConstant {graph_constant_id = 2 : i32} : !spirv.arm.tensor<1xi16>
+    %shift = spirv.ARM.GraphConstant {graph_constant_id = 3 : i32} : !spirv.arm.tensor<1xi8>
+    %5 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
+    %6 = spirv.Constant dense<-4> : !spirv.arm.tensor<1xi8>
+    // CHECK: {{%.*}} = spirv.Tosa.Rescale{{.*}}loc(#loc[[MULTI_OP1:.*]])
+    %rescale = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %conv2d, %multiplier, %shift, %5, %6 : !spirv.arm.tensor<2x9x3x32xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<2x9x3x32xi8> loc("op_1")
+    spirv.ARM.GraphOutputs %rescale : !spirv.arm.tensor<2x9x3x32xi8>
+  // CHECK: } loc(#loc[[MULTI_GRAPH:.*]])
+  } loc("graph_0")
+}
+// CHECK-DAG: #loc[[MULTI_GRAPH]] = loc("graph_0")
+// CHECK-DAG: #loc[[MULTI_OP0]] = loc("op_0")
+// CHECK-DAG: #loc[[MULTI_OP1]] = loc("op_1")
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: Fused Locations
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, Int16, Int64, Float16, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_non_semantic_info]> {
+  spirv.ARM.Graph @fused(%arg0: !spirv.arm.tensor<1x16x16x1xi8> loc("tensor_0"), %arg1: !spirv.arm.tensor<8x3x3x1xi8>, %arg2: !spirv.arm.tensor<8xi32>) -> (!spirv.arm.tensor<1x14x14x8xi32>) {
+      %0 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      %1 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
+      // CHECK: {{%.*}} = spirv.Tosa.Conv2D{{.*}}loc(#loc[[FUSED:.*]])
+      %2 = spirv.Tosa.Conv2D pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %0, %1 : !spirv.arm.tensor<1x16x16x1xi8>, !spirv.arm.tensor<8x3x3x1xi8>, !spirv.arm.tensor<8xi32>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x14x14x8xi32> loc(fused["op_0", "source.cc":12:34])
+      spirv.ARM.GraphOutputs %2 : !spirv.arm.tensor<1x14x14x8xi32>
+  // CHECK: } loc(#loc[[FUSED_GRAPH:.*]])
+  } loc("graph_0")
+}
+// CHECK-DAG: #loc[[FUSED_GRAPH]] = loc("graph_0")
+// CHECK-DAG: #loc[[FUSED]] = loc("op_0;source.cc:12:34;")
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv DebugInfo: Missing non-semantic info extension
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{SPV_KHR_non_semantic_info extension not available}}
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  spirv.ARM.Graph @g(%arg0: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<1xi8>) {
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1xi8>
+  }
+}
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
index ffe815b9be5e2..b485f58435fea 100644
--- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Target/SPIRV/SPIRVExtInstSets.h"
 #include "gmock/gmock.h"
 
 #include <memory>
@@ -162,6 +163,51 @@ TEST_F(DeserializationTest, InsufficientWordFailure) {
   expectDiagnostic("insufficient words for the last instruction");
 }
 
+TEST_F(DeserializationTest, OpExtInstMissingOperands) {
+  addHeader();
+  addInstruction(spirv::Opcode::OpExtInst, {1, 2});
+
+  ASSERT_FALSE(deserialize());
+  expectDiagnostic("OpExtInst must have at least 4 operands, result type <id>, "
+                   "result <id>, set <id> and instruction opcode");
+}
+
+TEST_F(DeserializationTest, DebugInfoExtInstMissingOperands) {
+  addHeader();
+  SmallVector<uint32_t, 4> importOperands = {nextID++};
+  uint32_t extInstSetID = importOperands[0];
+  spirv::encodeStringLiteralInto(importOperands, spirv::extDebugInfo);
+  addInstruction(spirv::Opcode::OpExtInstImport, importOperands);
+
+  uint32_t voidType = addVoidType();
+  addInstruction(spirv::Opcode::OpExtInst,
+                 {voidType, nextID++, extInstSetID,
+                  static_cast<uint32_t>(
+                      spirv::GraphDebugInfoExtInst::DebugTensor)});
+
+  ASSERT_FALSE(deserialize());
+  expectDiagnostic("DebugTensor must have tensor and string IDs");
+}
+
+TEST_F(DeserializationTest, DebugOperationMissingInstructionIDs) {
+  addHeader();
+  SmallVector<uint32_t, 4> importOperands = {nextID++};
+  uint32_t extInstSetID = importOperands[0];
+  spirv::encodeStringLiteralInto(importOperands, spirv::extDebugInfo);
+  addInstruction(spirv::Opcode::OpExtInstImport, importOperands);
+
+  uint32_t voidType = addVoidType();
+  addInstruction(spirv::Opcode::OpExtInst,
+                 {voidType, nextID++, extInstSetID,
+                  static_cast<uint32_t>(
+                      spirv::GraphDebugInfoExtInst::DebugOperation),
+                  /*debugGraphID=*/42, /*stringID=*/43});
+
+  ASSERT_FALSE(deserialize());
+  expectDiagnostic(
+      "DebugOperation must have graph, string and instruction IDs");
+}
+
 //===----------------------------------------------------------------------===//
 // Types
 //===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index b0413a8e994ba..fa96a59983ead 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -22,6 +22,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Target/SPIRV/Deserialization.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Target/SPIRV/SPIRVExtInstSets.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/Sequence.h"
@@ -98,6 +99,65 @@ class SerializationTest : public ::testing::Test {
     llvm_unreachable("unimplemented types for AddConstInt()");
   }
 
+  /// Inserts the programmatic equivalent of this minimal SPIR-V graph:
+  ///
+  ///   spirv.GraphARM @argmax(%arg0: !spirv.arm.tensor<1x2xi8>
+  ///       loc("tensor_0")) -> !spirv.arm.tensor<1xi32> {
+  ///     %cst = spirv.Constant dense<7> : !spirv.arm.tensor<1x2xi8>
+  ///         loc("constant_0")
+  ///     %out = spirv.TosaArgMax 1, Propagate, %arg0
+  ///         : !spirv.arm.tensor<1x2xi8> to !spirv.arm.tensor<1xi32>
+  ///         loc("op_0")
+  ///     spirv.GraphOutputsARM %out : !spirv.arm.tensor<1xi32>
+  ///   } loc("graph_0")
+  ///
+  /// The constant is intentionally unused; it gives the test a distinct
+  /// tensor value that is neither a graph input nor a graph output.
+  void addGraphWithDebugInfo() {
+    OpBuilder builder(module->getRegion());
+    (*module)->setAttr("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
+                                           spirv::MemoryModel::Vulkan));
+    (*module)->setAttr(
+        "vce_triple",
+        spirv::VerCapExtAttr::get(
+            spirv::Version::V_1_6,
+            ArrayRef<spirv::Capability>{
+                spirv::Capability::VulkanMemoryModel, spirv::Capability::Shader,
+                spirv::Capability::Int8, spirv::Capability::TensorsARM,
+                spirv::Capability::GraphARM},
+            ArrayRef<spirv::Extension>{
+                spirv::Extension::SPV_ARM_tensors,
+                spirv::Extension::SPV_ARM_graph,
+                spirv::Extension::SPV_KHR_non_semantic_info},
+            &context));
+
+    Type inputType =
+        spirv::TensorArmType::get({1, 2}, builder.getIntegerType(8));
+    Type outputType =
+        spirv::TensorArmType::get({1}, builder.getIntegerType(32));
+    GraphType graphType = builder.getGraphType({inputType}, {outputType});
+
+    auto graphLoc = NameLoc::get(builder.getStringAttr("graph_0"));
+    auto graph =
+        spirv::GraphARMOp::create(builder, graphLoc, "argmax", graphType);
+    Block *entry = graph.addEntryBlock();
+    entry->getArgument(0).setLoc(
+        NameLoc::get(builder.getStringAttr("tensor_0")));
+
+    OpBuilder bodyBuilder = OpBuilder::atBlockBegin(entry);
+    auto constLoc = NameLoc::get(builder.getStringAttr("constant_0"));
+    auto constAttr =
+        DenseIntElementsAttr::get(cast<ShapedType>(inputType), {APInt(8, 7)});
+    spirv::ConstantOp::create(bodyBuilder, constLoc, inputType, constAttr);
+
+    auto opLoc = NameLoc::get(builder.getStringAttr("op_0"));
+    auto argMax = spirv::TosaArgMaxOp::create(
+        bodyBuilder, opLoc, outputType, 1,
+        spirv::TosaExtNaNPropagationModeType::Propagate, entry->getArgument(0));
+    spirv::GraphOutputsARMOp::create(bodyBuilder, UnknownLoc::get(&context),
+                                     argMax.getOutput());
+  }
+
   /// Handles a SPIR-V instruction with the given `opcode` and `operand`.
   /// Returns true to interrupt.
   using HandleFn = llvm::function_ref<bool(spirv::Opcode opcode,
@@ -302,8 +362,154 @@ bool allInstructionsWithinWordLimit(SmallVectorImpl<uint32_t> &binary) {
   return true;
 }
 
+using InstructionCallback =
+    function_ref<void(size_t, spirv::Opcode, ArrayRef<uint32_t>)>;
+
+void walkInstructions(SmallVectorImpl<uint32_t> &binary,
+                      InstructionCallback callback) {
+  size_t offset = spirv::kHeaderWordCount;
+  while (offset < binary.size()) {
+    uint32_t wordCount = binary[offset] >> 16;
+    if (!wordCount || offset + wordCount > binary.size())
+      return;
+
+    auto opcode = static_cast<spirv::Opcode>(binary[offset] & 0xffff);
+    ArrayRef<uint32_t> operands(binary.begin() + offset + 1,
+                                binary.begin() + offset + wordCount);
+    callback(offset, opcode, operands);
+    offset += wordCount;
+  }
+}
+
+std::optional<uint32_t> getExtInstSetID(SmallVectorImpl<uint32_t> &binary,
+                                        StringRef extInstSetName) {
+  std::optional<uint32_t> extInstSetID;
+  walkInstructions(binary, [&](size_t, spirv::Opcode opcode,
+                               ArrayRef<uint32_t> operands) {
+    if (opcode == spirv::Opcode::OpExtInstImport && operands.size() >= 2) {
+      unsigned stringIndex = 1;
+      if (spirv::decodeStringLiteral(operands, stringIndex) == extInstSetName) {
+        extInstSetID = operands[0];
+        return;
+      }
+    }
+  });
+  return extInstSetID;
+}
+
+struct ExtInstRecord {
+  size_t offset;
+  uint32_t resultID;
+  uint32_t setID;
+  uint32_t instruction;
+  SmallVector<uint32_t, 4> arguments;
+};
+
+SmallVector<ExtInstRecord> getExtInstRecords(SmallVectorImpl<uint32_t> &binary,
+                                             uint32_t extInstSetID) {
+  SmallVector<ExtInstRecord> records;
+  walkInstructions(binary, [&](size_t offset, spirv::Opcode opcode,
+                               ArrayRef<uint32_t> operands) {
+    if (opcode == spirv::Opcode::OpExtInst && operands.size() >= 4 &&
+        operands[2] == extInstSetID) {
+      ExtInstRecord record{offset, operands[1], operands[2], operands[3], {}};
+      record.arguments.append(operands.begin() + 4, operands.end());
+      records.push_back(std::move(record));
+    }
+  });
+
+  return records;
+}
+
 } // namespace
 
+TEST_F(SerializationTest, GraphDebugInfoReferencesSerializedObjects) {
+  addGraphWithDebugInfo();
+
+  spirv::SerializationOptions options;
+  options.emitDebugInfo = true;
+  ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options)));
+
+  std::optional<uint32_t> extInstSetID =
+      getExtInstSetID(binary, spirv::extDebugInfo);
+  ASSERT_TRUE(extInstSetID);
+
+  SmallVector<ExtInstRecord> records = getExtInstRecords(binary, *extInstSetID);
+  ASSERT_FALSE(records.empty());
+
+  // DebugOperation's first payload operand must reference the DebugGraph result
+  // ID, not the OpGraphARM result ID.
+  auto debugGraphIt = llvm::find_if(records, [](const ExtInstRecord &record) {
+    return record.instruction ==
+           static_cast<uint32_t>(spirv::GraphDebugInfoExtInst::DebugGraph);
+  });
+  ASSERT_NE(debugGraphIt, records.end());
+  auto debugOperationIt =
+      llvm::find_if(records, [](const ExtInstRecord &record) {
+        return record.instruction ==
+               static_cast<uint32_t>(
+                   spirv::GraphDebugInfoExtInst::DebugOperation);
+      });
+  ASSERT_NE(debugOperationIt, records.end());
+
+  ASSERT_GE(debugOperationIt->arguments.size(), 3u);
+  EXPECT_EQ(debugOperationIt->arguments[0], debugGraphIt->resultID);
+
+  // NonSemantic.Graph.DebugInfo does not allow forward references, so a
+  // DebugGraph must be emitted before any DebugOperation that references it.
+  EXPECT_LT(debugGraphIt->offset, debugOperationIt->offset);
+
+  // Collect the serialized object IDs that DebugTensor is expected to describe.
+  SmallVector<uint32_t> graphInputIDs;
+  SmallVector<uint32_t> graphOutputIDs;
+  SmallVector<uint32_t> tensorConstantIDs;
+  walkInstructions(binary, [&](size_t, spirv::Opcode opcode,
+                               ArrayRef<uint32_t> operands) {
+    if (opcode == spirv::Opcode::OpGraphInputARM && operands.size() >= 2)
+      graphInputIDs.push_back(operands[1]);
+    if (opcode == spirv::Opcode::OpGraphSetOutputARM && operands.size() >= 1)
+      graphOutputIDs.push_back(operands[0]);
+    if (opcode == spirv::Opcode::OpConstantComposite && operands.size() >= 2)
+      tensorConstantIDs.push_back(operands[1]);
+  });
+  ASSERT_FALSE(graphInputIDs.empty());
+  ASSERT_FALSE(graphOutputIDs.empty());
+  ASSERT_FALSE(tensorConstantIDs.empty());
+
+  // Graph inputs are tensor values and should get DebugTensor records.
+  auto debugTensorIt = llvm::find_if(records, [&](const ExtInstRecord &record) {
+    return record.instruction ==
+               static_cast<uint32_t>(
+                   spirv::GraphDebugInfoExtInst::DebugTensor) &&
+           !record.arguments.empty() &&
+           llvm::is_contained(graphInputIDs, record.arguments[0]);
+  });
+  ASSERT_NE(debugTensorIt, records.end());
+
+  // Graph outputs are described through the SSA value passed to
+  // OpGraphSetOutputARM.
+  auto outputDebugTensorIt =
+      llvm::find_if(records, [&](const ExtInstRecord &record) {
+        return record.instruction ==
+                   static_cast<uint32_t>(
+                       spirv::GraphDebugInfoExtInst::DebugTensor) &&
+               !record.arguments.empty() &&
+               llvm::is_contained(graphOutputIDs, record.arguments[0]);
+      });
+  ASSERT_NE(outputDebugTensorIt, records.end());
+
+  // Tensor-typed spirv.Constant results should also get DebugTensor records.
+  auto constantDebugTensorIt =
+      llvm::find_if(records, [&](const ExtInstRecord &record) {
+        return record.instruction ==
+                   static_cast<uint32_t>(
+                       spirv::GraphDebugInfoExtInst::DebugTensor) &&
+               !record.arguments.empty() &&
+               llvm::is_contained(tensorConstantIDs, record.arguments[0]);
+      });
+  ASSERT_NE(constantDebugTensorIt, records.end());
+}
+
 TEST_F(SerializationTest, LongTypeStructIsSplit) {
   OpBuilder builder(module->getRegion());
   Type i32Type = builder.getIntegerType(32);



More information about the Mlir-commits mailing list