[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_graph extension (PR #147937)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 10 04:06:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-ods

Author: Davide Grohmann (davidegrohmann)

<details>
<summary>Changes</summary>

This patch adds support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources.

The implementation includes:

- A new `GraphType`, modeled similarly to `FunctionType`, for typed graph signatures.
- New operations in the `spirv.arm` namespace:
  - `spirv.arm.Graph`
  - `spirv.arm.GraphEntryPoint`
  - `spirv.arm.GraphConstant`
  - `spirv.arm.GraphOutput`
- Serialization and deserialization support for:
  - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM`
  - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM`
- ABI lowering support for graph entry points via `LowerABIAttributesPass`.
- Verifier and VCE updates to properly gate usage under `SPV_ARM_graph`.
- Tests covering parsing, verification, ABI handling, and binary round-tripping.

Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images.

Spec: https://github.com/KhronosGroup/SPIRV-Registry/pull/346
RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947

---

Patch is 85.76 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147937.diff


27 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+25-2) 
- (added) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td (+201) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td (+1) 
- (modified) mlir/include/mlir/IR/Builders.h (+2) 
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+11-7) 
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+7) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+10-4) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+12) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+230) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp (+137-1) 
- (modified) mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp (+8) 
- (modified) mlir/lib/IR/AsmPrinter.cpp (+16-1) 
- (modified) mlir/lib/IR/Builders.cpp (+4) 
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+43) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp (+21) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+304) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+49-2) 
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+126) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+82-4) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+37-4) 
- (modified) mlir/test/Dialect/SPIRV/IR/availability.mlir (+17) 
- (added) mlir/test/Dialect/SPIRV/IR/graph-ops.mlir (+30) 
- (modified) mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir (+22-1) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir (+22) 
- (modified) mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir (+11) 
- (added) mlir/test/Target/SPIRV/graph-ops.mlir (+24) 
- (modified) mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp (+13-5) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 910418f1706a6..ce4bb6c2e4934 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -423,6 +423,7 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
 def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph                        : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
 
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -447,7 +448,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
-      SPV_ARM_tensors,
+      SPV_ARM_tensors, SPV_ARM_graph,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1332,6 +1333,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"Stora
     Extension<[SPV_ARM_tensors]>
   ];
 }
+def SPIRV_C_GraphARM                                    : I32EnumAttrCase<"GraphARM", 4191> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader, SPIRV_C_VulkanMemoryModel];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_graph]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1545,7 +1552,7 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
       SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
-      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4245,6 +4252,7 @@ def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
+
 def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
 def SPIRV_Composite :
     AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
@@ -4551,6 +4559,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
 def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM               : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM             : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM                       : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM                  : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM              : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM                    : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM                   : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4666,6 +4681,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
       SPIRV_OC_OpGroupNonUniformLogicalXor,
       SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+      SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+      SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
       SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4836,6 +4854,11 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
   SPIRV_VendorOp<mnemonic, "NV", traits> {
 }
 
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
+
 def SPIRV_FPFMM_None         : I32BitEnumAttrCaseNone<"None">;
 def SPIRV_FPFMM_NotNaN       : I32BitEnumAttrCaseBit<"NotNaN", 0>;
 def SPIRV_FPFMM_NotInf       : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..38fb4b2eff414
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,201 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+  SPIRV_ArmVendorOp<mnemonic, traits> {
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>,
+    Capability<[SPIRV_C_GraphARM]>
+  ];
+}
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [Pure]> {
+  let summary = "Declare a graph constant.";
+
+  let description = [{
+    Declare a graph constant.
+    Result Type must be an OpTypeTensorARM.
+    GraphConstantID must be a 32-bit integer literal.
+  }];
+
+  let arguments = (ins
+    I32Attr: $graph_constant_id
+  );
+
+  let results = (outs
+    SPIRV_AnyTensorArm:$output
+  );
+
+  let hasVerifier = 0;
+
+  let autogenSerialization = 0;
+
+  let assemblyFormat = [{
+    attr-dict `:` type($output)
+  }];
+}
+
+// -----
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+    AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+    FunctionOpInterface, InModuleScope, IsolatedFromAbove
+  ]> {
+
+  let summary = "Declare or define a SPIR-V graph";
+
+  let description = [{
+    This op declares or defines a SPIR-V graph using one region, which
+    contains one or more blocks.
+
+    Different from the SPIR-V binary format, this op is not allowed to
+    implicitly capture global values, and all external references must use
+    function arguments or symbol references. This op itself defines a symbol
+    that is unique in the enclosing module op.
+
+    This op itself takes no operands and generates no results. Its region
+    can take zero or more arguments and return zero or more values.
+
+    ```
+    spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+                        region
+    ```
+  }];
+
+  let arguments = (ins
+    TypeAttrOf<GraphType>:$function_type,
+    OptionalAttr<DictArrayAttr>:$arg_attrs,
+    OptionalAttr<DictArrayAttr>:$res_attrs,
+    OptionalAttr<BoolAttr>:$entry_point,
+    StrAttr:$sym_name
+  );
+
+  let results = (outs);
+
+  let regions = (region AnyRegion:$body);
+
+  let hasVerifier = 0;
+
+  let builders = [
+    OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,  CArg<"bool", "false">:$entry_point)>];
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+
+  let extraClassDeclaration = [{
+    /// Hook for FunctionOpInterface, called after verifying that the 'type'
+    /// attribute is present and checks if it holds a function type. Ensures
+    /// getType, getNumArguments, and getNumResults can be called safely
+    LogicalResult verifyType();
+
+    /// Hook for FunctionOpInterface, called after verifying the function
+    /// type and the presence of the (potentially empty) function body.
+    /// Ensures SPIR-V specific semantics.
+    LogicalResult verifyBody();
+  }];
+}
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+  "op must appear in a spirv.ARM.Graph op's block",
+  CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+  let summary = [{
+    Declare a graph entry point and its interface.
+  }];
+
+  let description = [{
+    Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+    Name is a name string for the graphentry point. A module cannot have two
+    OpGraphEntryPointARM instructions with the same Name string.
+
+    Interface is a list of symbol references to `spirv.GlobalVariable`
+    operations. These declare the set of global variables from a
+    module that form the interface of this entry point. The set of
+    Interface symbols must be equal to or a superset of the
+    `spirv.GlobalVariable`s referenced by the entry point’s static call
+    tree, within the interface’s storage classes.
+
+    ```
+    entry-point-op ::= ssa-id `=` `spirv.ARM.GraphEntryPoint`
+                       symbol-reference (`, ` symbol-reference)*
+    ```
+  }];
+
+  let arguments = (ins
+    FlatSymbolRefAttr:$fn,
+    SymbolRefArrayAttr:$interface
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let builders = [
+    OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+                                               Terminator]> {
+
+  let summary = "Define graph outputs.";
+
+  let description = [{
+    Values are the graph outputs values and must match the GraphOutputs Type
+    operand of the OpTypeGraphARM type of the OpGraphARM body this
+    instruction is in.
+
+    This instruction must be the last instruction in a block.
+
+    ```
+    graph-output-op ::= `spirv.ARM.GraphOutputs` ssa-use `:` type-list-no-parens
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<SPIRV_AnyTensorArm>:$value
+  );
+
+  let results = (outs);
+
+  let autogenSerialization = 0;
+
+  let hasOpcode = 0;
+
+  let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
 include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index ad59ea63a6901..aa7d30b87db14 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
 class IntegerType;
 class FloatType;
 class FunctionType;
+class GraphType;
 class IndexType;
 class MemRefType;
 class VectorType;
@@ -81,6 +82,7 @@ class Builder {
   IntegerType getIntegerType(unsigned width);
   IntegerType getIntegerType(unsigned width, bool isSigned);
   FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+  GraphType getGraphType(TypeRange inputs, TypeRange results);
   TupleType getTupleType(TypeRange elementTypes);
   NoneType getNoneType();
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
 // FunctionType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
   let summary = "Map from a list of inputs to a list of results";
   let description = [{
     Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let storageClass = "FunctionTypeStorage";
   let genStorageClass = 0;
   let extraClassDeclaration = [{
     /// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
     unsigned getNumResults() const;
     Type getResult(unsigned i) const { return getResults()[i]; }
 
-    /// Returns a clone of this function type with the given argument
+    /// Returns a clone of this function-like type with the given argument
     /// and result types.
-    FunctionType clone(TypeRange inputs, TypeRange results) const;
+    }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
 
-    /// Returns a new function type with the specified arguments and results
+    /// Returns a new function-like type with the specified arguments and results
     /// inserted.
-    FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+    }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
                                        TypeRange argTypes,
                                        ArrayRef<unsigned> resultIndices,
                                        TypeRange resultTypes);
 
-    /// Returns a new function type without the specified arguments and results.
-    FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+    /// Returns a new function-like type without the specified arguments and results.
+    }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
                                           const BitVector &resultIndices);
   }];
 }
 
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
 //===----------------------------------------------------------------------===//
 // IndexType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 45ec1846580f2..aab1b01c5cff9 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,13 @@ class OpaqueType<string dialect, string name, string summary>
 def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
                               "function type", "::mlir::FunctionType">;
 
+// Graph Type
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+                              "graph type", "::mlir::GraphType">;
+
+
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
                     string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 88c7adf3dfcb3..e66d4b0ffc446 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -1019,8 +1019,14 @@ LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
   return verifyRegionAttribute(op->getLoc(), argType, attribute);
 }
 
-LogicalResult SPIRVDialect::verifyRegionResultAttribute(
-    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
-    NamedAttribute attribute) {
-  return op->emitError("cannot attach SPIR-V attributes to region result");
+LogicalResult SPIRVDialect::verifyRegionResultAttribute(Operation *op,
+                                                        unsigned regionIndex,
+                                                        unsigned resultIndex,
+                                                        NamedAttribute attribute) {
+  auto funcOp = dyn_cast<FunctionOpInterface>(op);
+  if (!funcOp)
+    return op->emitError("cannot attach SPIR-V attributes to region result which is "
+                         "not a FunctionOpInterface type");
+  return verifyRegionAttribute(
+      op->getLoc(), funcOp.getResultTypes()[resultIndex], attribute);
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
   return isNestedInFunctionOpInterface(op->getParentOp());
 }
 
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+  if (!op)
+    return false;
+  if (op->hasTrait<OpTrait::SymbolTable>())
+    return false;
+  if (isa<spirv::GraphARMOp>(op))
+    return true;
+  return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
 /// Returns true if the given op is an module-like op that maintains a symbol
 /// table.
 static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index eb2974d62fdd1..17cbab189588f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1084,6 +1084,236 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
   state.addRegion();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+                                        OperationState &state,
+                                        spirv::GraphARMOp graph,
+                                        ArrayRef<Attribute> interfaceVars) {
+  build(builder, state, SymbolRefAttr::get(graph),
+        builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+                                               OperationState &result) {
+  SmallVector<Type, 0> idTypes;
+  SmallVector<Attribute, 4> interfaceVars;
+
+  FlatSymbolRefAttr fn;
+  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
+    return failure();
+  }
+
+  if (!parser.parseOptionalComma()) {
+    // Parse the interface variables
+    if ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/147937


More information about the Mlir-commits mailing list