[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (PR #151934)

Davide Grohmann llvmlistbot at llvm.org
Tue Aug 19 01:06:12 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/151934

>From f3eb19fdd5acf04635e842ce0b002c607047d688 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 4 Aug 2025 10:42:01 +0200
Subject: [PATCH 1/6] [mlir][spirv] Add support for SPV_ARM_graph extension -
 part 1
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This is the first patch to add support for the SPV_ARM_graph SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new
Graph abstraction for expressing dataflow computations over full
resources.

The part 1 implementation includes:

    A new GraphType, modeled similarly to FunctionType, for typed graph signatures.
    New operations in the spirv.arm namespace:
        spirv.arm.Graph
        spirv.arm.GraphEntryPoint
        spirv.arm.GraphConstant
        spirv.arm.GraphOutput
    Verifier and VCE updates to properly gate usage under SPV_ARM_graph.
    Tests covering parsing, verification.

Graphs currently support only SPV_ARM_tensors, but are designed to
generalize to other resource types, such as images.

Spec: KhronosGroup/SPIRV-Registry#346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia74b7ab0161b03d3d4702e93c34d7f55cd295a5f
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  26 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 201 +++++++++++++++
 .../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td |   1 +
 mlir/include/mlir/IR/Builders.h               |   2 +
 mlir/include/mlir/IR/BuiltinTypes.td          |  18 +-
 mlir/include/mlir/IR/CommonTypeConstraints.td |   7 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |   7 +-
 .../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp    |  12 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 230 ++++++++++++++++++
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |   8 +
 mlir/lib/IR/AsmPrinter.cpp                    |  17 +-
 mlir/lib/IR/Builders.cpp                      |   4 +
 mlir/lib/IR/BuiltinTypes.cpp                  |  39 +++
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  17 ++
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     |  30 +++
 .../SPIRV/Transforms/vce-deduction.mlir       |  11 +
 .../lib/Dialect/SPIRV/TestAvailability.cpp    |  18 +-
 17 files changed, 630 insertions(+), 18 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
 create mode 100644 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index bdfd728d1d0b3..a27554a3c6f64 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
 def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph                        : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
 
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
-      SPV_ARM_tensors,
+      SPV_ARM_tensors, SPV_ARM_graph,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
     Extension<[SPV_ARM_tensors]>
   ];
 }
+def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_graph]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
       SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
-      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
 def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM               : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM             : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM                       : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM                  : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM              : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM                    : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM                   : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
       SPIRV_OC_OpGroupNonUniformLogicalXor,
       SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+      SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+      SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
       SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4862,6 +4879,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
+
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..38fb4b2eff414
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,201 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_ArmVendorOp<mnemonic, traits> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+    AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+    FunctionOpInterface, InModuleScope, IsolatedFromAbove
+  ]> {
+
+  let summary = "Declare or define a SPIR-V graph";
+
+  let description = [{
+    This op declares or defines a SPIR-V graph using one region, which
+    contains one or more blocks.
+
+    Different from the SPIR-V binary format, this op is not allowed to
+    implicitly capture global values, and all external references must use
+    function arguments or symbol references. This op itself defines a symbol
+    that is unique in the enclosing module op.
+
+    This op itself takes no operands and generates no results. Its region
+    can take zero or more arguments and return zero or more values.
+
+    ```
+    spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+                        region
+    ```
+  }];
+
+  let arguments = (ins
+    TypeAttrOf<GraphType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
+    OptionalAttr<BoolAttr>:$entry_point,
+    StrAttr:$sym_name
+  );
+
+  let results = (outs);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasVerifier = 0;
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,  CArg<"bool", "false">:$entry_point)>];
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+
+  let extraClassDeclaration = [{
+    /// Hook for FunctionOpInterface, called after verifying that the 'type'
+    /// attribute is present and checks if it holds a function type. Ensures
+    /// getType, getNumArguments, and getNumResults can be called safely
+    LogicalResult verifyType();
+
+    /// Hook for FunctionOpInterface, called after verifying the function
+    /// type and the presence of the (potentially empty) function body.
+    /// Ensures SPIR-V specific semantics.
+    LogicalResult verifyBody();
+  }];
+}
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+  "op must appear in a spirv.ARM.Graph op's block",
+  CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+  let summary = [{
+    Declare a graph entry point and its interface.
+  }];
+
+  let description = [{
+    Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+    Name is a name string for the graphentry point. A module cannot have two
+    OpGraphEntryPointARM instructions with the same Name string.
+
+    Interface is a list of symbol references to `spirv.GlobalVariable`
+    operations. These declare the set of global variables from a
+    module that form the interface of this entry point. The set of
+    Interface symbols must be equal to or a superset of the
+    `spirv.GlobalVariable`s referenced by the entry point’s static call
+    tree, within the interface’s storage classes.
+
+    ```
+    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
+                       symbol-reference (`, ` symbol-reference)*
+    ```
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SymbolRefArrayAttr:$interface
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let builders = [
+    OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+                                               Terminator]> {
+
+  let summary = "Define graph outputs.";
+
+  let description = [{
+    Values are the graph outputs values and must match the GraphOutputs Type
+    operand of the OpTypeGraphARM type of the OpGraphARM body this
+    instruction is in.
+
+    This instruction must be the last instruction in a block.
+
+    ```
+    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<SPIRV_AnyTensorArm>:$value
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let hasOpcode = 0;
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2e356dec1981f..9d8d81a839fcb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
 class IntegerType;
 class FloatType;
 class FunctionType;
+class GraphType;
 class IndexType;
 class MemRefType;
 class VectorType;
@@ -81,6 +82,7 @@ class Builder {
   IntegerType getIntegerType(unsigned width);
   IntegerType getIntegerType(unsigned width, bool isSigned);
   FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+  GraphType getGraphType(TypeRange inputs, TypeRange results);
   TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let storageClass = "FunctionTypeStorage";
   let genStorageClass = 0;
   let extraClassDeclaration = [{
     /// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     unsigned getNumResults() const;
     Type getResult(unsigned i) const { return getResults()[i]; }
 
-    /// Returns a clone of this function type with the given argument
+    /// Returns a clone of this function-like type with the given argument
     /// and result types.
-    FunctionType clone(TypeRange inputs, TypeRange results) const;
+    }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
 
-    /// Returns a new function type with the specified arguments and results
+    /// Returns a new function-like type with the specified arguments and results
     /// inserted.
-    FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+    }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
                                        TypeRange argTypes,
                                        ArrayRef<unsigned> resultIndices,
                                        TypeRange resultTypes);
 
-    /// Returns a new function type without the specified arguments and results.
-    FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+    /// Returns a new function-like type without the specified arguments and results.
+    }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
                                           const BitVector &resultIndices);
   }];
 }
 
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
 //===----------------------------------------------------------------------===//
 // IndexType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 45ec1846580f2..aab1b01c5cff9 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
 def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
                               "function type", "::mlir::FunctionType">;
 
+// Graph Type
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+                              "graph type", "::mlir::GraphType">;
+
+
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index fcf1526491971..6f18dcefea14d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1065,8 +1065,9 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
-    NamedAttribute attribute) {
+LogicalResult
+SPIRVDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+                                          unsigned resultIndex,
+                                          NamedAttribute attribute) {
   return op->emitError("cannot attach SPIR-V attributes to region result");
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
   return isNestedInFunctionOpInterface(op->getParentOp());
 }
 
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+  if (!op)
+    return false;
+  if (op->hasTrait<OpTrait::SymbolTable>())
+    return false;
+  if (isa<spirv::GraphARMOp>(op))
+    return true;
+  return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
 /// Returns true if the given op is an module-like op that maintains a symbol
 /// table.
 static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index f99339852824c..8dfdfea8a5c54 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1126,6 +1126,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Type, 0> idTypes;
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+    return failure();
+  }
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
+  }
+  result.addAttribute("interface",
+                      parser.getBuilder().getArrayAttr(interfaceVars));
+  return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  auto interfaceVars = getInterface().getValue();
+  if (!interfaceVars.empty()) {
+    printer << ", ";
+    llvm::interleaveComma(interfaceVars, printer);
+  }
+}
+
+LogicalResult spirv::GraphEntryPointARMOp::verify() {
+  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
+  // verification.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  SmallVector<Type> resultTypes;
+  auto &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (auto &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  auto grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  auto *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  auto grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  GraphType grType = getFunctionType();
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  auto walkResult = walk([grType](Operation *op) -> WalkResult {
+    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
+      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
+        return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
+               << graphOutputsARMOp.getNumOperands()
+               << "value(s) but enclosing graph requires "
+               << grType.getNumResults() << " results";
+
+      auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
+      for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
+        auto graphOutputOperandType = graphOutputOperandTypes[i];
+        auto grResultType = grType.getResult(i);
+        if (graphOutputOperandType != grResultType)
+          return graphOutputsARMOp.emitError("type of return operand ")
+                 << i << " (" << graphOutputOperandType
+                 << ") doesn't match graph result type (" << grResultType
+                 << ")";
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+// Returns the argument types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+// Returns the result types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+// CallableOpInterface
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const auto &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0; i < results.size(); i++)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match graph result type (" << results[i]
+                         << ")"
+                         << " in graph @" << graph.getName();
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.GLFClampOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6d3bda421f309..fd97b09d802f1 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -158,6 +158,14 @@ void UpdateVCEPass::runOnOperation() {
     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
       valueTypes.push_back(globalVar.getType());
 
+    // If the op is FunctionLike make sure to process input and result types
+    if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
+      auto inputTypes = funcOpInterface.getArgumentTypes();
+      auto resultTypes = funcOpInterface.getResultTypes();
+      valueTypes.append(inputTypes.begin(), inputTypes.end());
+      valueTypes.append(resultTypes.begin(), resultTypes.end());
+    }
+
     // Requirements from values' types
     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index de52fbd3f215c..9a5dbcf6f598e 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,7 +104,8 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      llvm::isa<FunctionType>(op->getResult(0).getType()))
+      (llvm::isa<FunctionType>(op->getResult(0).getType()) ||
+       llvm::isa<GraphType>(op->getResult(0).getType())))
     wrapped = true;
 
   if (wrapped)
@@ -2836,6 +2837,20 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         os << '>';
       })
       .Case<NoneType>([&](Type) { os << "none"; })
+      .Case<GraphType>([&](GraphType graphTy) {
+        os << '(';
+        interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
+        os << ") -> ";
+        ArrayRef<Type> results = graphTy.getResults();
+        if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
+                                     llvm::isa<GraphType>(results[0]))) {
+          printType(results[0]);
+        } else {
+          os << '(';
+          interleaveComma(results, [&](Type ty) { printType(ty); });
+          os << ')';
+        }
+      })
       .Default([&](Type type) { return printDialectType(type); });
 }
 
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index f657db142eeb9..3d366276b4375 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
   return FunctionType::get(context, inputs, results);
 }
 
+GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) {
+  return GraphType::get(context, inputs, results);
+}
+
 TupleType Builder::getTupleType(TypeRange elementTypes) {
   return TupleType::get(context, elementTypes);
 }
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..ce47c60c9b932 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -179,6 +179,45 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
   return clone(newArgTypes, newResultTypes);
 }
 
+//===----------------------------------------------------------------------===//
+// GraphType
+//===----------------------------------------------------------------------===//
+
+unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
+
+ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
+
+unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
+
+ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
+
+GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
+  return get(getContext(), inputs, results);
+}
+
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+                                           TypeRange argTypes,
+                                           ArrayRef<unsigned> resultIndices,
+                                           TypeRange resultTypes) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes =
+      insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
+  TypeRange newResultTypes =
+      insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
+
+/// Returns a new function type without the specified arguments and results.
+GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
+                                              const BitVector &resultIndices) {
+  SmallVector<Type> argStorage, resultStorage;
+  TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
+  TypeRange newResultTypes =
+      filterTypesOut(getResults(), resultIndices, resultStorage);
+  return clone(newArgTypes, newResultTypes);
+}
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index f56bc3967b4b7..bc1505d32d4d5 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -306,3 +306,20 @@ func.func @constant_composite_replicate() -> () {
   %0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
   spirv.Return
 }
+
+//===----------------------------------------------------------------------===//
+// GraphARM ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: graph_arm
+spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  // CHECK: spirv.ARM.GraphOutputs min version: v1.0
+  // CHECK: spirv.ARM.GraphOutputs max version: v1.6
+  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+  // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+// CHECK: spirv.ARM.Graph min version: v1.0
+// CHECK: spirv.ARM.Graph max version: v1.6
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
new file mode 100644
index 0000000000000..90c31e19db382
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphConstant
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
+  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
+
+  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
+    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
+  }
+
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+  }
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4e534a30ad516..cf9d86576b1f6 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes {
     spirv.ReturnValue %val : bf16
   }
 }
+
+// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.5, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
+    #spirv.resource_limits<>>
+} {
+  spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
+      spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
+  }
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 2e5e591fe5f91..9efca825a663d 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -21,7 +21,7 @@ using namespace mlir;
 namespace {
 /// A pass for testing SPIR-V op availability.
 struct PrintOpAvailability
-    : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
+    : public PassWrapper<PrintOpAvailability, OperationPass<mlir::ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
 
   void runOnOperation() override;
@@ -33,12 +33,10 @@ struct PrintOpAvailability
 } // namespace
 
 void PrintOpAvailability::runOnOperation() {
-  auto f = getOperation();
-  llvm::outs() << f.getName() << "\n";
-
+  auto moduleOp = getOperation();
   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
 
-  f->walk([&](Operation *op) {
+  auto opCallback = [&](Operation *op) {
     if (op->getDialect() != spirvDialect)
       return WalkResult::advance();
 
@@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() {
     os.flush();
 
     return WalkResult::advance();
+  };
+
+  moduleOp.walk([&](func::FuncOp f) {
+    llvm::outs() << f.getName() << "\n";
+    f->walk(opCallback);
+  });
+
+  moduleOp.walk([&](spirv::GraphARMOp g) {
+    llvm::outs() << g.getName() << "\n";
+    g->walk(opCallback);
   });
 }
 

>From 101e67424627d0762c745a3ec9b7e506657b82b2 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 6 Aug 2025 10:11:35 +0200
Subject: [PATCH 2/6] Resolve code review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ie69a1696a7b31869c1ba94bdf7aa214d52175565
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  3 +-
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 64 ++++++++++--------
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 19 +++---
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |  4 +-
 mlir/test/Dialect/SPIRV/IR/availability.mlir  |  4 +-
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     | 67 +++++++++++++++----
 .../SPIRV/Transforms/vce-deduction.mlir       |  4 +-
 7 files changed, 104 insertions(+), 61 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index a27554a3c6f64..0e42d08cdb1fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1343,7 +1343,7 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
   ];
 }
 def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
-  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM];
   list<Availability> availability = [
     Extension<[SPV_ARM_graph]>
   ];
@@ -4883,7 +4883,6 @@ class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "ARM", traits> {
 }
 
-
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index 38fb4b2eff414..f2913239cc4e8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -29,39 +29,11 @@ class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
   let availability = [
     MinVersion<SPIRV_V_1_0>,
     MaxVersion<SPIRV_V_1_6>,
-    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
     Capability<[SPIRV_C_GraphARM]>
   ];
 }
 
-def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
-  let summary = "Declare a graph constant.";
-
-  let description = [{
-    Declare a graph constant.
-    Result Type must be an OpTypeTensorARM.
-    GraphConstantID must be a 32-bit integer literal.
-  }];
-
-  let arguments = (ins
-    I32Attr: $graph_constant_id
-  );
-
-  let results = (outs
-    SPIRV_AnyTensorArm:$output
-  );
-
-  let hasVerifier = 0;
-
-  let autogenSerialization = 0;
-
-  let assemblyFormat = [{
-    attr-dict `:` type($output)
-  }];
-}
-
-// -----
-
 def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
     FunctionOpInterface, InModuleScope, IsolatedFromAbove
@@ -122,6 +94,8 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
   }];
 }
 
+// -----
+
 // Check that an op can only be used within the scope of a spirv.ARM.Graph op.
 def InGraphScope : PredOpTrait<
   "op must appear in a spirv.ARM.Graph op's block",
@@ -129,6 +103,38 @@ def InGraphScope : PredOpTrait<
 
 // -----
 
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+
+    ```
+    spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
+    ```
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
 def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
   let summary = [{
     Declare a graph entry point and its interface.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 8dfdfea8a5c54..953406da60a57 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1140,13 +1140,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
 
 ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  SmallVector<Type, 0> idTypes;
   SmallVector<Attribute, 4> interfaceVars;
 
   FlatSymbolRefAttr fn;
-  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
     return failure();
-  }
 
   if (!parser.parseOptionalComma()) {
     // Parse the interface variables
@@ -1224,7 +1222,7 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
       getResAttrsAttrName(result.name));
 
   // Parse the optional function body.
-  auto *body = result.addRegion();
+  Region *body = result.addRegion();
   OptionalParseResult parseResult =
       parser.parseOptionalRegion(*body, entryArgs);
   return failure(parseResult.has_value() && failed(*parseResult));
@@ -1234,7 +1232,7 @@ void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
   // Print graph name, signature, and control.
   printer << " ";
   printer.printSymbolName(getSymName());
-  auto grType = getFunctionType();
+  GraphType grType = getFunctionType();
   function_interface_impl::printFunctionSignature(
       printer, *this, grType.getInputs(),
       /*isVariadic=*/false, grType.getResults());
@@ -1288,9 +1286,10 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
                << grType.getNumResults() << " results";
 
       auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0; i < graphOutputOperandTypes.size(); ++i) {
-        auto graphOutputOperandType = graphOutputOperandTypes[i];
-        auto grResultType = grType.getResult(i);
+      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
+           ++i) {
+        Type graphOutputOperandType = graphOutputOperandTypes[i];
+        Type grResultType = grType.getResult(i);
         if (graphOutputOperandType != grResultType)
           return graphOutputsARMOp.emitError("type of return operand ")
                  << i << " (" << graphOutputOperandType
@@ -1339,13 +1338,13 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
   auto graph = cast<GraphARMOp>((*this)->getParentOp());
 
   // The operand number and types must match the graph signature.
-  const auto &results = graph.getFunctionType().getResults();
+  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
   if (getNumOperands() != results.size())
     return emitOpError("has ")
            << getNumOperands() << " operands, but enclosing graph (@"
            << graph.getName() << ") returns " << results.size();
 
-  for (unsigned i = 0; i < results.size(); i++)
+  for (unsigned i = 0, size = results.size(); i < size; ++i)
     if (getOperand(i).getType() != results[i])
       return emitError() << "type of return operand " << i << " ("
                          << getOperand(i).getType()
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index fd97b09d802f1..a2d221252fb69 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -160,8 +160,8 @@ void UpdateVCEPass::runOnOperation() {
 
     // If the op is FunctionLike make sure to process input and result types
     if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
-      auto inputTypes = funcOpInterface.getArgumentTypes();
-      auto resultTypes = funcOpInterface.getResultTypes();
+      ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
+      ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
       valueTypes.append(inputTypes.begin(), inputTypes.end());
       valueTypes.append(resultTypes.begin(), resultTypes.end());
     }
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index bc1505d32d4d5..4ef242bdc5b16 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -315,11 +315,11 @@ func.func @constant_composite_replicate() -> () {
 spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
   // CHECK: spirv.ARM.GraphOutputs min version: v1.0
   // CHECK: spirv.ARM.GraphOutputs max version: v1.6
-  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+  // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
   // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
   spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
 // CHECK: spirv.ARM.Graph min version: v1.0
 // CHECK: spirv.ARM.Graph max version: v1.6
-// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model] ]
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
 // CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 90c31e19db382..6919c7eecc632 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -1,29 +1,68 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph and spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ARM.GraphConstant
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Shader, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<14xi32>
-  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<14xi32>
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
+  spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
+    // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
+    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+  }
+}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphEntryPoint
+//===----------------------------------------------------------------------===//
+
 
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
   // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @main_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
-  spirv.GlobalVariable @main_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<2x3xi16>, UniformConstant>
+  spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+  spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
   // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
-  spirv.ARM.GraphEntryPoint @main, @main_arg_0, @main_res_0
-  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
-  spirv.ARM.Graph @main(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<2x3xi16> attributes {entry_point = true} {
-    // CHECK: [[CONST2:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
-    %1 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
-    // CHECK: spirv.ARM.GraphOutputs [[OUT:%.*]] : !spirv.arm.tensor<2x3xi16>
-    spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<2x3xi16>
+  spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
+  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+  }
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Multiple spirv.ARM.Graphs
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
+  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
   }
 
   // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-  spirv.ARM.Graph @empty_graph(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+  spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
     // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
     spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
   }
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index cf9d86576b1f6..18958cef7b00a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -232,10 +232,10 @@ spirv.module Logical GLSL450 attributes {
   }
 }
 
-// CHECK: spirv.module Logical Vulkan requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
 spirv.module Logical Vulkan attributes {
   spirv.target_env = #spirv.target_env<
-    #spirv.vce<v1.5, [GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph, SPV_KHR_vulkan_memory_model]>,
+    #spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
     #spirv.resource_limits<>>
 } {
   spirv.ARM.Graph @argmax(%arg0 : !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {

>From 3cf2ee9ce028265868e05ac84d561b09db4e847e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 6 Aug 2025 12:14:31 +0200
Subject: [PATCH 3/6] Fix one more comment

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia24695d965919ffebdff9945cd7d72233faa922a
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 953406da60a57..fdefd5e3966ca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1167,7 +1167,7 @@ ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
 void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   printer << " ";
   printer.printSymbolName(getFn());
-  auto interfaceVars = getInterface().getValue();
+  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
   if (!interfaceVars.empty()) {
     printer << ", ";
     llvm::interleaveComma(interfaceVars, printer);

>From 5ae57fe2d6fdfd8ccf95bd6a03c96bcc9db93ba1 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 Aug 2025 11:04:19 +0200
Subject: [PATCH 4/6] Resolve more review comments and expand testing

In particular add negative testing.

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Iee4ba17c74b451eda7f76c6f905ca12c734d39d6
---
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    |  36 +++++
 mlir/include/mlir/IR/CommonTypeConstraints.td |   1 -
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  36 +++--
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     | 131 +++++++++++++-----
 4 files changed, 151 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index f2913239cc4e8..51df4dc79ae68 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -57,6 +57,14 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
                         region
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+    }
+    ```
   }];
 
   let arguments = (ins
@@ -114,6 +122,12 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
     ```
     spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
     ```
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+    ```
   }];
 
   let arguments = (ins
@@ -157,6 +171,17 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
     entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
                        symbol-reference (`, ` symbol-reference)*
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+    spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+    spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        ...
+    }
+    ```
   }];
 
   let arguments = (ins
@@ -166,6 +191,9 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
 
   let results = (outs);
 
+  // Checks for graph and interface symbol reference are done in spirv::ModuleOp verification.
+  let hasVerifier = 0;
+
   let autogenSerialization = 0;
 
   let builders = [
@@ -189,6 +217,14 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
     ```
     graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
     ```
+
+    #### Example:
+
+    ```mlir
+    spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+        spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+    }
+    ```
   }];
 
   let arguments = (ins
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index aab1b01c5cff9..8ba2daefd97aa 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -393,7 +393,6 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
 def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
                               "graph type", "::mlir::GraphType">;
 
-
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fdefd5e3966ca..398dc046b3912 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1174,12 +1174,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   }
 }
 
-LogicalResult spirv::GraphEntryPointARMOp::verify() {
-  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
-  // verification.
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GraphARM
 //===----------------------------------------------------------------------===//
@@ -1257,7 +1251,19 @@ LogicalResult spirv::GraphARMOp::verifyType() {
 }
 
 LogicalResult spirv::GraphARMOp::verifyBody() {
-  GraphType grType = getFunctionType();
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
   if (!isExternal()) {
     Block &entryBlock = front();
 
@@ -1277,15 +1283,17 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
     }
   }
 
+  GraphType grType = getFunctionType();
   auto walkResult = walk([grType](Operation *op) -> WalkResult {
     if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
       if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("has GraphOutputsARM returning ")
+        return graphOutputsARMOp.emitOpError("is returning ")
                << graphOutputsARMOp.getNumOperands()
-               << "value(s) but enclosing graph requires "
-               << grType.getNumResults() << " results";
+               << " value(s) but enclosing spirv.ARM.Graph requires "
+               << grType.getNumResults() << " result(s)";
 
-      auto graphOutputOperandTypes = graphOutputsARMOp.getValue().getType();
+      ValueTypeRange<OperandRange> graphOutputOperandTypes =
+          graphOutputsARMOp.getValue().getType();
       for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
            ++i) {
         Type graphOutputOperandType = graphOutputOperandTypes[i];
@@ -1341,15 +1349,15 @@ LogicalResult spirv::GraphOutputsARMOp::verify() {
   const ArrayRef<Type> &results = graph.getFunctionType().getResults();
   if (getNumOperands() != results.size())
     return emitOpError("has ")
-           << getNumOperands() << " operands, but enclosing graph (@"
+           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
            << graph.getName() << ") returns " << results.size();
 
   for (unsigned i = 0, size = results.size(); i < size; ++i)
     if (getOperand(i).getType() != results[i])
       return emitError() << "type of return operand " << i << " ("
                          << getOperand(i).getType()
-                         << ") doesn't match graph result type (" << results[i]
-                         << ")"
+                         << ") doesn't match  spirv.ARM.Graph result type ("
+                         << results[i] << ")"
                          << " in graph @" << graph.getName();
 
   return success();
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 6919c7eecc632..591eaaea4c802 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -4,12 +4,10 @@
 // spirv.ARM.Graph and spirv.ARM.GraphOutputs
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
 }
 
 // -----
@@ -18,14 +16,12 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8,
 // spirv.ARM.GraphConstant
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
-  spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
-    // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
-    %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
-    // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
-    spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
-  }
+// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
+spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
+  // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+  %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+  // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
+  spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
 }
 // -----
 
@@ -33,37 +29,96 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8,
 // spirv.ARM.GraphEntryPoint
 //===----------------------------------------------------------------------===//
 
+// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
-  // CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
-  spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
-  // CHECK: spirv.ARM.Graph [[GN]]({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @entrypoint(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no terminator
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{empty block: expect at least a terminator}}
+spirv.ARM.Graph @graphNoterminator(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no result types
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}}
+spirv.ARM.Graph @graphNoOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> () {
 }
 
 // -----
 
 //===----------------------------------------------------------------------===//
-// Multiple spirv.ARM.Graphs
+// spirv.ARM.GraphConstant outside graph scope
 //===----------------------------------------------------------------------===//
 
-spirv.module Logical Vulkan requires #spirv.vce<v1.0, [VulkanMemoryModel, Int8, TensorsARM, GraphARM], [SPV_ARM_tensors, SPV_ARM_graph]> {
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-  spirv.ARM.Graph @graph1(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
-  }
+// expected-error @+1 {{'spirv.ARM.GraphConstant' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphOutputs outside graph scope
+//===----------------------------------------------------------------------===//
+
+%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+// expected-error @+1 {{'spirv.ARM.GraphOutputs' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
+  // expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}}
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
+  // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}}
+  spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+  // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}}
+  spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType argument
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0 : i8) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType result
+//===----------------------------------------------------------------------===//
 
-  // CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-  spirv.ARM.Graph @graph2(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
-    // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<1x16x16x16xi8>
-    spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
-  }
+// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0 : !spirv.arm.tensor<14x19xi16>) -> i8 {
 }

>From 1a4f0aa4486f25230821c3c08830729b0efd900e Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 8 Aug 2025 13:20:35 +0200
Subject: [PATCH 5/6] Extract SPV_ARM_graph operations in its own file

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ia74db44157fb724c9f787387386338814413db30
---
 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp | 264 ++++++++++++++++++++++
 mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt  |   1 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp    | 237 -------------------
 3 files changed, 265 insertions(+), 237 deletions(-)
 create mode 100644 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp

diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
new file mode 100644
index 0000000000000..e300596fd3733
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -0,0 +1,264 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
+//------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+                                     OperationState &result) {
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<DictionaryAttr> resultAttrs;
+  SmallVector<Type> resultTypes;
+  auto &builder = parser.getBuilder();
+
+  // Parse the name as a symbol.
+  StringAttr nameAttr;
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes))
+    return failure();
+
+  // Parse the function signature.
+  bool isVariadic = false;
+  if (function_interface_impl::parseFunctionSignatureWithArguments(
+          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+          resultAttrs))
+    return failure();
+
+  SmallVector<Type> argTypes;
+  for (auto &arg : entryArgs)
+    argTypes.push_back(arg.type);
+  auto grType = builder.getGraphType(argTypes, resultTypes);
+  result.addAttribute(getFunctionTypeAttrName(result.name),
+                      TypeAttr::get(grType));
+
+  // If additional attributes are present, parse them.
+  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+    return failure();
+
+  // Add the attributes to the function arguments.
+  assert(resultAttrs.size() == resultTypes.size());
+  call_interface_impl::addArgAndResultAttrs(
+      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+      getResAttrsAttrName(result.name));
+
+  // Parse the optional function body.
+  Region *body = result.addRegion();
+  OptionalParseResult parseResult =
+      parser.parseOptionalRegion(*body, entryArgs);
+  return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+  // Print graph name, signature, and control.
+  printer << " ";
+  printer.printSymbolName(getSymName());
+  GraphType grType = getFunctionType();
+  function_interface_impl::printFunctionSignature(
+      printer, *this, grType.getInputs(),
+      /*isVariadic=*/false, grType.getResults());
+  function_interface_impl::printFunctionAttributes(printer, *this,
+                                                   {getFunctionTypeAttrName(),
+                                                    getArgAttrsAttrName(),
+                                                    getResAttrsAttrName()});
+
+  // Print the body.
+  Region &body = this->getBody();
+  if (!body.empty()) {
+    printer << ' ';
+    printer.printRegion(body, /*printEntryBlockArgs=*/false,
+                        /*printBlockTerminators=*/true);
+  }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+  if (getFunctionType().getNumResults() < 1)
+    return emitOpError("there should be at least one result");
+  return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+    if (!isa<spirv::TensorArmType>(graphArgType)) {
+      return emitOpError("type of argument #")
+             << index << " must be a TensorArmType, but got " << graphArgType;
+    }
+  }
+  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+    if (!isa<spirv::TensorArmType>(graphResType)) {
+      return emitOpError("type of result #")
+             << index << " must be a TensorArmType, but got " << graphResType;
+    }
+  }
+
+  if (!isExternal()) {
+    Block &entryBlock = front();
+
+    unsigned numArguments = this->getNumArguments();
+    if (entryBlock.getNumArguments() != numArguments)
+      return emitOpError("entry block must have ")
+             << numArguments << " arguments to match graph signature";
+
+    for (auto [index, grArgType, blockArgType] :
+         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+      if (blockArgType != grArgType) {
+        return emitOpError("type of entry block argument #")
+               << index << '(' << blockArgType
+               << ") must match the type of the corresponding argument in "
+               << "graph signature(" << grArgType << ')';
+      }
+    }
+  }
+
+  GraphType grType = getFunctionType();
+  auto walkResult = walk([grType](Operation *op) -> WalkResult {
+    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
+      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
+        return graphOutputsARMOp.emitOpError("is returning ")
+               << graphOutputsARMOp.getNumOperands()
+               << " value(s) but enclosing spirv.ARM.Graph requires "
+               << grType.getNumResults() << " result(s)";
+
+      ValueTypeRange<OperandRange> graphOutputOperandTypes =
+          graphOutputsARMOp.getValue().getType();
+      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
+           ++i) {
+        Type graphOutputOperandType = graphOutputOperandTypes[i];
+        Type grResultType = grType.getResult(i);
+        if (graphOutputOperandType != grResultType)
+          return graphOutputsARMOp.emitError("type of return operand ")
+                 << i << " (" << graphOutputOperandType
+                 << ") doesn't match graph result type (" << grResultType
+                 << ")";
+      }
+    }
+    return WalkResult::advance();
+  });
+
+  return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+                              StringRef name, GraphType type,
+                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+  state.addAttribute(SymbolTable::getSymbolAttrName(),
+                     builder.getStringAttr(name));
+  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addAttribute(getEntryPointAttrName(state.name),
+                     builder.getBoolAttr(entryPoint));
+  state.addRegion();
+}
+
+// Returns the argument types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+  return getFunctionType().getInputs();
+}
+
+// Returns the result types of this function.
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+  return getFunctionType().getResults();
+}
+
+// CallableOpInterface
+Region *spirv::GraphARMOp::getCallableRegion() {
+  return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+  auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+  // The operand number and types must match the graph signature.
+  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
+  if (getNumOperands() != results.size())
+    return emitOpError("has ")
+           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
+           << graph.getName() << ") returns " << results.size();
+
+  for (unsigned i = 0, size = results.size(); i < size; ++i)
+    if (getOperand(i).getType() != results[i])
+      return emitError() << "type of return operand " << i << " ("
+                         << getOperand(i).getType()
+                         << ") doesn't match  spirv.ARM.Graph result type ("
+                         << results[i] << ")"
+                         << " in graph @" << graph.getName();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
+    return failure();
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
+  }
+  result.addAttribute("interface",
+                      parser.getBuilder().getArrayAttr(interfaceVars));
+  return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+  printer << " ";
+  printer.printSymbolName(getFn());
+  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
+  if (!interfaceVars.empty()) {
+    printer << ", ";
+    llvm::interleaveComma(interfaceVars, printer);
+  }
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b9aa7b7491abf..60d705d940cfc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
 add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
 
 add_mlir_dialect_library(MLIRSPIRVDialect
+  ArmGraphOps.cpp
   AtomicOps.cpp
   CastOps.cpp
   ControlFlowOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 398dc046b3912..f99339852824c 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1126,243 +1126,6 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.GraphEntryPointARM
-//===----------------------------------------------------------------------===//
-
-void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
-                                        OperationState &state,
-                                        spirv::GraphARMOp graph,
-                                        ArrayRef<Attribute> interfaceVars) {
-  build(builder, state, SymbolRefAttr::get(graph),
-        builder.getArrayAttr(interfaceVars));
-}
-
-ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
-                                               OperationState &result) {
-  SmallVector<Attribute, 4> interfaceVars;
-
-  FlatSymbolRefAttr fn;
-  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
-    return failure();
-
-  if (!parser.parseOptionalComma()) {
-    // Parse the interface variables
-    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
-          // The name of the interface variable attribute isnt important
-          FlatSymbolRefAttr var;
-          NamedAttrList attrs;
-          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
-            return failure();
-          interfaceVars.push_back(var);
-          return success();
-        }))
-      return failure();
-  }
-  result.addAttribute("interface",
-                      parser.getBuilder().getArrayAttr(interfaceVars));
-  return success();
-}
-
-void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
-  printer << " ";
-  printer.printSymbolName(getFn());
-  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
-  if (!interfaceVars.empty()) {
-    printer << ", ";
-    llvm::interleaveComma(interfaceVars, printer);
-  }
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GraphARM
-//===----------------------------------------------------------------------===//
-
-ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
-                                     OperationState &result) {
-  SmallVector<OpAsmParser::Argument> entryArgs;
-  SmallVector<DictionaryAttr> resultAttrs;
-  SmallVector<Type> resultTypes;
-  auto &builder = parser.getBuilder();
-
-  // Parse the name as a symbol.
-  StringAttr nameAttr;
-  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
-                             result.attributes))
-    return failure();
-
-  // Parse the function signature.
-  bool isVariadic = false;
-  if (function_interface_impl::parseFunctionSignatureWithArguments(
-          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
-          resultAttrs))
-    return failure();
-
-  SmallVector<Type> argTypes;
-  for (auto &arg : entryArgs)
-    argTypes.push_back(arg.type);
-  auto grType = builder.getGraphType(argTypes, resultTypes);
-  result.addAttribute(getFunctionTypeAttrName(result.name),
-                      TypeAttr::get(grType));
-
-  // If additional attributes are present, parse them.
-  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
-    return failure();
-
-  // Add the attributes to the function arguments.
-  assert(resultAttrs.size() == resultTypes.size());
-  call_interface_impl::addArgAndResultAttrs(
-      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
-      getResAttrsAttrName(result.name));
-
-  // Parse the optional function body.
-  Region *body = result.addRegion();
-  OptionalParseResult parseResult =
-      parser.parseOptionalRegion(*body, entryArgs);
-  return failure(parseResult.has_value() && failed(*parseResult));
-}
-
-void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
-  // Print graph name, signature, and control.
-  printer << " ";
-  printer.printSymbolName(getSymName());
-  GraphType grType = getFunctionType();
-  function_interface_impl::printFunctionSignature(
-      printer, *this, grType.getInputs(),
-      /*isVariadic=*/false, grType.getResults());
-  function_interface_impl::printFunctionAttributes(printer, *this,
-                                                   {getFunctionTypeAttrName(),
-                                                    getArgAttrsAttrName(),
-                                                    getResAttrsAttrName()});
-
-  // Print the body.
-  Region &body = this->getBody();
-  if (!body.empty()) {
-    printer << ' ';
-    printer.printRegion(body, /*printEntryBlockArgs=*/false,
-                        /*printBlockTerminators=*/true);
-  }
-}
-
-LogicalResult spirv::GraphARMOp::verifyType() {
-  if (getFunctionType().getNumResults() < 1)
-    return emitOpError("there should be at least one result");
-  return success();
-}
-
-LogicalResult spirv::GraphARMOp::verifyBody() {
-  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
-    if (!isa<spirv::TensorArmType>(graphArgType)) {
-      return emitOpError("type of argument #")
-             << index << " must be a TensorArmType, but got " << graphArgType;
-    }
-  }
-  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
-    if (!isa<spirv::TensorArmType>(graphResType)) {
-      return emitOpError("type of result #")
-             << index << " must be a TensorArmType, but got " << graphResType;
-    }
-  }
-
-  if (!isExternal()) {
-    Block &entryBlock = front();
-
-    unsigned numArguments = this->getNumArguments();
-    if (entryBlock.getNumArguments() != numArguments)
-      return emitOpError("entry block must have ")
-             << numArguments << " arguments to match graph signature";
-
-    for (auto [index, grArgType, blockArgType] :
-         llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
-      if (blockArgType != grArgType) {
-        return emitOpError("type of entry block argument #")
-               << index << '(' << blockArgType
-               << ") must match the type of the corresponding argument in "
-               << "graph signature(" << grArgType << ')';
-      }
-    }
-  }
-
-  GraphType grType = getFunctionType();
-  auto walkResult = walk([grType](Operation *op) -> WalkResult {
-    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
-      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("is returning ")
-               << graphOutputsARMOp.getNumOperands()
-               << " value(s) but enclosing spirv.ARM.Graph requires "
-               << grType.getNumResults() << " result(s)";
-
-      ValueTypeRange<OperandRange> graphOutputOperandTypes =
-          graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
-           ++i) {
-        Type graphOutputOperandType = graphOutputOperandTypes[i];
-        Type grResultType = grType.getResult(i);
-        if (graphOutputOperandType != grResultType)
-          return graphOutputsARMOp.emitError("type of return operand ")
-                 << i << " (" << graphOutputOperandType
-                 << ") doesn't match graph result type (" << grResultType
-                 << ")";
-      }
-    }
-    return WalkResult::advance();
-  });
-
-  return failure(walkResult.wasInterrupted());
-}
-
-void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
-                              StringRef name, GraphType type,
-                              ArrayRef<NamedAttribute> attrs, bool entryPoint) {
-  state.addAttribute(SymbolTable::getSymbolAttrName(),
-                     builder.getStringAttr(name));
-  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
-  state.attributes.append(attrs.begin(), attrs.end());
-  state.addAttribute(getEntryPointAttrName(state.name),
-                     builder.getBoolAttr(entryPoint));
-  state.addRegion();
-}
-
-// Returns the argument types of this function.
-ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
-  return getFunctionType().getInputs();
-}
-
-// Returns the result types of this function.
-ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
-  return getFunctionType().getResults();
-}
-
-// CallableOpInterface
-Region *spirv::GraphARMOp::getCallableRegion() {
-  return isExternal() ? nullptr : &getBody();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.GraphOutputsARM
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::GraphOutputsARMOp::verify() {
-  auto graph = cast<GraphARMOp>((*this)->getParentOp());
-
-  // The operand number and types must match the graph signature.
-  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
-  if (getNumOperands() != results.size())
-    return emitOpError("has ")
-           << getNumOperands() << " operands, but enclosing  spirv.ARM.Graph (@"
-           << graph.getName() << ") returns " << results.size();
-
-  for (unsigned i = 0, size = results.size(); i < size; ++i)
-    if (getOperand(i).getType() != results[i])
-      return emitError() << "type of return operand " << i << " ("
-                         << getOperand(i).getType()
-                         << ") doesn't match  spirv.ARM.Graph result type ("
-                         << results[i] << ")"
-                         << " in graph @" << graph.getName();
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.GLFClampOp
 //===----------------------------------------------------------------------===//

>From f4c2c3a524e2d432b3fc84bb6eb55a21dfcaa7b1 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Mon, 18 Aug 2025 16:36:11 +0200
Subject: [PATCH 6/6] Fix review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: Ib592a4b99af01c0d3c88eaf63a61cb4c6cca8cbe
---
 .../mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td    | 15 +----
 mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp     | 63 ++++++++-----------
 .../SPIRV/Transforms/UpdateVCEPass.cpp        |  6 +-
 mlir/lib/IR/AsmPrinter.cpp                    |  7 +--
 mlir/test/Dialect/SPIRV/IR/graph-ops.mlir     |  2 +-
 .../lib/Dialect/SPIRV/TestAvailability.cpp    |  2 +-
 6 files changed, 35 insertions(+), 60 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
index 51df4dc79ae68..ff2ec7363fe38 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -51,7 +51,7 @@ def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
     that is unique in the enclosing module op.
 
     This op itself takes no operands and generates no results. Its region
-    can take zero or more arguments and return zero or more values.
+    can take zero or more arguments and return one or more values.
 
     ```
     spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
@@ -119,10 +119,6 @@ def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope,
     Result Type must be an OpTypeTensorARM.
     GraphConstantID must be a 32-bit integer literal.
 
-    ```
-    spv-graph-constant-arm-op ::= `spirv.ARM.GraphConstant` { graph_constant_id = 42 : i32 }
-    ```
-
     #### Example:
 
     ```mlir
@@ -167,11 +163,6 @@ def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleSc
     `spirv.GlobalVariable`s referenced by the entry point’s static call
     tree, within the interface’s storage classes.
 
-    ```
-    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
-                       symbol-reference (`, ` symbol-reference)*
-    ```
-
     #### Example:
 
     ```mlir
@@ -214,10 +205,6 @@ def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pu
 
     This instruction must be the last instruction in a block.
 
-    ```
-    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
-    ```
-
     #### Example:
 
     ```mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
index e300596fd3733..722e771425a47 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -1,5 +1,4 @@
-//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations
-//------------------------------===//
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -22,6 +21,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
 
 using namespace mlir;
 using namespace mlir::spirv::AttrNames;
@@ -32,10 +32,7 @@ using namespace mlir::spirv::AttrNames;
 
 ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
-  SmallVector<OpAsmParser::Argument> entryArgs;
-  SmallVector<DictionaryAttr> resultAttrs;
-  SmallVector<Type> resultTypes;
-  auto &builder = parser.getBuilder();
+  Builder &builder = parser.getBuilder();
 
   // Parse the name as a symbol.
   StringAttr nameAttr;
@@ -45,15 +42,18 @@ ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
 
   // Parse the function signature.
   bool isVariadic = false;
+  SmallVector<OpAsmParser::Argument> entryArgs;
+  SmallVector<Type> resultTypes;
+  SmallVector<DictionaryAttr> resultAttrs;
   if (function_interface_impl::parseFunctionSignatureWithArguments(
           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
           resultAttrs))
     return failure();
 
   SmallVector<Type> argTypes;
-  for (auto &arg : entryArgs)
+  for (OpAsmParser::Argument &arg : entryArgs)
     argTypes.push_back(arg.type);
-  auto grType = builder.getGraphType(argTypes, resultTypes);
+  GraphType grType = builder.getGraphType(argTypes, resultTypes);
   result.addAttribute(getFunctionTypeAttrName(result.name),
                       TypeAttr::get(grType));
 
@@ -136,26 +136,22 @@ LogicalResult spirv::GraphARMOp::verifyBody() {
   }
 
   GraphType grType = getFunctionType();
-  auto walkResult = walk([grType](Operation *op) -> WalkResult {
-    if (auto graphOutputsARMOp = dyn_cast<spirv::GraphOutputsARMOp>(op)) {
-      if (grType.getNumResults() != graphOutputsARMOp.getNumOperands())
-        return graphOutputsARMOp.emitOpError("is returning ")
-               << graphOutputsARMOp.getNumOperands()
-               << " value(s) but enclosing spirv.ARM.Graph requires "
-               << grType.getNumResults() << " result(s)";
-
-      ValueTypeRange<OperandRange> graphOutputOperandTypes =
-          graphOutputsARMOp.getValue().getType();
-      for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size;
-           ++i) {
-        Type graphOutputOperandType = graphOutputOperandTypes[i];
-        Type grResultType = grType.getResult(i);
-        if (graphOutputOperandType != grResultType)
-          return graphOutputsARMOp.emitError("type of return operand ")
-                 << i << " (" << graphOutputOperandType
-                 << ") doesn't match graph result type (" << grResultType
-                 << ")";
-      }
+  auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+    if (grType.getNumResults() != op.getNumOperands())
+      return op.emitOpError("is returning ")
+             << op.getNumOperands()
+             << " value(s) but enclosing spirv.ARM.Graph requires "
+             << grType.getNumResults() << " result(s)";
+
+    ValueTypeRange<OperandRange> graphOutputOperandTypes =
+        op.getValue().getType();
+    for (unsigned i = 0, size = graphOutputOperandTypes.size(); i < size; ++i) {
+      Type graphOutputOperandType = graphOutputOperandTypes[i];
+      Type grResultType = grType.getResult(i);
+      if (graphOutputOperandType != grResultType)
+        return op.emitError("type of return operand ")
+               << i << " (" << graphOutputOperandType
+               << ") doesn't match graph result type (" << grResultType << ")";
     }
     return WalkResult::advance();
   });
@@ -169,23 +165,20 @@ void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
   state.addAttribute(SymbolTable::getSymbolAttrName(),
                      builder.getStringAttr(name));
   state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
-  state.attributes.append(attrs.begin(), attrs.end());
+  state.attributes.append(attrs);
   state.addAttribute(getEntryPointAttrName(state.name),
                      builder.getBoolAttr(entryPoint));
   state.addRegion();
 }
 
-// Returns the argument types of this function.
 ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
   return getFunctionType().getInputs();
 }
 
-// Returns the result types of this function.
 ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
   return getFunctionType().getResults();
 }
 
-// CallableOpInterface
 Region *spirv::GraphARMOp::getCallableRegion() {
   return isExternal() ? nullptr : &getBody();
 }
@@ -229,12 +222,11 @@ void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
 
 ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
                                                OperationState &result) {
-  SmallVector<Attribute, 4> interfaceVars;
-
   FlatSymbolRefAttr fn;
   if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
     return failure();
 
+  SmallVector<Attribute, 4> interfaceVars;
   if (!parser.parseOptionalComma()) {
     // Parse the interface variables
     if (parser.parseCommaSeparatedList([&]() -> ParseResult {
@@ -258,7 +250,6 @@ void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
   printer.printSymbolName(getFn());
   ArrayRef<Attribute> interfaceVars = getInterface().getValue();
   if (!interfaceVars.empty()) {
-    printer << ", ";
-    llvm::interleaveComma(interfaceVars, printer);
+    printer << ", " << llvm::interleaved(interfaceVars);
   }
 }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index a2d221252fb69..e8c299932f06d 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -160,10 +160,8 @@ void UpdateVCEPass::runOnOperation() {
 
     // If the op is FunctionLike make sure to process input and result types
     if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
-      ArrayRef<Type> inputTypes = funcOpInterface.getArgumentTypes();
-      ArrayRef<Type> resultTypes = funcOpInterface.getResultTypes();
-      valueTypes.append(inputTypes.begin(), inputTypes.end());
-      valueTypes.append(resultTypes.begin(), resultTypes.end());
+      llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
+      llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
     }
 
     // Requirements from values' types
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9a5dbcf6f598e..348068dbc84c8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,8 +104,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
   // it is a function (avoiding a grammar ambiguity).
   bool wrapped = op->getNumResults() != 1;
   if (!wrapped && op->getResult(0).getType() &&
-      (llvm::isa<FunctionType>(op->getResult(0).getType()) ||
-       llvm::isa<GraphType>(op->getResult(0).getType())))
+      isa<GraphType>(op->getResult(0).getType()))
     wrapped = true;
 
   if (wrapped)
@@ -2842,8 +2841,8 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
         interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
         os << ") -> ";
         ArrayRef<Type> results = graphTy.getResults();
-        if (results.size() == 1 && !(llvm::isa<FunctionType>(results[0]) ||
-                                     llvm::isa<GraphType>(results[0]))) {
+        if (results.size() == 1 &&
+            !(isa<FunctionType>(results[0]) || isa<GraphType>(results[0]))) {
           printType(results[0]);
         } else {
           os << '(';
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
index 591eaaea4c802..c7763d45c6b5e 100644
--- a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // spirv.ARM.Graph and spirv.ARM.GraphOutputs
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 9efca825a663d..5643a0ff5b91c 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -33,7 +33,7 @@ struct PrintOpAvailability
 } // namespace
 
 void PrintOpAvailability::runOnOperation() {
-  auto moduleOp = getOperation();
+  mlir::ModuleOp moduleOp = getOperation();
   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
 
   auto opCallback = [&](Operation *op) {



More information about the Mlir-commits mailing list