[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension (PR #147937)

Davide Grohmann llvmlistbot at llvm.org
Thu Jul 10 05:08:55 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/147937

>From 60ace99000c4f86d3aaae4644ebe97658aa0a6b3 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 19 Jun 2025 16:57:33 +0200
Subject: [PATCH 1/2] [mlir][spirv] Add support for SPV_ARM_graph extension
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This patch adds support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources.

The implementation includes:

- A new `GraphType`, modeled similarly to `FunctionType`, for typed graph signatures.
- New operations in the `spirv.arm` namespace:
  - `spirv.arm.Graph`
  - `spirv.arm.GraphEntryPoint`
  - `spirv.arm.GraphConstant`
  - `spirv.arm.GraphOutput`
- Serialization and deserialization support for:
  - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM`
  - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM`
- ABI lowering support for graph entry points via `LowerABIAttributesPass`.
- Verifier and VCE updates to properly gate usage under `SPV_ARM_graph`.
- Tests covering parsing, verification, ABI handling, and binary round-tripping.

Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images.

Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I99aa469f2108219591544056db55bcd3f0702c7e
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  27 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 201 ++++++++++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |   1 +
 mlir/include/mlir/IR/Builders.h               |   2 +
 mlir/include/mlir/IR/BuiltinTypes.td          |  18 +-
 mlir/include/mlir/IR/CommonTypeConstraints.td |   7 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  14 +-
 .../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp    |  12 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 230 +++++++++++++
 .../Transforms/LowerABIAttributesPass.cpp     | 138 +++++++-
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |   8 +
 mlir/lib/IR/AsmPrinter.cpp                    |  17 +-
 mlir/lib/IR/Builders.cpp                      |   4 +
 mlir/lib/IR/BuiltinTypes.cpp                  |  43 +++
 .../SPIRV/Deserialization/DeserializeOps.cpp  |  21 ++
 .../SPIRV/Deserialization/Deserializer.cpp    | 304 ++++++++++++++++++
 .../SPIRV/Deserialization/Deserializer.h      |  51 ++-
 .../SPIRV/Serialization/SerializeOps.cpp      | 126 ++++++++
 .../Target/SPIRV/Serialization/Serializer.cpp |  86 ++++-
 .../Target/SPIRV/Serialization/Serializer.h   |  41 ++-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  17 +
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     |  30 ++
 .../test/Dialect/SPIRV/IR/target-and-abi.mlir |  23 +-
 .../SPIRV/Transforms/abi-interface.mlir       |  22 ++
 .../SPIRV/Transforms/vce-deduction.mlir       |  11 +
 mlir/test/Target/SPIRV/graph-ops.mlir         |  24 ++
 .../lib/Dialect/SPIRV/TestAvailability.cpp    |  18 +-
 27 files changed, 1465 insertions(+), 31 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
 create mode 100644 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
 create mode 100644 mlir/test/Target/SPIRV/graph-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 910418f1706a6..ce4bb6c2e4934 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -423,6 +423,7 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
 def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph                        : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
 
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -447,7 +448,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
-      SPV_ARM_tensors,
+      SPV_ARM_tensors, SPV_ARM_graph,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1332,6 +1333,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
     Extension<[SPV_ARM_tensors]>
   ];
 }
+def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_graph]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1545,7 +1552,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
       SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
-      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4245,6 +4252,7 @@ def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
+
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
@@ -4551,6 +4559,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
 def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM               : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM             : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM                       : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM                  : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM              : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM                    : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM                   : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4666,6 +4681,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
       SPIRV_OC_OpGroupNonUniformLogicalXor,
       SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+      SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+      SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
       SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4836,6 +4854,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
+
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..38fb4b2eff414
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,201 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_ArmVendorOp<mnemonic, traits> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+    AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+    FunctionOpInterface, InModuleScope, IsolatedFromAbove
+  ]> {
+
+  let summary = "Declare or define a SPIR-V graph";
+
+  let description = [{
+    This op declares or defines a SPIR-V graph using one region, which
+    contains one or more blocks.
+
+    Different from the SPIR-V binary format, this op is not allowed to
+    implicitly capture global values, and all external references must use
+    function arguments or symbol references. This op itself defines a symbol
+    that is unique in the enclosing module op.
+
+    This op itself takes no operands and generates no results. Its region
+    can take zero or more arguments and return zero or more values.
+
+    ```
+    spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+                        region
+    ```
+  }];
+
+  let arguments = (ins
+    TypeAttrOf<GraphType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
+    OptionalAttr<BoolAttr>:$entry_point,
+    StrAttr:$sym_name
+  );
+
+  let results = (outs);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasVerifier = 0;
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,  CArg<"bool", "false">:$entry_point)>];
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+
+  let extraClassDeclaration = [{
+    /// Hook for FunctionOpInterface, called after verifying that the 'type'
+    /// attribute is present and checks if it holds a function type. Ensures
+    /// getType, getNumArguments, and getNumResults can be called safely
+    LogicalResult verifyType();
+
+    /// Hook for FunctionOpInterface, called after verifying the function
+    /// type and the presence of the (potentially empty) function body.
+    /// Ensures SPIR-V specific semantics.
+    LogicalResult verifyBody();
+  }];
+}
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+  "op must appear in a spirv.ARM.Graph op's block",
+  CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+  let summary = [{
+    Declare a graph entry point and its interface.
+  }];
+
+  let description = [{
+    Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+    Name is a name string for the graphentry point. A module cannot have two
+    OpGraphEntryPointARM instructions with the same Name string.
+
+    Interface is a list of symbol references to `spirv.GlobalVariable`
+    operations. These declare the set of global variables from a
+    module that form the interface of this entry point. The set of
+    Interface symbols must be equal to or a superset of the
+    `spirv.GlobalVariable`s referenced by the entry point’s static call
+    tree, within the interface’s storage classes.
+
+    ```
+    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
+                       symbol-reference (`, ` symbol-reference)*
+    ```
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SymbolRefArrayAttr:$interface
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let builders = [
+    OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+                                               Terminator]> {
+
+  let summary = "Define graph outputs.";
+
+  let description = [{
+    Values are the graph outputs values and must match the GraphOutputs Type
+    operand of the OpTypeGraphARM type of the OpGraphARM body this
+    instruction is in.
+
+    This instruction must be the last instruction in a block.
+
+    ```
+    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<SPIRV_AnyTensorArm>:$value
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let hasOpcode = 0;
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index ad59ea63a6901..aa7d30b87db14 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
 class IntegerType;
 class FloatType;
 class FunctionType;
+class GraphType;
 class IndexType;
 class MemRefType;
 class VectorType;
@@ -81,6 +82,7 @@ class Builder {
   IntegerType getIntegerType(unsigned width);
   IntegerType getIntegerType(unsigned width, bool isSigned);
   FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+  GraphType getGraphType(TypeRange inputs, TypeRange results);
   TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let storageClass = "FunctionTypeStorage";
   let genStorageClass = 0;
   let extraClassDeclaration = [{
     /// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     unsigned getNumResults() const;
     Type getResult(unsigned i) const { return getResults()[i]; }
 
-    /// Returns a clone of this function type with the given argument
+    /// Returns a clone of this function-like type with the given argument
     /// and result types.
-    FunctionType clone(TypeRange inputs, TypeRange results) const;
+    }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
 
-    /// Returns a new function type with the specified arguments and results
+    /// Returns a new function-like type with the specified arguments and results
     /// inserted.
-    FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+    }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
                                        TypeRange argTypes,
                                        ArrayRef<unsigned> resultIndices,
                                        TypeRange resultTypes);
 
-    /// Returns a new function type without the specified arguments and results.
-    FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+    /// Returns a new function-like type without the specified arguments and results.
+    }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
                                           const BitVector &resultIndices);
   }];
 }
 
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
 //===----------------------------------------------------------------------===//
 // IndexType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 45ec1846580f2..aab1b01c5cff9 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
 def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
                               "function type", "::mlir::FunctionType">;
 
+// Graph Type
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+                              "graph type", "::mlir::GraphType">;
+
+
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 88c7adf3dfcb3..e66d4b0ffc446 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1019,8 +1019,14 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
-    NamedAttribute attribute) {
-  return op->emitError("cannot attach SPIR-V attributes to region result");
+LogicalResult SPIRVDialect::verifyRegionResultAttribute(Operation *op,
+                                                        unsigned regionIndex,
+                                                        unsigned resultIndex,
+                                                        NamedAttribute attribute) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return op->emitError("cannot attach SPIR-V attributes to region result which is "
+                         "not a FunctionOpInterface type");
+  return verifyRegionAttribute(
+      op->getLoc(), funcOp.getResultTypes()[resultIndex], attribute);
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
   return isNestedInFunctionOpInterface(op->getParentOp());
 }
 
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+  if (!op)
+    return false;
+  if (op->hasTrait<OpTrait::SymbolTable>())
+    return false;
+  if (isa<spirv::GraphARMOp>(op))
+    return true;
+  return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
 /// Returns true if the given op is an module-like op that maintains a symbol
 /// table.
 static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..17cbab189588f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1084,6 +1084,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Type, 0> idTypes;
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+    return failure();
+  }
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
+  }
+  result.addAttribute("interface",
+                      parser.getBuilder().getArrayAttr(interfaceVars));
+  return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  auto interfaceVars = getInterface().getValue();
+  if (!interfaceVars.empty()) {
+    printer << ", ";
+    llvm::interleaveComma(interfaceVars, printer);
+  }
+}
+
+LogicalResult spirv::GraphEntryPointARMOp::verify() {
+  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+  // verification.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  SmallVector<Type> resultTypes;
+  auto &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (auto &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  auto grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  auto *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  auto grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  GraphType grType = getFunctionType();
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  auto walkResult = walk([grType](Operation *op) -> WalkResult {
+    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
+      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
+        return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
+               << graphOutputsARMOp.getNumOperands()
+               << "value(s) but enclosing graph requires "
+               << grType.getNumResults() << " results";
+
+      auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
+      for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
+        auto graphOutputOperandType = graphOutputOperandTypes[i];
+        auto grResultType = grType.getResult(i);
+        if (graphOutputOperandType != grResultType)
+          return graphOutputsARMOp.emitError("type of return operand ")
+                 << i << " (" << graphOutputOperandType
+                 << ") doesn't match graph result type (" << grResultType
+                 << ")";
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+// Returns the argument types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+// Returns the result types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+// CallableOpInterface
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const auto &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0; i < results.size(); i++)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match graph result type (" << results[i]
+                         << ")"
+                         << " in graph @" << graph.getName();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.GLFClampOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 6fd20466e36e3..40a85dca60939 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -76,10 +76,36 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
       abiInfo.getBinding());
 }
 
+/// Creates a global variable for an argument or result based on the ABI info.
+static spirv::GlobalVariableOp
+createGlobalVarForGraphEntryPoint(OpBuilder &builder, spirv::GraphARMOp graphOp,
+                                  unsigned index, bool isArg,
+                                  spirv::InterfaceVarABIAttr abiInfo) {
+  auto spirvModule = graphOp->getParentOfType<spirv::ModuleOp>();
+  if (!spirvModule)
+    return nullptr;
+
+  OpBuilder::InsertionGuard moduleInsertionGuard(builder);
+  builder.setInsertionPoint(graphOp.getOperation());
+  std::string varName = graphOp.getName().str() + (isArg ? "_arg_" : "_res_") +
+                        std::to_string(index);
+
+  auto varType = isArg ? graphOp.getFunctionType().getInput(index)
+                       : graphOp.getFunctionType().getResult(index);
+
+  auto pointerType = spirv::PointerType::get(
+      varType,
+      abiInfo.getStorageClass().value_or(spirv::StorageClass::UniformConstant));
+
+  return builder.create<spirv::GlobalVariableOp>(
+      graphOp.getLoc(), pointerType, varName, abiInfo.getDescriptorSet(),
+      abiInfo.getBinding());
+}
+
 /// Gets the global variables that need to be specified as interface variable
 /// with an spirv.EntryPointOp. Traverses the body of a entry function to do so.
 static LogicalResult
-getInterfaceVariables(spirv::FuncOp funcOp,
+getInterfaceVariables(mlir::FunctionOpInterface funcOp,
                       SmallVectorImpl<Attribute> &interfaceVars) {
   auto module = funcOp->getParentOfType<spirv::ModuleOp>();
   if (!module) {
@@ -215,6 +241,21 @@ class ProcessInterfaceVarABI final : public OpConversionPattern<spirv::FuncOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+/// A pattern to convert graph signature according to interface variable ABI
+/// attributes.
+///
+/// Specifically, this pattern creates global variables according to interface
+/// variable ABI attributes attached to graph arguments and results.
+class ProcessGraphInterfaceVarABI final
+    : public OpConversionPattern<spirv::GraphARMOp> {
+public:
+  using OpConversionPattern<spirv::GraphARMOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Pass to implement the ABI information specified as attributes.
 class LowerABIAttributesPass final
     : public spirv::impl::SPIRVLowerABIAttributesPassBase<
@@ -288,6 +329,89 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
   return success();
 }
 
+namespace {
+
+/// Lowers the graph entry point
+LogicalResult lowerGraphEntryPoint(OpBuilder &builder,
+                                   spirv::GraphARMOp graphOp,
+                                   ArrayRef<Attribute> interfaceVars) {
+  if (!graphOp.getEntryPoint().value_or(false)) {
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  builder.setInsertionPoint(graphOp);
+  builder.create<spirv::GraphEntryPointARMOp>(graphOp.getLoc(), graphOp,
+                                              interfaceVars);
+  return success();
+}
+} // namespace
+
+LogicalResult ProcessGraphInterfaceVarABI::matchAndRewrite(
+    spirv::GraphARMOp graphOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (!graphOp.getEntryPoint().value_or(false)) {
+    // Non-entry point graphs are not handled.
+    return failure();
+  }
+  TypeConverter::SignatureConversion signatureConverter(
+      graphOp.getFunctionType().getNumInputs());
+
+  auto attrName = spirv::getInterfaceVarABIAttrName();
+
+  SmallVector<Attribute, 2> interfaceVars;
+
+  // Convert arguments
+  for (const auto &argType :
+       llvm::enumerate(graphOp.getFunctionType().getInputs())) {
+    auto abiInfo = graphOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
+        argType.index(), attrName);
+    if (!abiInfo) {
+      // Non-entry point graphs are not handled in this ABI lowering and will
+      // produce an error.
+      return failure();
+    }
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, argType.index(), true, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  for (const auto &resType :
+       llvm::enumerate(graphOp.getFunctionType().getResults())) {
+    auto abiInfo = graphOp.getResultAttrOfType<spirv::InterfaceVarABIAttr>(
+        resType.index(), attrName);
+    if (!abiInfo) {
+      // Non-entry point graphs are not handled in this ABI lowering and will
+      // produce an error.
+      return failure();
+    }
+    spirv::GlobalVariableOp var = createGlobalVarForGraphEntryPoint(
+        rewriter, graphOp, resType.index(), false, abiInfo);
+    if (!var)
+      return failure();
+    interfaceVars.push_back(
+        SymbolRefAttr::get(rewriter.getContext(), var.getSymName()));
+  }
+
+  // Creates a new function with the update signature.
+  rewriter.modifyOpInPlace(graphOp, [&] {
+    for (const auto &argType :
+         llvm::enumerate(graphOp.getFunctionType().getInputs())) {
+      graphOp.removeArgAttr(argType.index(), attrName);
+    }
+    for (const auto &resType :
+         llvm::enumerate(graphOp.getFunctionType().getResults())) {
+      graphOp.removeResultAttr(resType.index(),
+                               rewriter.getStringAttr(attrName));
+    }
+  });
+
+  return lowerGraphEntryPoint(rewriter, graphOp, interfaceVars);
+}
+
 void LowerABIAttributesPass::runOnOperation() {
   // Uses the signature conversion methodology of the dialect conversion
   // framework to implement the conversion.
@@ -314,6 +438,7 @@ void LowerABIAttributesPass::runOnOperation() {
 
   RewritePatternSet patterns(context);
   patterns.add<ProcessInterfaceVarABI>(typeConverter, context);
+  patterns.add<ProcessGraphInterfaceVarABI>(typeConverter, context);
 
   ConversionTarget target(*context);
   // "Legal" function ops should have no interface variable ABI attributes.
@@ -324,6 +449,17 @@ void LowerABIAttributesPass::runOnOperation() {
         return false;
     return true;
   });
+  target.addDynamicallyLegalOp<spirv::GraphARMOp>([&](spirv::GraphARMOp op) {
+    StringRef attrName = spirv::getInterfaceVarABIAttrName();
+    for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i)
+      if (op.getArgAttr(i, attrName))
+        return false;
+    for (unsigned i = 0, e = op.getNumResults(); i < e; ++i)
+      if (op.getResultAttr(i, attrName))
+        return false;
+    return true;
+  });
+
   // All other SPIR-V ops are legal.
   target.markUnknownOpDynamicallyLegal([](Operation *op) {
     return op->getDialect()->getNamespace() ==
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 095db6b815f51..d636ea29fe019 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -154,6 +154,14 @@ void UpdateVCEPass::runOnOperation() {
     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
       valueTypes.push_back(globalVar.getType());
 
+    // If the op is FunctionLike make sure to process input and result types
+    if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
+      auto inputTypes = funcOpInterface.getArgumentTypes();
+      auto resultTypes = funcOpInterface.getResultTypes();
+      valueTypes.append(inputTypes.begin(), inputTypes.end());
+      valueTypes.append(resultTypes.begin(), resultTypes.end());
+    }
+
     // Requirements from values' types
     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad290a1981..58fa14c0f5251 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,7 +104,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      llvm::isa<FunctionType>(op->getResult(0).getType()))
+      (llvm::isa<FunctionType>(op->getResult(0).getType()) ||
+       llvm::isa<GraphType>(op->getResult(0).getType())))
     wrapped = true;
 
   if (wrapped)
@@ -2837,6 +2838,20 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         os << '>';
       })
       .Case<NoneType>([&](Type) { os << "none"; })
+      .Case<GraphType>([&](GraphType graphTy) {
+        os << '(';
+        interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
+        os << ") -> ";
+        ArrayRef<Type> results = graphTy.getResults();
+        if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
+                                     llvm::isa<GraphType>(results[0]))) {
+          printType(results[0]);
+        } else {
+          os << '(';
+          interleaveComma(results, [&](Type ty) { printType(ty); });
+          os << ')';
+        }
+      })
       .Default([&](Type type) { return printDialectType(type); });
 }
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index f657db142eeb9..3d366276b4375 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
   return FunctionType::get(context, inputs, results);
 }
 
+GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) {
+  return GraphType::get(context, inputs, results);
+}
+
 TupleType Builder::getTupleType(TypeRange elementTypes) {
   return TupleType::get(context, elementTypes);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..0ed2549bcc9ba 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -179,6 +179,49 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
   return clone(newArgTypes, newResultTypes);
 }
 
+//===----------------------------------------------------------------------===//
+// GraphType
+//===----------------------------------------------------------------------===//
+
+unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
+
+ArrayRef<Type> GraphType::getInputs() const {
+  return getImpl()->getInputs();
+}
+
+unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
+
+ArrayRef<Type> GraphType::getResults() const {
+  return getImpl()->getResults();
+}
+
+GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
+  return get(getContext(), inputs, results);
+}
+
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+                                           TypeRange argTypes,
+                                           ArrayRef<unsigned> resultIndices,
+                                           TypeRange resultTypes) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes =
+      insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
+  TypeRange newResultTypes =
+      insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
+
+/// Returns a new function type without the specified arguments and results.
+GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
+                                            const BitVector &resultIndices) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
+  TypeRange newResultTypes =
+      filterTypesOut(getResults(), resultIndices, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 55d6a380d0bff..abe6d4bc7040b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -71,6 +71,12 @@ Value spirv::Deserializer::getValue(uint32_t id) {
   if (auto undef = getUndefType(id)) {
     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
   }
+  if (auto graphConstantARMInfo = getGraphConstantARM(id)) {
+    auto graphConstantID = graphConstantARMInfo->graphConstantID;
+    auto resultType = graphConstantARMInfo->resultType;
+    return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
+                                                       graphConstantID);
+  }
   return valueMap.lookup(id);
 }
 
@@ -165,6 +171,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
   case spirv::Opcode::OpTypeTensorARM:
+  case spirv::Opcode::OpTypeGraphARM:
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
@@ -189,12 +196,26 @@ LogicalResult spirv::Deserializer::processInstruction(
     return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantNull:
     return processConstantNull(operands);
+  case spirv::Opcode::OpGraphConstantARM:
+    return processGraphConstantARM(operands);
   case spirv::Opcode::OpDecorate:
     return processDecoration(operands);
   case spirv::Opcode::OpMemberDecorate:
     return processMemberDecoration(operands);
   case spirv::Opcode::OpFunction:
     return processFunction(operands);
+  case spirv::Opcode::OpGraphEntryPointARM:
+    if (deferInstructions) {
+      deferredInstructions.emplace_back(opcode, operands);
+      return success();
+    }
+    return processGraphEntryPointARM(operands);
+  case spirv::Opcode::OpGraphARM:
+    return processGraphARM(operands);
+  case spirv::Opcode::OpGraphSetOutputARM:
+    return processOpGraphSetOutputARM(operands);
+  case spirv::Opcode::OpGraphEndARM:
+    return processGraphARMEnd(operands);
   case spirv::Opcode::OpLabel:
     return processLabel(operands);
   case spirv::Opcode::OpBranch:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b1abd8b3dffe9..de3dc19349642 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -670,6 +670,213 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
+  unsigned wordIndex = 0;
+  if (wordIndex >= operands.size()) {
+    return emitError(unknownLoc,
+                     "missing graph defintion in OpGraphEntryPointARM");
+  }
+
+  uint32_t grID = operands[wordIndex++];
+  if (!graphMap.count(grID)) {
+    return emitError(unknownLoc,
+                     "missing graph definition/declaration with id ")
+           << grID;
+  }
+
+  spirv::GraphARMOp graphARM = graphMap[grID];
+  StringRef name = decodeStringLiteral(operands, wordIndex);
+  graphARM.setSymName(name);
+  graphARM.setEntryPoint(true);
+
+  SmallVector<Attribute, 4> interface;
+  while (wordIndex < operands.size()) {
+    auto arg = getGlobalVariable(operands[wordIndex]);
+    if (!arg) {
+      return emitError(unknownLoc, "undefined result <id> ")
+             << operands[wordIndex] << " while decoding OpGraphEntryPoint";
+    }
+    interface.push_back(SymbolRefAttr::get(arg.getOperation()));
+    wordIndex++;
+  }
+
+  // RAII guard to reset the insertion point to previous value when done.
+  OpBuilder::InsertionGuard insertionGuard(opBuilder);
+  opBuilder.setInsertionPoint(graphARM);
+  opBuilder.create<spirv::GraphEntryPointARMOp>(
+      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
+      opBuilder.getArrayAttr(interface));
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
+  if (curGraph) {
+    return emitError(unknownLoc, "found graph inside graph");
+  }
+  // Get the result type
+  if (operands.size() < 2) {
+    return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
+  }
+
+  Type grType = getType(operands[0]);
+  if (!grType || !llvm::isa<GraphType>(grType)) {
+    return emitError(unknownLoc, "unknown graph type from <id> ")
+           << operands[0];
+  }
+  auto graphType = llvm::cast<GraphType>(grType);
+  if (graphType.getNumResults() <= 0) {
+    return emitError(unknownLoc, "expected at least one result");
+  }
+
+  uint32_t grID = operands[1];
+  if (graphMap.count(grID)) {
+    return emitError(unknownLoc, "duplicate graph definition/declaration");
+  }
+
+  std::string grName = getGraphSymbol(grID);
+  auto graphOp =
+      opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
+  curGraph = graphMap[grID] = graphOp;
+  auto *entryBlock = graphOp.addEntryBlock();
+  LLVM_DEBUG({
+    logger.startLine()
+        << "//===-------------------------------------------===//\n";
+    logger.startLine() << "[graph] name: " << grName << "\n";
+    logger.startLine() << "[graph] type: " << grType << "\n";
+    logger.startLine() << "[graph] ID: " << grID << "\n";
+    logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
+    logger.indent();
+  });
+
+  // Parse the op argument instructions
+  if (graphType.getNumInputs()) {
+    for (size_t i = 0, e = graphType.getNumInputs(); i != e; ++i) {
+      auto argType = graphType.getInput(i);
+      spirv::Opcode opcode = spirv::Opcode::OpNop;
+      ArrayRef<uint32_t> operands;
+      if (failed(sliceInstruction(opcode, operands,
+                                  spirv::Opcode::OpGraphInputARM))) {
+        return failure();
+      }
+      if (opcode != spirv::Opcode::OpGraphInputARM) {
+        return emitError(unknownLoc,
+                         "missing OpGraphInputARM instruction for argument ")
+               << i;
+      }
+
+      if (operands.size() != 3) {
+        return emitError(unknownLoc, "expected result type, result <id> and "
+                                     "input index for OpGraphInputARM");
+      }
+
+      auto argDefinedType = getType(operands[0]);
+      if (!argDefinedType) {
+        return emitError(unknownLoc, "unknown operand type <id> ")
+               << operands[0];
+      }
+
+      if (argDefinedType != argType) {
+        return emitError(unknownLoc,
+                         "mismatch in argument type between graph type "
+                         "definition ")
+               << graphType << " and argument type definition "
+               << argDefinedType << " at argument " << i;
+      }
+      if (getValue(operands[1])) {
+        return emitError(unknownLoc, "duplicate definition of result <id> ")
+               << operands[1];
+      }
+
+      auto inputIndexAttr = getConstantInt(operands[2]);
+      if (inputIndexAttr == nullptr) {
+        return emitError(unknownLoc,
+                         "unable to read inputIndex value from constant op ")
+               << operands[2];
+      }
+      auto argValue = graphOp.getArgument(inputIndexAttr.getInt());
+      valueMap[operands[1]] = argValue;
+    }
+  }
+
+  graphOutputs.resize(graphType.getNumResults());
+
+  // RAII guard to reset the insertion point to the module's region after
+  // deserializing the body of this function.
+  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
+
+  spirv::Opcode opcode = spirv::Opcode::OpNop;
+
+  blockMap[grID] = entryBlock;
+  if (failed(createGraphBlock(grID))) {
+    return failure();
+  }
+
+  // Process all the instructions in the graph until and including
+  // OpGraphEndARM.
+  ArrayRef<uint32_t> instOperands;
+  do {
+    if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
+      return failure();
+    }
+
+    if (failed(processInstruction(opcode, instOperands))) {
+      return failure();
+    }
+  } while (opcode != spirv::Opcode::OpGraphEndARM);
+
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
+
+  if (operands.size() != 2) {
+    return emitError(
+        unknownLoc,
+        "expected value id and output index for OpGraphSetOutputARM");
+  }
+
+  auto id = operands[0];
+  auto value = getValue(id);
+  if (!value) {
+    return emitError(unknownLoc, "could not find result <id> ") << id;
+  }
+
+  auto outputIndexAttr = getConstantInt(operands[1]);
+  if (outputIndexAttr == nullptr) {
+    return emitError(unknownLoc,
+                     "unable to read outputIndex value from constant op ")
+           << operands[1];
+  }
+  graphOutputs[outputIndexAttr.getInt()] = value;
+  return success();
+}
+
+LogicalResult
+spirv::Deserializer::processGraphARMEnd(ArrayRef<uint32_t> operands) {
+  // Create GraphOutputsARM instruction
+  opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
+
+  // Process OpGraphEndARM.
+  if (!operands.empty()) {
+    return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
+  }
+
+  curBlock = nullptr;
+  curGraph = std::nullopt;
+  graphOutputs.clear();
+
+  LLVM_DEBUG({
+    logger.unindent();
+    logger.startLine()
+        << "//===-------------------------------------------===//\n";
+  });
+  return success();
+}
+
 std::optional<std::pair<Attribute, Type>>
 spirv::Deserializer::getConstant(uint32_t id) {
   auto constIt = constantMap.find(id);
@@ -694,6 +901,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
   return funcName;
 }
 
+std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
+  auto graphName = nameMap.lookup(id).str();
+  if (graphName.empty()) {
+    graphName = "spirv_graph_" + std::to_string(id);
+  }
+  return graphName;
+}
+
 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
   auto constName = nameMap.lookup(id).str();
   if (constName.empty()) {
@@ -716,6 +931,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
   return op;
 }
 
+std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+spirv::Deserializer::getGraphConstantARM(uint32_t id) {
+  auto graphConstIt = graphConstantMap.find(id);
+  if (graphConstIt == graphConstantMap.end())
+    return std::nullopt;
+  return graphConstIt->getSecond();
+}
+
 LogicalResult
 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
   unsigned wordIndex = 0;
@@ -937,6 +1160,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processMatrixType(operands);
   case spirv::Opcode::OpTypeTensorARM:
     return processTensorARMType(operands);
+  case spirv::Opcode::OpTypeGraphARM:
+    return processGraphTypeARM(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -1289,6 +1514,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
+  unsigned size = operands.size();
+  if (size < 2) {
+    return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
+                                 "(result_id, num_inputs, (inout0_type, "
+                                 "inout1_type, ...))")
+           << size;
+  }
+  uint32_t numInputs = operands[1];
+  SmallVector<Type, 1> argTypes;
+  SmallVector<Type, 1> returnTypes;
+  for (unsigned i = 2; i < size; i++) {
+    Type inOutTy = getType(operands[i]);
+    if (!inOutTy) {
+      return emitError(unknownLoc,
+                       "OpTypeGraphARM references undefined element type.")
+             << operands[i];
+    }
+    if (i - 2 >= numInputs) {
+      returnTypes.push_back(inOutTy);
+    } else {
+      argTypes.push_back(inOutTy);
+    }
+  }
+  typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2)
@@ -1699,6 +1953,38 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
          << resultType;
 }
 
+LogicalResult
+spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc)
+           << "OpGraphConstantARM must have type <id> and result <id>";
+  }
+  if (operands.size() < 3) {
+    return emitError(unknownLoc)
+           << "OpGraphConstantARM must have at least 1 more parameter";
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto resultID = operands[1];
+
+  if (!llvm::dyn_cast<spirv::TensorArmType>(resultType)) {
+    return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
+  }
+
+  APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
+  Type i32Ty = opBuilder.getIntegerType(32);
+  auto attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
+  graphConstantMap.try_emplace(
+      resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Control flow
 //===----------------------------------------------------------------------===//
@@ -1796,6 +2082,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
+  if (!curGraph) {
+    return emitError(unknownLoc, "a graph block must appear inside a graph");
+  }
+
+  // We may have forward declared this block.
+  auto *block = getOrCreateBlock(graphID);
+  LLVM_DEBUG(logger.startLine()
+             << "[block] populating block " << block << "\n");
+  // If we have seen this block, make sure it was just a forward declaration.
+  assert(block->empty() && "re-deserialize the same block!");
+
+  opBuilder.setInsertionPointToStart(block);
+  blockMap[graphID] = curBlock = block;
+
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
   if (!curBlock) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 1bc9e4a3c75d8..90740112c8d13 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -105,6 +105,13 @@ struct SpecConstOperationMaterializationInfo {
   SmallVector<uint32_t> enclosedOpOperands;
 };
 
+/// A struct that collects the info needed to materialize/emit a
+/// GraphConstantARMOp.
+struct GraphConstantARMOpMaterializationInfo {
+  Type resultType;
+  IntegerAttr graphConstantID;
+};
+
 //===----------------------------------------------------------------------===//
 // Deserializer Declaration
 //===----------------------------------------------------------------------===//
@@ -205,9 +212,14 @@ class Deserializer {
   /// exists; otherwise creates one based on the <id>.
   std::string getFunctionSymbol(uint32_t id);
 
-  /// Returns a symbol to be used for the specialization constant with the given
-  /// result <id>. This tries to use the specialization constant's OpName if
+  /// Returns a symbol to be used for the graph name with the given
+  /// result <id>. This tries to use the graph's OpName if
   /// exists; otherwise creates one based on the <id>.
+  std::string getGraphSymbol(uint32_t id);
+
+  /// Returns a symbol to be used for the specialization constant with the
+  /// given result <id>. This tries to use the specialization constant's
+  /// OpName if exists; otherwise creates one based on the <id>.
   std::string getSpecConstantSymbol(uint32_t id);
 
   /// Gets the specialization constant with the given result <id>.
@@ -224,6 +236,11 @@ class Deserializer {
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
                                            TypedAttr defaultValue);
 
+  /// Gets the GraphConstantARM ID attribute and result type with the given
+  /// result <id>.
+  std::optional<spirv::GraphConstantARMOpMaterializationInfo>
+  getGraphConstantARM(uint32_t id);
+
   /// Processes the OpVariable instructions at current `offset` into `binary`.
   /// It is expected that this method is used for variables that are to be
   /// defined at module scope and will be deserialized into a
@@ -293,6 +310,16 @@ class Deserializer {
 
   LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processGraphTypeARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphEntryPointARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processOpGraphSetOutputARM(ArrayRef<uint32_t> operands);
+
+  LogicalResult processGraphARMEnd(ArrayRef<uint32_t> operands);
+
   LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
 
   //===--------------------------------------------------------------------===//
@@ -330,6 +357,10 @@ class Deserializer {
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
+  /// Processes a SPIR-V OpGraphConstantARM instruction with the given
+  /// `operands`.
+  LogicalResult processGraphConstantARM(ArrayRef<uint32_t> operands);
+
   //===--------------------------------------------------------------------===//
   // Debug
   //===--------------------------------------------------------------------===//
@@ -427,6 +458,9 @@ class Deserializer {
   /// blocks declared as selection/loop headers are handled.
   LogicalResult structurizeControlFlow();
 
+  /// Creates a block for graph with the given graphID
+  LogicalResult createGraphBlock(uint32_t graphID);
+
   //===--------------------------------------------------------------------===//
   // Instruction
   //===--------------------------------------------------------------------===//
@@ -523,6 +557,9 @@ class Deserializer {
   /// The current function under construction.
   std::optional<spirv::FuncOp> curFunction;
 
+  /// The current graph under construction.
+  std::optional<spirv::GraphARMOp> curGraph;
+
   /// The current block under construction.
   Block *curBlock = nullptr;
 
@@ -560,12 +597,19 @@ class Deserializer {
   DenseMap<uint32_t, SpecConstOperationMaterializationInfo>
       specConstOperationMap;
 
+  // Result <id> to GraphConstantARM ID attribute and result type.
+  DenseMap<uint32_t, spirv::GraphConstantARMOpMaterializationInfo>
+      graphConstantMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
   // Result <id> to function mapping.
   DenseMap<uint32_t, spirv::FuncOp> funcMap;
 
+  // Result <id> to function mapping.
+  DenseMap<uint32_t, spirv::GraphARMOp> graphMap;
+
   // Result <id> to block mapping.
   DenseMap<uint32_t, Block *> blockMap;
 
@@ -629,6 +673,9 @@ class Deserializer {
   /// Deserialization options.
   DeserializationOptions options;
 
+  /// List of IDs assigned to graph outputs.
+  SmallVector<Value> graphOutputs;
+
 #ifndef NDEBUG
   /// A logger used to emit information during the deserialzation process.
   llvm::ScopedPrinter logger;
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..ffed7ad7ec8a0 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -161,6 +161,16 @@ Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
   return success();
 }
 
+LogicalResult
+Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
+  if (auto resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
+                                             op.getGraphConstantIdAttr())) {
+    valueIDMap[op.getResult()] = resultID;
+    return success();
+  }
+  return failure();
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -326,6 +336,122 @@ LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
   return success();
 }
 
+LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
+
+  if (op.getNumResults() < 1) {
+    return op.emitError("cannot serialize graph with no return types");
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
+  assert(functionHeader.empty() && functionBody.empty());
+
+  uint32_t funcID = getOrCreateFunctionID(op.getName());
+  uint32_t fnTypeID = 0;
+  // Generate type of the function.
+  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
+    return failure();
+  encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
+                        {fnTypeID, funcID});
+
+  // Declare the parameters.
+  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
+    uint32_t argTypeID = 0;
+    SmallVector<uint32_t, 3> inputOperands;
+
+    if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
+      return failure();
+    }
+
+    uint32_t argValueID = getNextID();
+    valueIDMap[arg] = argValueID;
+
+    auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+    auto indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+    inputOperands.push_back(argTypeID);
+    inputOperands.push_back(argValueID);
+    inputOperands.push_back(indexID);
+
+    encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
+                          inputOperands);
+  }
+
+  // Process the body.
+  if (op.isExternal()) {
+    return op.emitError("external function is unhandled");
+  }
+
+  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
+    return failure();
+  if (failed(visitInPrettyBlockOrder(
+          &op.front(), [&](Block *block) { return processBlock(block); },
+          /*skipHeader=*/true))) {
+    return failure();
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
+                          << "' --\n");
+  // Insert OpFunctionEnd.
+  encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
+
+  graphs.append(functionHeader.begin(), functionHeader.end());
+  graphs.append(functionBody.begin(), functionBody.end());
+  functionHeader.clear();
+  functionBody.clear();
+
+  return success();
+}
+
+LogicalResult
+Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
+  SmallVector<uint32_t, 4> operands;
+  auto graph = op.getFn();
+  // Add the graph <id>.
+  uint32_t graphID = getOrCreateFunctionID(graph);
+  operands.push_back(graphID);
+  // Add the name of the graph.
+  spirv::encodeStringLiteralInto(operands, graph);
+
+  // Add the interface values.
+  if (auto interface = op.getInterface()) {
+    for (auto var : interface.getValue()) {
+      auto id = getVariableID(llvm::cast<FlatSymbolRefAttr>(var).getValue());
+      if (!id) {
+        return op.emitError(
+            "referencing undefined global variable."
+            "spirv.GraphEntryPointARM is at the end of spirv.module. All "
+            "referenced variables should already be defined");
+      }
+      operands.push_back(id);
+    }
+  }
+  encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
+  return success();
+}
+
+LogicalResult Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
+  for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
+    SmallVector<uint32_t, 2> outputOperands;
+
+    auto resType = value.getType();
+    uint32_t resTypeID = 0;
+    if (failed(processType(op.getLoc(), resType, resTypeID))) {
+      return failure();
+    }
+
+    uint32_t outputID = getValueID(value);
+    auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
+    auto indexID = prepareConstantInt(op.getLoc(), attr, false);
+
+    outputOperands.push_back(outputID);
+    outputOperands.push_back(indexID);
+
+    encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
+                          outputOperands);
+  }
+  return success();
+}
+
 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
   SmallVector<uint32_t, 4> operands;
   SmallVector<StringRef, 2> elidedAttrs;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index ebebd2d283afa..cbbfdc247b55f 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -115,7 +115,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
                     extensions.size() + extendedSets.size() +
                     memoryModel.size() + entryPoints.size() +
                     executionModes.size() + decorations.size() +
-                    typesGlobalValues.size() + functions.size();
+                    typesGlobalValues.size() + functions.size() + graphs.size();
 
   binary.clear();
   binary.reserve(moduleSize);
@@ -133,6 +133,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
   binary.append(decorations.begin(), decorations.end());
   binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
   binary.append(functions.begin(), functions.end());
+  binary.append(graphs.begin(), graphs.end());
 }
 
 #ifndef NDEBUG
@@ -457,9 +458,12 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
   auto typeEnum = spirv::Opcode::OpTypeVoid;
   bool deferSerialization = false;
 
-  if ((isa<FunctionType>(type) &&
-       succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
-                                     operands))) ||
+  if ((llvm::isa<FunctionType>(type) &&
+       succeeded(prepareFunctionType(loc, llvm::cast<FunctionType>(type),
+                                     typeEnum, operands))) ||
+      (llvm::isa<GraphType>(type) &&
+       succeeded(prepareGraphType(loc, llvm::cast<GraphType>(type), typeEnum,
+                                  operands))) ||
       succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
                                  deferSerialization, serializationCtx))) {
     if (deferSerialization)
@@ -490,6 +494,7 @@ Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
     return success();
   }
 
+  emitError(loc, "failed to process type: ") << type;
   return failure();
 }
 
@@ -805,6 +810,35 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
   return success();
 }
 
+LogicalResult
+Serializer::prepareGraphType(Location loc, GraphType type,
+                             spirv::Opcode &typeEnum,
+                             SmallVectorImpl<uint32_t> &operands) {
+  typeEnum = spirv::Opcode::OpTypeGraphARM;
+  assert(type.getNumResults() >= 1 &&
+         "serialization requires at least a return value");
+
+  operands.push_back(type.getNumInputs());
+
+  for (auto &res : type.getInputs()) {
+    uint32_t argTypeID = 0;
+    if (failed(processType(loc, res, argTypeID))) {
+      return failure();
+    }
+    operands.push_back(argTypeID);
+  }
+
+  for (auto &res : type.getResults()) {
+    uint32_t resultID = 0;
+    if (failed(processType(loc, res, resultID))) {
+      return failure();
+    }
+    operands.push_back(resultID);
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Constant
 //===----------------------------------------------------------------------===//
@@ -1056,6 +1090,41 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
   return resultID;
 }
 
+uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
+                                            IntegerAttr intAttr) {
+  // De-duplicate graph constants.
+  if (auto id = getGraphConstantARMId(intAttr)) {
+    return id;
+  }
+
+  // Process the type for this graph constant.
+  uint32_t typeID = 0;
+  if (failed(processType(loc, graphConstType, typeID))) {
+    return 0;
+  }
+
+  auto resultID = getNextID();
+  APInt value = intAttr.getValue();
+  unsigned bitwidth = value.getBitWidth();
+  if (bitwidth > 32) {
+    emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
+        << bitwidth << " bits";
+    return 0;
+  }
+  bool isSigned = value.isSignedIntN(bitwidth);
+
+  uint32_t word = 0;
+  if (isSigned) {
+    word = static_cast<int32_t>(value.getSExtValue());
+  } else {
+    word = static_cast<uint32_t>(value.getZExtValue());
+  }
+  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
+                        {typeID, resultID, word});
+  graphConstIDMap[intAttr] = resultID;
+  return resultID;
+}
+
 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
                                        bool isSpec) {
   if (!isSpec) {
@@ -1329,9 +1398,18 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       })
       .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
       .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
+      .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
+      .Case([&](spirv::GraphEntryPointARMOp op) {
+        return processGraphEntryPointARMOp(op);
+      })
+      .Case(
+          [&](spirv::GraphOutputsARMOp op) { return processGraphOutputsARMOp(op); })
       .Case([&](spirv::GlobalVariableOp op) {
         return processGlobalVariableOp(op);
       })
+      .Case([&](spirv::GraphConstantARMOp op) {
+        return processGraphConstantARMOp(op);
+      })
       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..e26e873f02daa 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -116,6 +116,8 @@ class Serializer {
   LogicalResult
   processSpecConstantOperationOp(spirv::SpecConstantOperationOp op);
 
+  LogicalResult processGraphConstantARMOp(spirv::GraphConstantARMOp op);
+
   /// SPIR-V dialect supports OpUndef using spirv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -129,6 +131,15 @@ class Serializer {
   LogicalResult processFuncOp(spirv::FuncOp op);
   LogicalResult processFuncParameter(spirv::FuncOp op);
 
+  /// Processes a SPIR-V GraphARM op.
+  LogicalResult processGraphARMOp(spirv::GraphARMOp op);
+
+  /// Processes a SPIR-V GraphEntryPointARM op.
+  LogicalResult processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op);
+
+  /// Processes a SPIR-V GraphOutputsARMOp op.
+  LogicalResult processGraphOutputsARMOp(spirv::GraphOutputsARMOp op);
+
   LogicalResult processVariableOp(spirv::VariableOp op);
 
   /// Process a SPIR-V GlobalVariableOp
@@ -183,6 +194,10 @@ class Serializer {
                                     spirv::Opcode &typeEnum,
                                     SmallVectorImpl<uint32_t> &operands);
 
+  LogicalResult prepareGraphType(Location loc, GraphType type,
+                                 spirv::Opcode &typeEnum,
+                                 SmallVectorImpl<uint32_t> &operands);
+
   //===--------------------------------------------------------------------===//
   // Constant
   //===--------------------------------------------------------------------===//
@@ -227,6 +242,13 @@ class Serializer {
   uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
                               bool isSpec = false);
 
+  uint32_t getGraphConstantARMId(Attribute value) const {
+    return graphConstIDMap.lookup(value);
+  }
+
+  uint32_t prepareGraphConstantId(Location loc, Type graphConstType,
+                                  IntegerAttr intAttr);
+
   uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
                              bool isSpec = false);
 
@@ -323,7 +345,7 @@ class Serializer {
   spirv::ModuleOp module;
 
   /// An MLIR builder for getting MLIR constructs.
-  mlir::Builder mlirBuilder;
+  mlir::OpBuilder mlirBuilder;
 
   /// Serialization options.
   SerializationOptions options;
@@ -355,6 +377,7 @@ class Serializer {
   SmallVector<uint32_t, 0> decorations;
   SmallVector<uint32_t, 0> typesGlobalValues;
   SmallVector<uint32_t, 0> functions;
+  SmallVector<uint32_t, 0> graphs;
 
   /// Recursive struct references are serialized as OpTypePointer instructions
   /// to the recursive struct type. However, the OpTypePointer instruction
@@ -371,15 +394,22 @@ class Serializer {
       recursiveStructInfos;
 
   /// `functionHeader` contains all the instructions that must be in the first
-  /// block in the function, and `functionBody` contains the rest. After
-  /// processing FuncOp, the encoded instructions of a function are appended to
-  /// `functions`. An example of instructions in `functionHeader` in order:
+  /// block in the function or graph, and `functionBody` contains the rest.
+  /// After processing FuncOp/GraphARMOp, the encoded instructions of a function
+  /// or graph are appended to `functions` or `graphs` respectively. Examples of
+  /// instructions in `functionHeader` in order:
+  ///
+  /// For a FuncOp:
   /// OpFunction ...
   /// OpFunctionParameter ...
   /// OpFunctionParameter ...
   /// OpLabel ...
   /// OpVariable ...
   /// OpVariable ...
+  ///
+  /// For a GraphARMOp
+  /// OpGraphARM ...
+  /// OpGraphInputARM ...
   SmallVector<uint32_t, 0> functionHeader;
   SmallVector<uint32_t, 0> functionBody;
 
@@ -392,6 +422,9 @@ class Serializer {
   /// Map from specialization constant names to their <id>s.
   llvm::StringMap<uint32_t> specConstIDMap;
 
+  /// Map from graph constant ID value to their <id>s.
+  DenseMap<Attribute, uint32_t> graphConstIDMap;
+
   /// Map from GlobalVariableOps name to <id>s.
   llvm::StringMap<uint32_t> globalVarIDMap;
 
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 64ba8e3fc249e..77dbdfaca19b9 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -278,3 +278,20 @@ func.func @set_mesh_outputs(%0 : i32, %1 : i32) -> () {
   spirv.EXT.SetMeshOutputs %0, %1 : i32, i32
   spirv.Return
 }
+
+//===----------------------------------------------------------------------===//
+// GraphARM ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: graph_arm
+spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  // CHECK: spirv.ARM.GraphOutputs min version: v1.0
+  // CHECK: spirv.ARM.GraphOutputs max version: v1.6
+  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+  // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+// CHECK: spirv.ARM.Graph min version: v1.0
+// CHECK: spirv.ARM.Graph max version: v1.6
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
new file mode 100644
index 0000000000000..90c31e19db382
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphConstant
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
+  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
+
+  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
+  }
+
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 10fbcf06eb052..515162bf99aea 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -14,7 +14,7 @@ func.func @unknown_attr_on_region(%arg: i32 {spirv.something}) {
 
 // -----
 
-// expected-error @+1 {{cannot attach SPIR-V attributes to region result}}
+// expected-error @+1 {{found unsupported 'spirv.something' attribute on region argument}}
 func.func @unknown_attr_on_region() -> (i32 {spirv.something}) {
   %0 = arith.constant 10.0 : f32
   return %0: f32
@@ -101,6 +101,27 @@ func.func @interface_var(
 
 // -----
 
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+func.func @interface_var(%arg: f32) -> (
+    f32 {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+) { return %arg : f32 }
+
+// -----
+
+// CHECK: {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>}
+func.func @interface_var(%arg: f32) -> (
+    f32 {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>}
+) { return %arg : f32 }
+
+// -----
+
+// expected-error @+1 {{'spirv.interface_var_abi' attribute cannot specify storage class when attaching to a non-scalar value}}
+func.func @interface_var(%arg0 : memref<4xf32>) -> (
+  memref<4xf32> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1), Uniform>}
+) { return %arg0 : memref<4xf32> }
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.resource_limits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
index bd51a07843652..9f5694135d623 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir
@@ -35,6 +35,28 @@ spirv.module Logical GLSL450 {
 
 // -----
 
+module attributes {
+  spirv.target_env = #spirv.target_env<
+     #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: spirv.module
+spirv.module Logical Vulkan {
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+  //  CHECK-DAG:    spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<1x16x16x16xi8>, UniformConstant>
+
+  //      CHECK:    spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  //      CHECK:    spirv.ARM.Graph [[GN]]([[ARG0:%.*]]: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = true}
+  spirv.ARM.Graph @main(%arg0: !spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>})
+                  -> (!spirv.arm.tensor<1x16x16x16xi8> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}) attributes {entry_point = true} {
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+} // end spirv.module
+
+} // end module
+
+// -----
+
 module {
 // expected-error at +1 {{'spirv.module' op missing SPIR-V target env attribute}}
 spirv.module Logical GLSL450 {}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 2b237665ffc4a..2482af8927aa4 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes {
     spirv.ReturnValue %val : bf16
   }
 }
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.0, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
+    #spirv.resource_limits<>>
+} {
+  spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
+      spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
+  }
+}
diff --git a/mlir/test/Target/SPIRV/graph-ops.mlir b/mlir/test/Target/SPIRV/graph-ops.mlir
new file mode 100644
index 0000000000000..5b39d33cd49b9
--- /dev/null
+++ b/mlir/test/Target/SPIRV/graph-ops.mlir
@@ -0,0 +1,24 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+  }
+
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> attributes {entry_point = false} {
+  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 2e5e591fe5f91..9efca825a663d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -21,7 +21,7 @@ using namespace mlir;
 namespace {
 /// A pass for testing SPIR-V op availability.
 struct PrintOpAvailability
-    : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
+    : public PassWrapper<PrintOpAvailability, OperationPass<mlir::ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
 
   void runOnOperation() override;
@@ -33,12 +33,10 @@ struct PrintOpAvailability
 } // namespace
 
 void PrintOpAvailability::runOnOperation() {
-  auto f = getOperation();
-  llvm::outs() << f.getName() << "\n";
-
+  auto moduleOp = getOperation();
   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
 
-  f->walk([&](Operation *op) {
+  auto opCallback = [&](Operation *op) {
     if (op->getDialect() != spirvDialect)
       return WalkResult::advance();
 
@@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() {
     os.flush();
 
     return WalkResult::advance();
+  };
+
+  moduleOp.walk([&](func::FuncOp f) {
+    llvm::outs() << f.getName() << "\n";
+    f->walk(opCallback);
+  });
+
+  moduleOp.walk([&](spirv::GraphARMOp g) {
+    llvm::outs() << g.getName() << "\n";
+    g->walk(opCallback);
   });
 }
 

>From 6033bc847aa8e7008b7a74bc5dc2774a2d591c6d Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 10 Jul 2025 14:06:25 +0200
Subject: [PATCH 2/2] Resolve code formatting mistakes

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I07c5cad1f3092994af33ebbeda84e2018e03f6b7
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp      | 17 +++++++++--------
 mlir/lib/IR/BuiltinTypes.cpp                    | 10 +++-------
 .../Target/SPIRV/Serialization/SerializeOps.cpp |  3 ++-
 .../Target/SPIRV/Serialization/Serializer.cpp   |  5 +++--
 4 files changed, 17 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index e66d4b0ffc446..4b8ed08249b3a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1019,14 +1019,15 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(Operation *op,
-                                                        unsigned regionIndex,
-                                                        unsigned resultIndex,
-                                                        NamedAttribute attribute) {
+LogicalResult
+SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+                                          unsigned resultIndex,
+                                          NamedAttribute attribute) {
   auto funcOp = dyn_cast<FunctionOpInterface>(op);
   if (!funcOp)
-    return op->emitError("cannot attach SPIR-V attributes to region result which is "
-                         "not a FunctionOpInterface type");
-  return verifyRegionAttribute(
-      op->getLoc(), funcOp.getResultTypes()[resultIndex], attribute);
+    return op->emitError(
+        "cannot attach SPIR-V attributes to region result which is "
+        "not a FunctionOpInterface type");
+  return verifyRegionAttribute(op->getLoc(),
+                               funcOp.getResultTypes()[resultIndex], attribute);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 0ed2549bcc9ba..ce47c60c9b932 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -185,15 +185,11 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
 
 unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
 
-ArrayRef<Type> GraphType::getInputs() const {
-  return getImpl()->getInputs();
-}
+ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
 
 unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
 
-ArrayRef<Type> GraphType::getResults() const {
-  return getImpl()->getResults();
-}
+ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
 
 GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
   return get(getContext(), inputs, results);
@@ -215,7 +211,7 @@ GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
 
 /// Returns a new function type without the specified arguments and results.
 GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
-                                            const BitVector &resultIndices) {
+                                              const BitVector &resultIndices) {
   SmallVector<Type> argStorage, resultStorage;
   TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
   TypeRange newResultTypes =
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ffed7ad7ec8a0..4a8b10001ec02 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -429,7 +429,8 @@ Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
   return success();
 }
 
-LogicalResult Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
+LogicalResult
+Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
   for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
     SmallVector<uint32_t, 2> outputOperands;
 
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index cbbfdc247b55f..b5ba5c91885d9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1402,8 +1402,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::GraphEntryPointARMOp op) {
         return processGraphEntryPointARMOp(op);
       })
-      .Case(
-          [&](spirv::GraphOutputsARMOp op) { return processGraphOutputsARMOp(op); })
+      .Case([&](spirv::GraphOutputsARMOp op) {
+        return processGraphOutputsARMOp(op);
+      })
       .Case([&](spirv::GlobalVariableOp op) {
         return processGlobalVariableOp(op);
       })



More information about the Mlir-commits mailing list