[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (PR #151934)

Davide Grohmann llvmlistbot at llvm.org
Mon Aug 25 04:57:16 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/151934

>From f3eb19fdd5acf04635e842ce0b002c607047d688 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 4 Aug 2025 10:42:01 +0200
Subject: [PATCH 1/8] [mlir][spirv] Add support for SPV_ARM_graph extension -
 part 1
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is the first patch to add support for the SPV_ARM_graph SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
Graph abstraction for expressing dataflow computations over full
resources.

The part 1 implementation includes:

    A new GraphType, modeled similarly to FunctionType, for typed graph signatures.
    New operations in the spirv.arm namespace:
        spirv.arm.Graph
        spirv.arm.GraphEntryPoint
        spirv.arm.GraphConstant
        spirv.arm.GraphOutput
    Verifier and VCE updates to properly gate usage under SPV_ARM_graph.
    Tests covering parsing, verification.

Graphs currently support only SPV_ARM_tensors, but are designed to
generalize to other resource types, such as images.

Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia74b7ab0161b03d3d4702e93c34d7f55cd295a5f
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  26 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 201 +++++++++++++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |   1 +
 mlir/include/mlir/IR/Builders.h               |   2 +
 mlir/include/mlir/IR/BuiltinTypes.td          |  18 +-
 mlir/include/mlir/IR/CommonTypeConstraints.td |   7 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |   7 +-
 .../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp    |  12 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 230 ++++++++++++++++++
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |   8 +
 mlir/lib/IR/AsmPrinter.cpp                    |  17 +-
 mlir/lib/IR/Builders.cpp                      |   4 +
 mlir/lib/IR/BuiltinTypes.cpp                  |  39 +++
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  17 ++
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     |  30 +++
 .../SPIRV/Transforms/vce-deduction.mlir       |  11 +
 .../lib/Dialect/SPIRV/TestAvailability.cpp    |  18 +-
 17 files changed, 630 insertions(+), 18 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
 create mode 100644 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index bdfd728d1d0b3..a27554a3c6f64 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
 def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph                        : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
 
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
-      SPV_ARM_tensors,
+      SPV_ARM_tensors, SPV_ARM_graph,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
     Extension<[SPV_ARM_tensors]>
   ];
 }
+def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_graph]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
       SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
-      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
 def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM               : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM             : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM                       : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM                  : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM              : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM                    : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM                   : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
       SPIRV_OC_OpGroupNonUniformLogicalXor,
       SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+      SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+      SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
       SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4862,6 +4879,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
+
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..38fb4b2eff414
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,201 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_ArmVendorOp<mnemonic, traits> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+    AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+    FunctionOpInterface, InModuleScope, IsolatedFromAbove
+  ]> {
+
+  let summary = "Declare or define a SPIR-V graph";
+
+  let description = [{
+    This op declares or defines a SPIR-V graph using one region, which
+    contains one or more blocks.
+
+    Different from the SPIR-V binary format, this op is not allowed to
+    implicitly capture global values, and all external references must use
+    function arguments or symbol references. This op itself defines a symbol
+    that is unique in the enclosing module op.
+
+    This op itself takes no operands and generates no results. Its region
+    can take zero or more arguments and return zero or more values.
+
+    ```
+    spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+                        region
+    ```
+  }];
+
+  let arguments = (ins
+    TypeAttrOf<GraphType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
+    OptionalAttr<BoolAttr>:$entry_point,
+    StrAttr:$sym_name
+  );
+
+  let results = (outs);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasVerifier = 0;
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,  CArg<"bool", "false">:$entry_point)>];
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+
+  let extraClassDeclaration = [{
+    /// Hook for FunctionOpInterface, called after verifying that the 'type'
+    /// attribute is present and checks if it holds a function type. Ensures
+    /// getType, getNumArguments, and getNumResults can be called safely
+    LogicalResult verifyType();
+
+    /// Hook for FunctionOpInterface, called after verifying the function
+    /// type and the presence of the (potentially empty) function body.
+    /// Ensures SPIR-V specific semantics.
+    LogicalResult verifyBody();
+  }];
+}
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+  "op must appear in a spirv.ARM.Graph op's block",
+  CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+  let summary = [{
+    Declare a graph entry point and its interface.
+  }];
+
+  let description = [{
+    Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+    Name is a name string for the graphentry point. A module cannot have two
+    OpGraphEntryPointARM instructions with the same Name string.
+
+    Interface is a list of symbol references to `spirv.GlobalVariable`
+    operations. These declare the set of global variables from a
+    module that form the interface of this entry point. The set of
+    Interface symbols must be equal to or a superset of the
+    `spirv.GlobalVariable`s referenced by the entry point’s static call
+    tree, within the interface’s storage classes.
+
+    ```
+    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
+                       symbol-reference (`, ` symbol-reference)*
+    ```
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SymbolRefArrayAttr:$interface
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let builders = [
+    OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+                                               Terminator]> {
+
+  let summary = "Define graph outputs.";
+
+  let description = [{
+    Values are the graph outputs values and must match the GraphOutputs Type
+    operand of the OpTypeGraphARM type of the OpGraphARM body this
+    instruction is in.
+
+    This instruction must be the last instruction in a block.
+
+    ```
+    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<SPIRV_AnyTensorArm>:$value
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let hasOpcode = 0;
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2e356dec1981f..9d8d81a839fcb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
 class IntegerType;
 class FloatType;
 class FunctionType;
+class GraphType;
 class IndexType;
 class MemRefType;
 class VectorType;
@@ -81,6 +82,7 @@ class Builder {
   IntegerType getIntegerType(unsigned width);
   IntegerType getIntegerType(unsigned width, bool isSigned);
   FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+  GraphType getGraphType(TypeRange inputs, TypeRange results);
   TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let storageClass = "FunctionTypeStorage";
   let genStorageClass = 0;
   let extraClassDeclaration = [{
     /// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     unsigned getNumResults() const;
     Type getResult(unsigned i) const { return getResults()[i]; }
 
-    /// Returns a clone of this function type with the given argument
+    /// Returns a clone of this function-like type with the given argument
     /// and result types.
-    FunctionType clone(TypeRange inputs, TypeRange results) const;
+    }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
 
-    /// Returns a new function type with the specified arguments and results
+    /// Returns a new function-like type with the specified arguments and results
     /// inserted.
-    FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+    }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
                                        TypeRange argTypes,
                                        ArrayRef<unsigned> resultIndices,
                                        TypeRange resultTypes);
 
-    /// Returns a new function type without the specified arguments and results.
-    FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+    /// Returns a new function-like type without the specified arguments and results.
+    }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
                                           const BitVector &resultIndices);
   }];
 }
 
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
 //===----------------------------------------------------------------------===//
 // IndexType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 45ec1846580f2..aab1b01c5cff9 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
 def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
                               "function type", "::mlir::FunctionType">;
 
+// Graph Type
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+                              "graph type", "::mlir::GraphType">;
+
+
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..6f18dcefea14d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1065,8 +1065,9 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
-    NamedAttribute attribute) {
+LogicalResult
+SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+                                          unsigned resultIndex,
+                                          NamedAttribute attribute) {
   return op->emitError("cannot attach SPIR-V attributes to region result");
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
   return isNestedInFunctionOpInterface(op->getParentOp());
 }
 
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+  if (!op)
+    return false;
+  if (op->hasTrait<OpTrait::SymbolTable>())
+    return false;
+  if (isa<spirv::GraphARMOp>(op))
+    return true;
+  return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
 /// Returns true if the given op is an module-like op that maintains a symbol
 /// table.
 static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index f99339852824c..8dfdfea8a5c54 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1126,6 +1126,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Type, 0> idTypes;
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+    return failure();
+  }
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
+  }
+  result.addAttribute("interface",
+                      parser.getBuilder().getArrayAttr(interfaceVars));
+  return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  auto interfaceVars = getInterface().getValue();
+  if (!interfaceVars.empty()) {
+    printer << ", ";
+    llvm::interleaveComma(interfaceVars, printer);
+  }
+}
+
+LogicalResult spirv::GraphEntryPointARMOp::verify() {
+  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+  // verification.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  SmallVector<Type> resultTypes;
+  auto &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (auto &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  auto grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  auto *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  auto grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  GraphType grType = getFunctionType();
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  auto walkResult = walk([grType](Operation *op) -> WalkResult {
+    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
+      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
+        return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
+               << graphOutputsARMOp.getNumOperands()
+               << "value(s) but enclosing graph requires "
+               << grType.getNumResults() << " results";
+
+      auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
+      for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
+        auto graphOutputOperandType = graphOutputOperandTypes[i];
+        auto grResultType = grType.getResult(i);
+        if (graphOutputOperandType != grResultType)
+          return graphOutputsARMOp.emitError("type of return operand ")
+                 << i << " (" << graphOutputOperandType
+                 << ") doesn't match graph result type (" << grResultType
+                 << ")";
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+// Returns the argument types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+// Returns the result types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+// CallableOpInterface
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const auto &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0; i < results.size(); i++)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match graph result type (" << results[i]
+                         << ")"
+                         << " in graph @" << graph.getName();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.GLFClampOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6d3bda421f309..fd97b09d802f1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -158,6 +158,14 @@ void UpdateVCEPass::runOnOperation() {
     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
       valueTypes.push_back(globalVar.getType());
 
+    // If the op is FunctionLike make sure to process input and result types
+    if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
+      auto inputTypes = funcOpInterface.getArgumentTypes();
+      auto resultTypes = funcOpInterface.getResultTypes();
+      valueTypes.append(inputTypes.begin(), inputTypes.end());
+      valueTypes.append(resultTypes.begin(), resultTypes.end());
+    }
+
     // Requirements from values' types
     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index de52fbd3f215c..9a5dbcf6f598e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,7 +104,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      llvm::isa<FunctionType>(op->getResult(0).getType()))
+      (llvm::isa<FunctionType>(op->getResult(0).getType()) ||
+       llvm::isa<GraphType>(op->getResult(0).getType())))
     wrapped = true;
 
   if (wrapped)
@@ -2836,6 +2837,20 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         os << '>';
       })
       .Case<NoneType>([&](Type) { os << "none"; })
+      .Case<GraphType>([&](GraphType graphTy) {
+        os << '(';
+        interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
+        os << ") -> ";
+        ArrayRef<Type> results = graphTy.getResults();
+        if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
+                                     llvm::isa<GraphType>(results[0]))) {
+          printType(results[0]);
+        } else {
+          os << '(';
+          interleaveComma(results, [&](Type ty) { printType(ty); });
+          os << ')';
+        }
+      })
       .Default([&](Type type) { return printDialectType(type); });
 }
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index f657db142eeb9..3d366276b4375 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
   return FunctionType::get(context, inputs, results);
 }
 
+GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) {
+  return GraphType::get(context, inputs, results);
+}
+
 TupleType Builder::getTupleType(TypeRange elementTypes) {
   return TupleType::get(context, elementTypes);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..ce47c60c9b932 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -179,6 +179,45 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
   return clone(newArgTypes, newResultTypes);
 }
 
+//===----------------------------------------------------------------------===//
+// GraphType
+//===----------------------------------------------------------------------===//
+
+unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
+
+ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
+
+unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
+
+ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
+
+GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
+  return get(getContext(), inputs, results);
+}
+
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+                                           TypeRange argTypes,
+                                           ArrayRef<unsigned> resultIndices,
+                                           TypeRange resultTypes) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes =
+      insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
+  TypeRange newResultTypes =
+      insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
+
+/// Returns a new function type without the specified arguments and results.
+GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
+                                              const BitVector &resultIndices) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
+  TypeRange newResultTypes =
+      filterTypesOut(getResults(), resultIndices, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index f56bc3967b4b7..bc1505d32d4d5 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -306,3 +306,20 @@ func.func @constant_composite_replicate() -> () {
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
   spirv.Return
 }
+
+//===----------------------------------------------------------------------===//
+// GraphARM ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: graph_arm
+spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  // CHECK: spirv.ARM.GraphOutputs min version: v1.0
+  // CHECK: spirv.ARM.GraphOutputs max version: v1.6
+  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+  // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+// CHECK: spirv.ARM.Graph min version: v1.0
+// CHECK: spirv.ARM.Graph max version: v1.6
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
new file mode 100644
index 0000000000000..90c31e19db382
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphConstant
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
+  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
+
+  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
+  }
+
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4e534a30ad516..cf9d86576b1f6 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes {
     spirv.ReturnValue %val : bf16
   }
 }
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
+    #spirv.resource_limits<>>
+} {
+  spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
+      spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
+  }
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 2e5e591fe5f91..9efca825a663d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -21,7 +21,7 @@ using namespace mlir;
 namespace {
 /// A pass for testing SPIR-V op availability.
 struct PrintOpAvailability
-    : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
+    : public PassWrapper<PrintOpAvailability, OperationPass<mlir::ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
 
   void runOnOperation() override;
@@ -33,12 +33,10 @@ struct PrintOpAvailability
 } // namespace
 
 void PrintOpAvailability::runOnOperation() {
-  auto f = getOperation();
-  llvm::outs() << f.getName() << "\n";
-
+  auto moduleOp = getOperation();
   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
 
-  f->walk([&](Operation *op) {
+  auto opCallback = [&](Operation *op) {
     if (op->getDialect() != spirvDialect)
       return WalkResult::advance();
 
@@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() {
     os.flush();
 
     return WalkResult::advance();
+  };
+
+  moduleOp.walk([&](func::FuncOp f) {
+    llvm::outs() << f.getName() << "\n";
+    f->walk(opCallback);
+  });
+
+  moduleOp.walk([&](spirv::GraphARMOp g) {
+    llvm::outs() << g.getName() << "\n";
+    g->walk(opCallback);
   });
 }
 

>From 101e67424627d0762c745a3ec9b7e506657b82b2 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 6 Aug 2025 10:11:35 +0200
Subject: [PATCH 2/8] Resolve code review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ie69a1696a7b31869c1ba94bdf7aa214d52175565
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  3 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 64 ++++++++++--------
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 19 +++---
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |  4 +-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  4 +-
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     | 67 +++++++++++++++----
 .../SPIRV/Transforms/vce-deduction.mlir       |  4 +-
 7 files changed, 104 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a27554a3c6f64..0e42d08cdb1fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1343,7 +1343,7 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
   ];
 }
 def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
-  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM];
   list<Availability> availability = [
     Extension<[SPV_ARM_graph]>
   ];
@@ -4883,7 +4883,6 @@ class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "ARM", traits> {
 }
 
-
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index 38fb4b2eff414..f2913239cc4e8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -29,39 +29,11 @@ class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
   let availability = [
     MinVersion<SPIRV_V_1_0>,
     MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
     Capability<[SPIRV_C_GraphARM]>
   ];
 }
 
-def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
-  let summary = "Declare a graph constant.";
-
-  let description = [{
-    Declare a graph constant.
-    Result Type must be an OpTypeTensorARM.
-    GraphConstantID must be a 32-bit integer literal.
-  }];
-
-  let arguments = (ins
-    I32Attr: $graph_constant_id
-  );
-
-  let results = (outs
-    SPIRV_AnyTensorArm:$output
-  );
-
-  let hasVerifier = 0;
-
-  let autogenSerialization = 0;
-
-  let assemblyFormat = [{
-    attr-dict `:` type($output)
-  }];
-}
-
-// -----
-
 def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
     FunctionOpInterface, InModuleScope, IsolatedFromAbove
@@ -122,6 +94,8 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
   }];
 }
 
+// -----
+
 // Check that an op can only be used within the scope of a spirv.ARM.Graph op.
 def InGraphScope : PredOpTrait<
   "op must appear in a spirv.ARM.Graph op's block",
@@ -129,6 +103,38 @@ def InGraphScope : PredOpTrait<
 
 // -----
 
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+
+    ```
+    spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
+    ```
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
 def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
   let summary = [{
     Declare a graph entry point and its interface.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 8dfdfea8a5c54..953406da60a57 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1140,13 +1140,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
 
 ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  SmallVector<Type, 0> idTypes;
   SmallVector<Attribute, 4> interfaceVars;
 
   FlatSymbolRefAttr fn;
-  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
     return failure();
-  }
 
   if (!parser.parseOptionalComma()) {
     // Parse the interface variables
@@ -1224,7 +1222,7 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
       getResAttrsAttrName(result.name));
 
   // Parse the optional function body.
-  auto *body = result.addRegion();
+  Region *body = result.addRegion();
   OptionalParseResult parseResult =
       parser.parseOptionalRegion(*body, entryArgs);
   return failure(parseResult.has_value() && failed(*parseResult));
@@ -1234,7 +1232,7 @@ void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
   // Print graph name, signature, and control.
   printer << " ";
   printer.printSymbolName(getSymName());
-  auto grType = getFunctionType();
+  GraphType grType = getFunctionType();
   function_interface_impl::printFunctionSignature(
       printer, *this, grType.getInputs(),
       /*isVariadic=*/false, grType.getResults());
@@ -1288,9 +1286,10 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
                << grType.getNumResults() << " results";
 
       auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
-        auto graphOutputOperandType = graphOutputOperandTypes[i];
-        auto grResultType = grType.getResult(i);
+      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
+           ++i) {
+        Type graphOutputOperandType = graphOutputOperandTypes[i];
+        Type grResultType = grType.getResult(i);
         if (graphOutputOperandType != grResultType)
           return graphOutputsARMOp.emitError("type of return operand ")
                  << i << " (" << graphOutputOperandType
@@ -1339,13 +1338,13 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
   auto graph = cast<GraphARMOp>((*this)->getParentOp());
 
   // The operand number and types must match the graph signature.
-  const auto &results = graph.getFunctionType().getResults();
+  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
   if (getNumOperands() != results.size())
     return emitOpError("has ")
            << getNumOperands() << " operands, but enclosing graph (@"
            << graph.getName() << ") returns " << results.size();
 
-  for (unsigned i = 0; i < results.size(); i++)
+  for (unsigned i = 0, size = results.size(); i < size; ++i)
     if (getOperand(i).getType() != results[i])
       return emitError() << "type of return operand " << i << " ("
                          << getOperand(i).getType()
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index fd97b09d802f1..a2d221252fb69 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -160,8 +160,8 @@ void UpdateVCEPass::runOnOperation() {
 
     // If the op is FunctionLike make sure to process input and result types
     if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
-      auto inputTypes = funcOpInterface.getArgumentTypes();
-      auto resultTypes = funcOpInterface.getResultTypes();
+      ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
+      ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
       valueTypes.append(inputTypes.begin(), inputTypes.end());
       valueTypes.append(resultTypes.begin(), resultTypes.end());
     }
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index bc1505d32d4d5..4ef242bdc5b16 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -315,11 +315,11 @@ func.func @constant_composite_replicate() -> () {
 spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
   // CHECK: spirv.ARM.GraphOutputs min version: v1.0
   // CHECK: spirv.ARM.GraphOutputs max version: v1.6
-  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
   // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
   spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
 // CHECK: spirv.ARM.Graph min version: v1.0
 // CHECK: spirv.ARM.Graph max version: v1.6
-// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
 // CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 90c31e19db382..6919c7eecc632 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -1,29 +1,68 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph and spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ARM.GraphConstant
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
-  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
+  spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
+    // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+  }
+}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphEntryPoint
+//===----------------------------------------------------------------------===//
+
 
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
   // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
-  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
   // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
-  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
-  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
-  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
-    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
-    %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
-    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
-    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
+  spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Multiple spirv.ARM.Graphs
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
   }
 
   // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
     // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
     spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
   }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index cf9d86576b1f6..18958cef7b00a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -232,10 +232,10 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
 spirv.module Logical Vulkan attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.5, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
+    #spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
     #spirv.resource_limits<>>
 } {
   spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {

>From 3cf2ee9ce028265868e05ac84d561b09db4e847e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 6 Aug 2025 12:14:31 +0200
Subject: [PATCH 3/8] Fix one more comment

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia24695d965919ffebdff9945cd7d72233faa922a
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 953406da60a57..fdefd5e3966ca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1167,7 +1167,7 @@ ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
 void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   printer << " ";
   printer.printSymbolName(getFn());
-  auto interfaceVars = getInterface().getValue();
+  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
   if (!interfaceVars.empty()) {
     printer << ", ";
     llvm::interleaveComma(interfaceVars, printer);

>From 5ae57fe2d6fdfd8ccf95bd6a03c96bcc9db93ba1 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 Aug 2025 11:04:19 +0200
Subject: [PATCH 4/8] Resolve more review comments and expand testing

In particular add negative testing.

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Iee4ba17c74b451eda7f76c6f905ca12c734d39d6
---
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    |  36 +++++
 mlir/include/mlir/IR/CommonTypeConstraints.td |   1 -
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  36 +++--
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     | 131 +++++++++++++-----
 4 files changed, 151 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index f2913239cc4e8..51df4dc79ae68 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -57,6 +57,14 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
                         region
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+    }
+    ```
   }];
 
   let arguments = (ins
@@ -114,6 +122,12 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
     ```
     spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
     ```
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    ```
   }];
 
   let arguments = (ins
@@ -157,6 +171,17 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
     entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
                        symbol-reference (`, ` symbol-reference)*
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+    spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+    spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        ...
+    }
+    ```
   }];
 
   let arguments = (ins
@@ -166,6 +191,9 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
 
   let results = (outs);
 
+  // Checks for graph and interface symbol reference are done in spirv::ModuleOp verification.
+  let hasVerifier = 0;
+
   let autogenSerialization = 0;
 
   let builders = [
@@ -189,6 +217,14 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
     ```
     graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+    }
+    ```
   }];
 
   let arguments = (ins
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index aab1b01c5cff9..8ba2daefd97aa 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -393,7 +393,6 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
 def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
                               "graph type", "::mlir::GraphType">;
 
-
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fdefd5e3966ca..398dc046b3912 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1174,12 +1174,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   }
 }
 
-LogicalResult spirv::GraphEntryPointARMOp::verify() {
-  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
-  // verification.
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GraphARM
 //===----------------------------------------------------------------------===//
@@ -1257,7 +1251,19 @@ LogicalResult spirv::GraphARMOp::verifyType() {
 }
 
 LogicalResult spirv::GraphARMOp::verifyBody() {
-  GraphType grType = getFunctionType();
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
   if (!isExternal()) {
     Block &entryBlock = front();
 
@@ -1277,15 +1283,17 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
     }
   }
 
+  GraphType grType = getFunctionType();
   auto walkResult = walk([grType](Operation *op) -> WalkResult {
     if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
       if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
+        return graphOutputsARMOp.emitOpError("is returning ")
                << graphOutputsARMOp.getNumOperands()
-               << "value(s) but enclosing graph requires "
-               << grType.getNumResults() << " results";
+               << " value(s) but enclosing spirv.ARM.Graph requires "
+               << grType.getNumResults() << " result(s)";
 
-      auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
+      ValueTypeRange<OperandRange> graphOutputOperandTypes =
+          graphOutputsARMOp.getValue().getType();
       for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
            ++i) {
         Type graphOutputOperandType = graphOutputOperandTypes[i];
@@ -1341,15 +1349,15 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
   const ArrayRef<Type> &results = graph.getFunctionType().getResults();
   if (getNumOperands() != results.size())
     return emitOpError("has ")
-           << getNumOperands() << " operands, but enclosing graph (@"
+           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
            << graph.getName() << ") returns " << results.size();
 
   for (unsigned i = 0, size = results.size(); i < size; ++i)
     if (getOperand(i).getType() != results[i])
       return emitError() << "type of return operand " << i << " ("
                          << getOperand(i).getType()
-                         << ") doesn't match graph result type (" << results[i]
-                         << ")"
+                         << ") doesn't match  spirv.ARM.Graph result type ("
+                         << results[i] << ")"
                          << " in graph @" << graph.getName();
 
   return success();
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 6919c7eecc632..591eaaea4c802 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -4,12 +4,10 @@
 // spirv.ARM.Graph and spirv.ARM.GraphOutputs
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
 }
 
 // -----
@@ -18,14 +16,12 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8,
 // spirv.ARM.GraphConstant
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
-  spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
-    // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
-    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
-    // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
-    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
-  }
+// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
+spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
+  // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+  // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
 }
 // -----
 
@@ -33,37 +29,96 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8,
 // spirv.ARM.GraphEntryPoint
 //===----------------------------------------------------------------------===//
 
+// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
-  spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
-  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no terminator
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{empty block: expect at least a terminator}}
+spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no result types
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}}
+spirv.ARM.Graph @graphNoOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> () {
 }
 
 // -----
 
 //===----------------------------------------------------------------------===//
-// Multiple spirv.ARM.Graphs
+// spirv.ARM.GraphConstant outside graph scope
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// expected-error @+1 {{'spirv.ARM.GraphConstant' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphOutputs outside graph scope
+//===----------------------------------------------------------------------===//
+
+%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+// expected-error @+1 {{'spirv.ARM.GraphOutputs' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
+  // expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}}
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
+  // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}}
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}}
+  spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType argument
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType result
+//===----------------------------------------------------------------------===//
 
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-  spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
-  }
+// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> i8 {
 }

>From 1a4f0aa4486f25230821c3c08830729b0efd900e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 Aug 2025 13:20:35 +0200
Subject: [PATCH 5/8] Extract SPV_ARM_graph operations in its own file

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia74db44157fb724c9f787387386338814413db30
---
 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp | 264 ++++++++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt  |   1 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp    | 237 -------------------
 3 files changed, 265 insertions(+), 237 deletions(-)
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp

diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
new file mode 100644
index 0000000000000..e300596fd3733
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -0,0 +1,264 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
+//------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  SmallVector<Type> resultTypes;
+  auto &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (auto &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  auto grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  Region *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  GraphType grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  GraphType grType = getFunctionType();
+  auto walkResult = walk([grType](Operation *op) -> WalkResult {
+    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
+      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
+        return graphOutputsARMOp.emitOpError("is returning ")
+               << graphOutputsARMOp.getNumOperands()
+               << " value(s) but enclosing spirv.ARM.Graph requires "
+               << grType.getNumResults() << " result(s)";
+
+      ValueTypeRange<OperandRange> graphOutputOperandTypes =
+          graphOutputsARMOp.getValue().getType();
+      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
+           ++i) {
+        Type graphOutputOperandType = graphOutputOperandTypes[i];
+        Type grResultType = grType.getResult(i);
+        if (graphOutputOperandType != grResultType)
+          return graphOutputsARMOp.emitError("type of return operand ")
+                 << i << " (" << graphOutputOperandType
+                 << ") doesn't match graph result type (" << grResultType
+                 << ")";
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+// Returns the argument types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+// Returns the result types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+// CallableOpInterface
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0, size = results.size(); i < size; ++i)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match  spirv.ARM.Graph result type ("
+                         << results[i] << ")"
+                         << " in graph @" << graph.getName();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
+    return failure();
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
+  }
+  result.addAttribute("interface",
+                      parser.getBuilder().getArrayAttr(interfaceVars));
+  return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
+  if (!interfaceVars.empty()) {
+    printer << ", ";
+    llvm::interleaveComma(interfaceVars, printer);
+  }
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b9aa7b7491abf..60d705d940cfc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
 add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
 
 add_mlir_dialect_library(MLIRSPIRVDialect
+  ArmGraphOps.cpp
   AtomicOps.cpp
   CastOps.cpp
   ControlFlowOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 398dc046b3912..f99339852824c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1126,243 +1126,6 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GraphEntryPointARM
-//===----------------------------------------------------------------------===//
-
-void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
-                                        OperationState &state,
-                                        spirv::GraphARMOp graph,
-                                        ArrayRef<Attribute> interfaceVars) {
-  build(builder, state, SymbolRefAttr::get(graph),
-        builder.getArrayAttr(interfaceVars));
-}
-
-ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
-                                               OperationState &result) {
-  SmallVector<Attribute, 4> interfaceVars;
-
-  FlatSymbolRefAttr fn;
-  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
-    return failure();
-
-  if (!parser.parseOptionalComma()) {
-    // Parse the interface variables
-    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
-          // The name of the interface variable attribute isnt important
-          FlatSymbolRefAttr var;
-          NamedAttrList attrs;
-          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
-            return failure();
-          interfaceVars.push_back(var);
-          return success();
-        }))
-      return failure();
-  }
-  result.addAttribute("interface",
-                      parser.getBuilder().getArrayAttr(interfaceVars));
-  return success();
-}
-
-void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
-  printer << " ";
-  printer.printSymbolName(getFn());
-  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
-  if (!interfaceVars.empty()) {
-    printer << ", ";
-    llvm::interleaveComma(interfaceVars, printer);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GraphARM
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  SmallVector<OpAsmParser::Argument> entryArgs;
-  SmallVector<DictionaryAttr> resultAttrs;
-  SmallVector<Type> resultTypes;
-  auto &builder = parser.getBuilder();
-
-  // Parse the name as a symbol.
-  StringAttr nameAttr;
-  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
-                             result.attributes))
-    return failure();
-
-  // Parse the function signature.
-  bool isVariadic = false;
-  if (function_interface_impl::parseFunctionSignatureWithArguments(
-          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
-          resultAttrs))
-    return failure();
-
-  SmallVector<Type> argTypes;
-  for (auto &arg : entryArgs)
-    argTypes.push_back(arg.type);
-  auto grType = builder.getGraphType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(grType));
-
-  // If additional attributes are present, parse them.
-  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
-    return failure();
-
-  // Add the attributes to the function arguments.
-  assert(resultAttrs.size() == resultTypes.size());
-  call_interface_impl::addArgAndResultAttrs(
-      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
-      getResAttrsAttrName(result.name));
-
-  // Parse the optional function body.
-  Region *body = result.addRegion();
-  OptionalParseResult parseResult =
-      parser.parseOptionalRegion(*body, entryArgs);
-  return failure(parseResult.has_value() && failed(*parseResult));
-}
-
-void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
-  // Print graph name, signature, and control.
-  printer << " ";
-  printer.printSymbolName(getSymName());
-  GraphType grType = getFunctionType();
-  function_interface_impl::printFunctionSignature(
-      printer, *this, grType.getInputs(),
-      /*isVariadic=*/false, grType.getResults());
-  function_interface_impl::printFunctionAttributes(printer, *this,
-                                                   {getFunctionTypeAttrName(),
-                                                    getArgAttrsAttrName(),
-                                                    getResAttrsAttrName()});
-
-  // Print the body.
-  Region &body = this->getBody();
-  if (!body.empty()) {
-    printer << ' ';
-    printer.printRegion(body, /*printEntryBlockArgs=*/false,
-                        /*printBlockTerminators=*/true);
-  }
-}
-
-LogicalResult spirv::GraphARMOp::verifyType() {
-  if (getFunctionType().getNumResults() < 1)
-    return emitOpError("there should be at least one result");
-  return success();
-}
-
-LogicalResult spirv::GraphARMOp::verifyBody() {
-  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
-    if (!isa<spirv::TensorArmType>(graphArgType)) {
-      return emitOpError("type of argument #")
-             << index << " must be a TensorArmType, but got " << graphArgType;
-    }
-  }
-  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
-    if (!isa<spirv::TensorArmType>(graphResType)) {
-      return emitOpError("type of result #")
-             << index << " must be a TensorArmType, but got " << graphResType;
-    }
-  }
-
-  if (!isExternal()) {
-    Block &entryBlock = front();
-
-    unsigned numArguments = this->getNumArguments();
-    if (entryBlock.getNumArguments() != numArguments)
-      return emitOpError("entry block must have ")
-             << numArguments << " arguments to match graph signature";
-
-    for (auto [index, grArgType, blockArgType] :
-         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
-      if (blockArgType != grArgType) {
-        return emitOpError("type of entry block argument #")
-               << index << '(' << blockArgType
-               << ") must match the type of the corresponding argument in "
-               << "graph signature(" << grArgType << ')';
-      }
-    }
-  }
-
-  GraphType grType = getFunctionType();
-  auto walkResult = walk([grType](Operation *op) -> WalkResult {
-    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
-      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("is returning ")
-               << graphOutputsARMOp.getNumOperands()
-               << " value(s) but enclosing spirv.ARM.Graph requires "
-               << grType.getNumResults() << " result(s)";
-
-      ValueTypeRange<OperandRange> graphOutputOperandTypes =
-          graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
-           ++i) {
-        Type graphOutputOperandType = graphOutputOperandTypes[i];
-        Type grResultType = grType.getResult(i);
-        if (graphOutputOperandType != grResultType)
-          return graphOutputsARMOp.emitError("type of return operand ")
-                 << i << " (" << graphOutputOperandType
-                 << ") doesn't match graph result type (" << grResultType
-                 << ")";
-      }
-    }
-    return WalkResult::advance();
-  });
-
-  return failure(walkResult.wasInterrupted());
-}
-
-void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
-                              StringRef name, GraphType type,
-                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
-  state.addAttribute(SymbolTable::getSymbolAttrName(),
-                     builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
-  state.attributes.append(attrs.begin(), attrs.end());
-  state.addAttribute(getEntryPointAttrName(state.name),
-                     builder.getBoolAttr(entryPoint));
-  state.addRegion();
-}
-
-// Returns the argument types of this function.
-ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
-  return getFunctionType().getInputs();
-}
-
-// Returns the result types of this function.
-ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
-  return getFunctionType().getResults();
-}
-
-// CallableOpInterface
-Region *spirv::GraphARMOp::getCallableRegion() {
-  return isExternal() ? nullptr : &getBody();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GraphOutputsARM
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GraphOutputsARMOp::verify() {
-  auto graph = cast<GraphARMOp>((*this)->getParentOp());
-
-  // The operand number and types must match the graph signature.
-  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
-  if (getNumOperands() != results.size())
-    return emitOpError("has ")
-           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
-           << graph.getName() << ") returns " << results.size();
-
-  for (unsigned i = 0, size = results.size(); i < size; ++i)
-    if (getOperand(i).getType() != results[i])
-      return emitError() << "type of return operand " << i << " ("
-                         << getOperand(i).getType()
-                         << ") doesn't match  spirv.ARM.Graph result type ("
-                         << results[i] << ")"
-                         << " in graph @" << graph.getName();
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GLFClampOp
 //===----------------------------------------------------------------------===//

>From f4c2c3a524e2d432b3fc84bb6eb55a21dfcaa7b1 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 18 Aug 2025 16:36:11 +0200
Subject: [PATCH 6/8] Fix review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ib592a4b99af01c0d3c88eaf63a61cb4c6cca8cbe
---
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 15 +----
 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp     | 63 ++++++++-----------
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |  6 +-
 mlir/lib/IR/AsmPrinter.cpp                    |  7 +--
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     |  2 +-
 .../lib/Dialect/SPIRV/TestAvailability.cpp    |  2 +-
 6 files changed, 35 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index 51df4dc79ae68..ff2ec7363fe38 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -51,7 +51,7 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     that is unique in the enclosing module op.
 
     This op itself takes no operands and generates no results. Its region
-    can take zero or more arguments and return zero or more values.
+    can take zero or more arguments and return one or more values.
 
     ```
     spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
@@ -119,10 +119,6 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
     Result Type must be an OpTypeTensorARM.
     GraphConstantID must be a 32-bit integer literal.
 
-    ```
-    spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
-    ```
-
     #### Example:
 
     ```mlir
@@ -167,11 +163,6 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
     `spirv.GlobalVariable`s referenced by the entry point’s static call
     tree, within the interface’s storage classes.
 
-    ```
-    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
-                       symbol-reference (`, ` symbol-reference)*
-    ```
-
     #### Example:
 
     ```mlir
@@ -214,10 +205,6 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
 
     This instruction must be the last instruction in a block.
 
-    ```
-    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
-    ```
-
     #### Example:
 
     ```mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
index e300596fd3733..722e771425a47 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -1,5 +1,4 @@
-//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
-//------------------------------===//
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -22,6 +21,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
 
 using namespace mlir;
 using namespace mlir::spirv::AttrNames;
@@ -32,10 +32,7 @@ using namespace mlir::spirv::AttrNames;
 
 ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
-  SmallVector<OpAsmParser::Argument> entryArgs;
-  SmallVector<DictionaryAttr> resultAttrs;
-  SmallVector<Type> resultTypes;
-  auto &builder = parser.getBuilder();
+  Builder &builder = parser.getBuilder();
 
   // Parse the name as a symbol.
   StringAttr nameAttr;
@@ -45,15 +42,18 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
 
   // Parse the function signature.
   bool isVariadic = false;
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<Type> resultTypes;
+  SmallVector<DictionaryAttr> resultAttrs;
   if (function_interface_impl::parseFunctionSignatureWithArguments(
           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
           resultAttrs))
     return failure();
 
   SmallVector<Type> argTypes;
-  for (auto &arg : entryArgs)
+  for (OpAsmParser::Argument &arg : entryArgs)
     argTypes.push_back(arg.type);
-  auto grType = builder.getGraphType(argTypes, resultTypes);
+  GraphType grType = builder.getGraphType(argTypes, resultTypes);
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(grType));
 
@@ -136,26 +136,22 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
   }
 
   GraphType grType = getFunctionType();
-  auto walkResult = walk([grType](Operation *op) -> WalkResult {
-    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
-      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("is returning ")
-               << graphOutputsARMOp.getNumOperands()
-               << " value(s) but enclosing spirv.ARM.Graph requires "
-               << grType.getNumResults() << " result(s)";
-
-      ValueTypeRange<OperandRange> graphOutputOperandTypes =
-          graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
-           ++i) {
-        Type graphOutputOperandType = graphOutputOperandTypes[i];
-        Type grResultType = grType.getResult(i);
-        if (graphOutputOperandType != grResultType)
-          return graphOutputsARMOp.emitError("type of return operand ")
-                 << i << " (" << graphOutputOperandType
-                 << ") doesn't match graph result type (" << grResultType
-                 << ")";
-      }
+  auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+    if (grType.getNumResults() != op.getNumOperands())
+      return op.emitOpError("is returning ")
+             << op.getNumOperands()
+             << " value(s) but enclosing spirv.ARM.Graph requires "
+             << grType.getNumResults() << " result(s)";
+
+    ValueTypeRange<OperandRange> graphOutputOperandTypes =
+        op.getValue().getType();
+    for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) {
+      Type graphOutputOperandType = graphOutputOperandTypes[i];
+      Type grResultType = grType.getResult(i);
+      if (graphOutputOperandType != grResultType)
+        return op.emitError("type of return operand ")
+               << i << " (" << graphOutputOperandType
+               << ") doesn't match graph result type (" << grResultType << ")";
     }
     return WalkResult::advance();
   });
@@ -169,23 +165,20 @@ void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
-  state.attributes.append(attrs.begin(), attrs.end());
+  state.attributes.append(attrs);
   state.addAttribute(getEntryPointAttrName(state.name),
                      builder.getBoolAttr(entryPoint));
   state.addRegion();
 }
 
-// Returns the argument types of this function.
 ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
   return getFunctionType().getInputs();
 }
 
-// Returns the result types of this function.
 ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
   return getFunctionType().getResults();
 }
 
-// CallableOpInterface
 Region *spirv::GraphARMOp::getCallableRegion() {
   return isExternal() ? nullptr : &getBody();
 }
@@ -229,12 +222,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
 
 ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  SmallVector<Attribute, 4> interfaceVars;
-
   FlatSymbolRefAttr fn;
   if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
     return failure();
 
+  SmallVector<Attribute, 4> interfaceVars;
   if (!parser.parseOptionalComma()) {
     // Parse the interface variables
     if (parser.parseCommaSeparatedList([&]() -> ParseResult {
@@ -258,7 +250,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getFn());
   ArrayRef<Attribute> interfaceVars = getInterface().getValue();
   if (!interfaceVars.empty()) {
-    printer << ", ";
-    llvm::interleaveComma(interfaceVars, printer);
+    printer << ", " << llvm::interleaved(interfaceVars);
   }
 }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index a2d221252fb69..e8c299932f06d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -160,10 +160,8 @@ void UpdateVCEPass::runOnOperation() {
 
     // If the op is FunctionLike make sure to process input and result types
     if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
-      ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
-      ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
-      valueTypes.append(inputTypes.begin(), inputTypes.end());
-      valueTypes.append(resultTypes.begin(), resultTypes.end());
+      llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
+      llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
     }
 
     // Requirements from values' types
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9a5dbcf6f598e..348068dbc84c8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,8 +104,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      (llvm::isa<FunctionType>(op->getResult(0).getType()) ||
-       llvm::isa<GraphType>(op->getResult(0).getType())))
+      isa<GraphType>(op->getResult(0).getType()))
     wrapped = true;
 
   if (wrapped)
@@ -2842,8 +2841,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
         os << ") -> ";
         ArrayRef<Type> results = graphTy.getResults();
-        if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
-                                     llvm::isa<GraphType>(results[0]))) {
+        if (results.size() == 1 &&
+            !(isa<FunctionType>(results[0]) || isa<GraphType>(results[0]))) {
           printType(results[0]);
         } else {
           os << '(';
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 591eaaea4c802..c7763d45c6b5e 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.ARM.Graph and spirv.ARM.GraphOutputs
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 9efca825a663d..5643a0ff5b91c 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -33,7 +33,7 @@ struct PrintOpAvailability
 } // namespace
 
 void PrintOpAvailability::runOnOperation() {
-  auto moduleOp = getOperation();
+  mlir::ModuleOp moduleOp = getOperation();
   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
 
   auto opCallback = [&](Operation *op) {

>From 6d364967b74fc3cc6c1ae3ed895f9241534e1ce1 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Tue, 19 Aug 2025 11:51:48 +0200
Subject: [PATCH 7/8] Small fixes

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I8df473ef101d12f2d57d1cc2f71f05268e3dc5c3
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 7 +++----
 mlir/lib/IR/AsmPrinter.cpp                 | 2 +-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 6f18dcefea14d..fcf1526491971 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1065,9 +1065,8 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult
-SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
-                                          unsigned resultIndex,
-                                          NamedAttribute attribute) {
+LogicalResult SPIRVDialect::verifyRegionResultAttribute(
+    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+    NamedAttribute attribute) {
   return op->emitError("cannot attach SPIR-V attributes to region result");
 }
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 348068dbc84c8..66ab6ee977366 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,7 +104,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      isa<GraphType>(op->getResult(0).getType()))
+      isa<FunctionType>(op->getResult(0).getType()))
     wrapped = true;
 
   if (wrapped)

>From 234c095d51a5191f879ecdaa318679c11e2ca43c Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 25 Aug 2025 13:55:47 +0200
Subject: [PATCH 8/8] Clarify documentation for SPIRV_GraphARMOp

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I7e73e1277bb988813f19f5684ec024c2d2dcb1df
---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index ff2ec7363fe38..748fd25b3f56a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -45,10 +45,15 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     This op declares or defines a SPIR-V graph using one region, which
     contains one or more blocks.
 
-    Different from the SPIR-V binary format, this op is not allowed to
-    implicitly capture global values, and all external references must use
-    function arguments or symbol references. This op itself defines a symbol
-    that is unique in the enclosing module op.
+    This op is not allowed to implicitly capture global values, and all external
+    references must use function arguments or symbol references. This op itself
+    defines a symbol that is unique in the enclosing module op.
+
+    Note that this op does not have a 1:1 mapping to the SPIR-V ops representing
+    a graph. Indeed during serialization a single GraphARMOp is serialized into
+    several different SPIR-V ops: OpGraphARM, OpGraphInputARM and OpGraphEndARM.
+    There are as many occurences of OpGraphInputARM ops as many inputs in the
+    graph. Deserialization maps that set of operations into a single GraphARMOp.
 
     This op itself takes no operands and generates no results. Its region
     can take zero or more arguments and return one or more values.



More information about the Mlir-commits mailing list