[Mlir-commits] [mlir] cf3a887 - [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (#151934)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 2 12:18:15 PDT 2025
Author: Davide Grohmann
Date: 2025-09-02T15:18:10-04:00
New Revision: cf3a8876f4129f76884a67f6db9214adb7adedc6
URL: https://github.com/llvm/llvm-project/commit/cf3a8876f4129f76884a67f6db9214adb7adedc6
DIFF: https://github.com/llvm/llvm-project/commit/cf3a8876f4129f76884a67f6db9214adb7adedc6.diff
LOG: [mlir][spirv] Add support for SPV_ARM_graph extension - part 1 (#151934)
This is the first patch to add support for the SPV_ARM_graph SPIR-V
extension to MLIR’s SPIR-V dialect. The extension introduces a new Graph
abstraction for expressing dataflow computations over full resources.
The part 1 implementation includes:
- A new `GraphType`, modeled similarly to `FunctionType`, for typed
graph signatures.
- New operations in the `spirv.arm` namespace:
- `spirv.arm.Graph`
- `spirv.arm.GraphEntryPoint`
- `spirv.arm.GraphConstant`
- `spirv.arm.GraphOutput`
- Verifier and VCE updates to properly gate usage under SPV_ARM_graph.
- Tests covering parsing and verification.
Graphs currently support only SPV_ARM_tensors, but are designed to
generalize to other resource types, such as images.
Spec: KhronosGroup/SPIRV-Registry#346
RFC:
https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947
---------
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Added:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/include/mlir/IR/CommonTypeConstraints.td
mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/Dialect/SPIRV/IR/availability.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index bdfd728d1d0b3..0e42d08cdb1fc 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -425,6 +425,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+def SPV_ARM_graph : I32EnumAttrCase<"SPV_ARM_graph", 6001>;
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
@@ -449,7 +450,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader, SPV_EXT_replicated_composites,
- SPV_ARM_tensors,
+ SPV_ARM_tensors, SPV_ARM_graph,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1341,6 +1342,12 @@ def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"Stora
Extension<[SPV_ARM_tensors]>
];
}
+def SPIRV_C_GraphARM : I32EnumAttrCase<"GraphARM", 4191> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_graph]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1560,7 +1567,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
- SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT, SPIRV_C_GraphARM,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4569,6 +4576,13 @@ def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
+def SPIRV_OC_OpGraphConstantARM : I32EnumAttrCase<"OpGraphConstantARM", 4181>;
+def SPIRV_OC_OpGraphEntryPointARM : I32EnumAttrCase<"OpGraphEntryPointARM", 4182>;
+def SPIRV_OC_OpGraphARM : I32EnumAttrCase<"OpGraphARM", 4183>;
+def SPIRV_OC_OpGraphInputARM : I32EnumAttrCase<"OpGraphInputARM", 4184>;
+def SPIRV_OC_OpGraphSetOutputARM : I32EnumAttrCase<"OpGraphSetOutputARM", 4185>;
+def SPIRV_OC_OpGraphEndARM : I32EnumAttrCase<"OpGraphEndARM", 4186>;
+def SPIRV_OC_OpTypeGraphARM : I32EnumAttrCase<"OpTypeGraphARM", 4190>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4689,6 +4703,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
SPIRV_OC_OpGroupNonUniformLogicalXor,
SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpGraphEntryPointARM, SPIRV_OC_OpGraphARM,
+ SPIRV_OC_OpGraphInputARM, SPIRV_OC_OpGraphSetOutputARM, SPIRV_OC_OpGraphEndARM,
+ SPIRV_OC_OpTypeGraphARM, SPIRV_OC_OpGraphConstantARM,
SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
@@ -4862,6 +4879,10 @@ class SPIRV_NvVendorOp<string mnemonic, list<Trait> traits = []> :
SPIRV_VendorOp<mnemonic, "NV", traits> {
}
+class SPIRV_ArmVendorOp<string mnemonic, list<Trait> traits = []> :
+ SPIRV_VendorOp<mnemonic, "ARM", traits> {
+}
+
def SPIRV_FPFMM_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FPFMM_NotNaN : I32BitEnumAttrCaseBit<"NotNaN", 0>;
def SPIRV_FPFMM_NotInf : I32BitEnumAttrCaseBit<"NotInf", 1>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
new file mode 100644
index 0000000000000..69551a9c0b976
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td
@@ -0,0 +1,242 @@
+//===- SPIRVGraphOps.td - Graph extended insts spec file -----*- tablegen -*-=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the op definition spec of Graph extension ops.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+#define MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
+
+include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/Interfaces/CallInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/FunctionInterfaces.td"
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Graph opcode specification.
+//===----------------------------------------------------------------------===//
+
+// Base class for all Graph ops.
+class SPIRV_GraphARMOp<string mnemonic, list<Trait> traits = []> :
+ SPIRV_ArmVendorOp<mnemonic, traits> {
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[SPV_ARM_graph, SPV_ARM_tensors]>,
+ Capability<[SPIRV_C_GraphARM]>
+ ];
+}
+
+def SPIRV_GraphARMOp : SPIRV_GraphARMOp<"Graph", [
+ AutomaticAllocationScope, DeclareOpInterfaceMethods<CallableOpInterface>,
+ FunctionOpInterface, InModuleScope, IsolatedFromAbove
+ ]> {
+
+ let summary = "Declare or define a SPIR-V graph";
+
+ let description = [{
+ This op declares or defines a SPIR-V graph using one region, which
+ contains one or more blocks.
+
+ This op is not allowed to implicitly capture global values, and all external
+ references must use function arguments or symbol references. This op itself
+ defines a symbol that is unique in the enclosing module op.
+
+ Note that this op does not have a 1:1 mapping to the SPIR-V ops representing
+ a graph. Indeed during serialization a single GraphARMOp is serialized into
+ several
diff erent SPIR-V ops: OpGraphARM, OpGraphInputARM and OpGraphEndARM.
+ There are as many occurences of OpGraphInputARM ops as many inputs in the
+ graph. Deserialization maps that set of operations into a single GraphARMOp.
+
+ This op itself takes no operands and generates no results. Its region
+ can take zero or more arguments and return one or more values.
+
+ ```
+ spv-graph-arm-op ::= `spirv.ARM.Graph` function-signature
+ region
+ ```
+
+ #### Example:
+
+ ```mlir
+ spirv.ARM.Graph @graph(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ TypeAttrOf<GraphType>:$function_type,
+ OptionalAttr<DictArrayAttr>:$arg_attrs,
+ OptionalAttr<DictArrayAttr>:$res_attrs,
+ OptionalAttr<BoolAttr>:$entry_point,
+ StrAttr:$sym_name
+ );
+
+ let results = (outs);
+
+ let regions = (region AnyRegion:$body);
+
+ let hasVerifier = 0;
+
+ let builders = [
+ OpBuilder<(ins "StringRef":$name, "GraphType":$type,
+ CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs, CArg<"bool", "false">:$entry_point)>];
+
+ let hasOpcode = 0;
+
+ let autogenSerialization = 0;
+
+ let extraClassDeclaration = [{
+ /// Hook for FunctionOpInterface, called after verifying that the 'type'
+ /// attribute is present and checks if it holds a function type. Ensures
+ /// getType, getNumArguments, and getNumResults can be called safely
+ LogicalResult verifyType();
+
+ /// Hook for FunctionOpInterface, called after verifying the function
+ /// type and the presence of the (potentially empty) function body.
+ /// Ensures SPIR-V specific semantics.
+ LogicalResult verifyBody();
+ }];
+}
+
+// -----
+
+// Check that an op can only be used within the scope of a spirv.ARM.Graph op.
+def InGraphScope : PredOpTrait<
+ "op must appear in a spirv.ARM.Graph op's block",
+ CPred<"isNestedInGraphARMOpInterface($_op.getParentOp())">>;
+
+// -----
+
+def SPIRV_GraphConstantARMOp : SPIRV_GraphARMOp<"GraphConstant", [InGraphScope, Pure, ConstantLike]> {
+ let summary = "Declare a graph constant.";
+
+ let description = [{
+ Declare a graph constant.
+ Result Type must be an OpTypeTensorARM.
+ GraphConstantID must be a 32-bit integer literal.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+ ```
+
+ GraphConstantID is a unique identifier which is use to map the contants
+ defined by GraphConstantARM in the SPIRV module with the one provided at
+ shader creation time via the VkDataGraphPipelineShaderModuleCreateInfoARM.
+ That Vulkan structure provides a list of VkDataGraphPipelineConstantARM
+ which contains the bindings from id to data. (For more details see
+ https://registry.khronos.org/vulkan/specs/latest/html/vkspec.html#graphs)
+ }];
+
+ let arguments = (ins
+ I32Attr: $graph_constant_id
+ );
+
+ let results = (outs
+ SPIRV_AnyTensorArm:$output
+ );
+
+ let hasVerifier = 0;
+
+ let autogenSerialization = 0;
+
+ let assemblyFormat = [{
+ attr-dict `:` type($output)
+ }];
+}
+
+// -----
+
+def SPIRV_GraphEntryPointARMOp : SPIRV_GraphARMOp<"GraphEntryPoint", [InModuleScope]> {
+ let summary = [{
+ Declare a graph entry point and its interface.
+ }];
+
+ let description = [{
+ Graph Entry Point must be the Result <id> of an OpGraphARM instruction.
+
+ Name is a name string for the graphentry point. A module cannot have two
+ OpGraphEntryPointARM instructions with the same Name string.
+
+ Interface is a list of symbol references to `spirv.GlobalVariable`
+ operations. These declare the set of global variables from a
+ module that form the interface of this entry point. The set of
+ Interface symbols must be equal to or a superset of the
+ `spirv.GlobalVariable`s referenced by the entry point’s static call
+ tree, within the interface’s storage classes.
+
+ #### Example:
+
+ ```mlir
+ spirv.GlobalVariable @arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+ spirv.GlobalVariable @res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+ spirv.ARM.GraphEntryPoint @graph, @arg_0, @res_0
+ spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+ ...
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ FlatSymbolRefAttr:$fn,
+ SymbolRefArrayAttr:$interface
+ );
+
+ let results = (outs);
+
+ // Checks for graph and interface symbol reference are done in spirv::ModuleOp verification.
+ let hasVerifier = 0;
+
+ let autogenSerialization = 0;
+
+ let builders = [
+ OpBuilder<(ins "spirv::GraphARMOp":$graph, "ArrayRef<Attribute>":$interfaceVars)>];
+}
+
+// -----
+
+def SPIRV_GraphOutputsARMOp : SPIRV_GraphARMOp<"GraphOutputs", [InGraphScope, Pure,
+ Terminator]> {
+
+ let summary = "Define graph outputs.";
+
+ let description = [{
+ Values are the graph outputs values and must match the GraphOutputs Type
+ operand of the OpTypeGraphARM type of the OpGraphARM body this
+ instruction is in.
+
+ This instruction must be the last instruction in a block.
+
+ #### Example:
+
+ ```mlir
+ spirv.ARM.Graph @graph(%arg0 : !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+ }
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<SPIRV_AnyTensorArm>:$value
+ );
+
+ let results = (outs);
+
+ let autogenSerialization = 0;
+
+ let hasOpcode = 0;
+
+ let assemblyFormat = "$value attr-dict `:` type($value)";
+}
+
+#endif // MLIR_DIALECT_SPIRV_IR_GRAPH_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
index 0fa1bb9d5bd01..96ef035eda37a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
@@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
+include "mlir/Dialect/SPIRV/IR/SPIRVGraphOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVIntegerDotProductOps.td"
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 2e356dec1981f..9d8d81a839fcb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -24,6 +24,7 @@ class Type;
class IntegerType;
class FloatType;
class FunctionType;
+class GraphType;
class IndexType;
class MemRefType;
class VectorType;
@@ -81,6 +82,7 @@ class Builder {
IntegerType getIntegerType(unsigned width);
IntegerType getIntegerType(unsigned width, bool isSigned);
FunctionType getFunctionType(TypeRange inputs, TypeRange results);
+ GraphType getGraphType(TypeRange inputs, TypeRange results);
TupleType getTupleType(TypeRange elementTypes);
NoneType getNoneType();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index a0c8acea91dc5..08847dd11c685 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -403,7 +403,7 @@ def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
// FunctionType
//===----------------------------------------------------------------------===//
-def Builtin_Function : Builtin_Type<"Function", "function"> {
+class Builtin_FunctionLike<string Name, string typeMnemonic> : Builtin_Type<Name, typeMnemonic> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:
@@ -434,6 +434,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
}]>
];
let skipDefaultBuilders = 1;
+ let storageClass = "FunctionTypeStorage";
let genStorageClass = 0;
let extraClassDeclaration = [{
/// Input types.
@@ -444,23 +445,26 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; }
- /// Returns a clone of this function type with the given argument
+ /// Returns a clone of this function-like type with the given argument
/// and result types.
- FunctionType clone(TypeRange inputs, TypeRange results) const;
+ }] # Name # "Type" # [{ clone(TypeRange inputs, TypeRange results) const;
- /// Returns a new function type with the specified arguments and results
+ /// Returns a new function-like type with the specified arguments and results
/// inserted.
- FunctionType getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+ }] # Name # "Type" # [{ getWithArgsAndResults(ArrayRef<unsigned> argIndices,
TypeRange argTypes,
ArrayRef<unsigned> resultIndices,
TypeRange resultTypes);
- /// Returns a new function type without the specified arguments and results.
- FunctionType getWithoutArgsAndResults(const BitVector &argIndices,
+ /// Returns a new function-like type without the specified arguments and results.
+ }] # Name # "Type" # [{ getWithoutArgsAndResults(const BitVector &argIndices,
const BitVector &resultIndices);
}];
}
+def Builtin_Function : Builtin_FunctionLike<"Function", "function">;
+def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">;
+
//===----------------------------------------------------------------------===//
// IndexType
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index b682f4c025a46..6b4e3dd603198 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -387,6 +387,12 @@ class OpaqueType<string dialect, string name, string summary>
def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
"function type", "::mlir::FunctionType">;
+// Graph Type.
+
+// Any graph type.
+def GraphType : Type<CPred<"::llvm::isa<::mlir::GraphType>($_self)">,
+ "graph type", "::mlir::GraphType">;
+
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppType = "::mlir::Type"> :
diff --git a/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
new file mode 100644
index 0000000000000..47fe4d9c5b21c
--- /dev/null
+++ b/mlir/lib/Dialect/SPIRV/IR/ArmGraphOps.cpp
@@ -0,0 +1,251 @@
+//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+
+#include "SPIRVParsingUtils.h"
+
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/FunctionImplementation.h"
+#include "llvm/Support/InterleavedRange.h"
+
+using namespace mlir;
+using namespace mlir::spirv::AttrNames;
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphARM
+//===----------------------------------------------------------------------===//
+
+ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ Builder &builder = parser.getBuilder();
+
+ // Parse the name as a symbol.
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ // Parse the function signature.
+ bool isVariadic = false;
+ SmallVector<OpAsmParser::Argument> entryArgs;
+ SmallVector<Type> resultTypes;
+ SmallVector<DictionaryAttr> resultAttrs;
+ if (function_interface_impl::parseFunctionSignatureWithArguments(
+ parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
+ resultAttrs))
+ return failure();
+
+ SmallVector<Type> argTypes = llvm::map_to_vector(
+ entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
+ GraphType grType = builder.getGraphType(argTypes, resultTypes);
+ result.addAttribute(getFunctionTypeAttrName(result.name),
+ TypeAttr::get(grType));
+
+ // If additional attributes are present, parse them.
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // Add the attributes to the function arguments.
+ assert(resultAttrs.size() == resultTypes.size());
+ call_interface_impl::addArgAndResultAttrs(
+ builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
+ getResAttrsAttrName(result.name));
+
+ // Parse the optional function body.
+ Region *body = result.addRegion();
+ OptionalParseResult parseResult =
+ parser.parseOptionalRegion(*body, entryArgs);
+ return failure(parseResult.has_value() && failed(*parseResult));
+}
+
+void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
+ // Print graph name, signature, and control.
+ printer << " ";
+ printer.printSymbolName(getSymName());
+ GraphType grType = getFunctionType();
+ function_interface_impl::printFunctionSignature(
+ printer, *this, grType.getInputs(),
+ /*isVariadic=*/false, grType.getResults());
+ function_interface_impl::printFunctionAttributes(printer, *this,
+ {getFunctionTypeAttrName(),
+ getArgAttrsAttrName(),
+ getResAttrsAttrName()});
+
+ // Print the body.
+ Region &body = this->getBody();
+ if (!body.empty()) {
+ printer << ' ';
+ printer.printRegion(body, /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ }
+}
+
+LogicalResult spirv::GraphARMOp::verifyType() {
+ if (getFunctionType().getNumResults() < 1)
+ return emitOpError("there should be at least one result");
+ return success();
+}
+
+LogicalResult spirv::GraphARMOp::verifyBody() {
+ for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
+ if (!isa<spirv::TensorArmType>(graphArgType)) {
+ return emitOpError("type of argument #")
+ << index << " must be a TensorArmType, but got " << graphArgType;
+ }
+ }
+ for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
+ if (!isa<spirv::TensorArmType>(graphResType)) {
+ return emitOpError("type of result #")
+ << index << " must be a TensorArmType, but got " << graphResType;
+ }
+ }
+
+ if (!isExternal()) {
+ Block &entryBlock = front();
+
+ unsigned numArguments = this->getNumArguments();
+ if (entryBlock.getNumArguments() != numArguments)
+ return emitOpError("entry block must have ")
+ << numArguments << " arguments to match graph signature";
+
+ for (auto [index, grArgType, blockArgType] :
+ llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
+ if (blockArgType != grArgType) {
+ return emitOpError("type of entry block argument #")
+ << index << '(' << blockArgType
+ << ") must match the type of the corresponding argument in "
+ << "graph signature(" << grArgType << ')';
+ }
+ }
+ }
+
+ GraphType grType = getFunctionType();
+ auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
+ if (grType.getNumResults() != op.getNumOperands())
+ return op.emitOpError("is returning ")
+ << op.getNumOperands()
+ << " value(s) but enclosing spirv.ARM.Graph requires "
+ << grType.getNumResults() << " result(s)";
+
+ ValueTypeRange<OperandRange> graphOutputOperandTypes =
+ op.getValue().getType();
+ for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) {
+ if (type != grType.getResult(index))
+ return op.emitError("type of return operand ")
+ << index << " (" << type << ") doesn't match graph result type ("
+ << grType.getResult(index) << ")";
+ }
+ return WalkResult::advance();
+ });
+
+ return failure(walkResult.wasInterrupted());
+}
+
+void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
+ StringRef name, GraphType type,
+ ArrayRef<NamedAttribute> attrs, bool entryPoint) {
+ state.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder.getStringAttr(name));
+ state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
+ state.attributes.append(attrs);
+ state.addAttribute(getEntryPointAttrName(state.name),
+ builder.getBoolAttr(entryPoint));
+ state.addRegion();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
+ return getFunctionType().getInputs();
+}
+
+ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
+ return getFunctionType().getResults();
+}
+
+Region *spirv::GraphARMOp::getCallableRegion() {
+ return isExternal() ? nullptr : &getBody();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphOutputsARM
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GraphOutputsARMOp::verify() {
+ auto graph = cast<GraphARMOp>((*this)->getParentOp());
+
+ // The operand number and types must match the graph signature.
+ const ArrayRef<Type> &results = graph.getFunctionType().getResults();
+ if (getNumOperands() != results.size())
+ return emitOpError("has ")
+ << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
+ << graph.getName() << ") returns " << results.size();
+
+ for (auto [index, result] : llvm::enumerate(results))
+ if (getOperand(index).getType() != result)
+ return emitError() << "type of return operand " << index << " ("
+ << getOperand(index).getType()
+ << ") doesn't match spirv.ARM.Graph result type ("
+ << result << ")"
+ << " in graph @" << graph.getName();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.GraphEntryPointARM
+//===----------------------------------------------------------------------===//
+
+void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
+ OperationState &state,
+ spirv::GraphARMOp graph,
+ ArrayRef<Attribute> interfaceVars) {
+ build(builder, state, SymbolRefAttr::get(graph),
+ builder.getArrayAttr(interfaceVars));
+}
+
+ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ FlatSymbolRefAttr fn;
+ if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
+ return failure();
+
+ SmallVector<Attribute, 4> interfaceVars;
+ if (!parser.parseOptionalComma()) {
+ // Parse the interface variables.
+ if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+ // The name of the interface variable attribute is not important.
+ FlatSymbolRefAttr var;
+ NamedAttrList attrs;
+ if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+ return failure();
+ interfaceVars.push_back(var);
+ return success();
+ }))
+ return failure();
+ }
+ result.addAttribute("interface",
+ parser.getBuilder().getArrayAttr(interfaceVars));
+ return success();
+}
+
+void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
+ printer << " ";
+ printer.printSymbolName(getFn());
+ ArrayRef<Attribute> interfaceVars = getInterface().getValue();
+ if (!interfaceVars.empty()) {
+ printer << ", " << llvm::interleaved(interfaceVars);
+ }
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
index b9aa7b7491abf..60d705d940cfc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_mlir_dialect_library(MLIRSPIRVDialect
+ ArmGraphOps.cpp
AtomicOps.cpp
CastOps.cpp
ControlFlowOps.cpp
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..2f3a28ff16173 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -31,6 +31,18 @@ static bool isNestedInFunctionOpInterface(Operation *op) {
return isNestedInFunctionOpInterface(op->getParentOp());
}
+/// Returns true if the given op is a GraphARM op or nested in a
+/// GraphARM op without a module-like op in the middle.
+static bool isNestedInGraphARMOpInterface(Operation *op) {
+ if (!op)
+ return false;
+ if (op->hasTrait<OpTrait::SymbolTable>())
+ return false;
+ if (isa<spirv::GraphARMOp>(op))
+ return true;
+ return isNestedInGraphARMOpInterface(op->getParentOp());
+}
+
/// Returns true if the given op is an module-like op that maintains a symbol
/// table.
static bool isDirectInModuleLikeOp(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6d3bda421f309..670eabf2584ea 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -158,6 +158,12 @@ void UpdateVCEPass::runOnOperation() {
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.getType());
+ // If the op is FunctionLike make sure to process input and result types.
+ if (auto funcOpInterface = dyn_cast<FunctionOpInterface>(op)) {
+ llvm::append_range(valueTypes, funcOpInterface.getArgumentTypes());
+ llvm::append_range(valueTypes, funcOpInterface.getResultTypes());
+ }
+
// Requirements from values' types
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index de52fbd3f215c..3d19c5ad8fbca 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -104,7 +104,7 @@ void OpAsmPrinter::printFunctionalType(Operation *op) {
// it is a function (avoiding a grammar ambiguity).
bool wrapped = op->getNumResults() != 1;
if (!wrapped && op->getResult(0).getType() &&
- llvm::isa<FunctionType>(op->getResult(0).getType()))
+ isa<FunctionType>(op->getResult(0).getType()))
wrapped = true;
if (wrapped)
@@ -2836,6 +2836,19 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << '>';
})
.Case<NoneType>([&](Type) { os << "none"; })
+ .Case<GraphType>([&](GraphType graphTy) {
+ os << '(';
+ interleaveComma(graphTy.getInputs(), [&](Type ty) { printType(ty); });
+ os << ") -> ";
+ ArrayRef<Type> results = graphTy.getResults();
+ if (results.size() == 1 && !isa<FunctionType, GraphType>(results[0])) {
+ printType(results[0]);
+ } else {
+ os << '(';
+ interleaveComma(results, [&](Type ty) { printType(ty); });
+ os << ')';
+ }
+ })
.Default([&](Type type) { return printDialectType(type); });
}
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index f657db142eeb9..3d366276b4375 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -76,6 +76,10 @@ FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
return FunctionType::get(context, inputs, results);
}
+GraphType Builder::getGraphType(TypeRange inputs, TypeRange results) {
+ return GraphType::get(context, inputs, results);
+}
+
TupleType Builder::getTupleType(TypeRange elementTypes) {
return TupleType::get(context, elementTypes);
}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1604ebba190a1..ce47c60c9b932 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -179,6 +179,45 @@ FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
return clone(newArgTypes, newResultTypes);
}
+//===----------------------------------------------------------------------===//
+// GraphType
+//===----------------------------------------------------------------------===//
+
+unsigned GraphType::getNumInputs() const { return getImpl()->numInputs; }
+
+ArrayRef<Type> GraphType::getInputs() const { return getImpl()->getInputs(); }
+
+unsigned GraphType::getNumResults() const { return getImpl()->numResults; }
+
+ArrayRef<Type> GraphType::getResults() const { return getImpl()->getResults(); }
+
+GraphType GraphType::clone(TypeRange inputs, TypeRange results) const {
+ return get(getContext(), inputs, results);
+}
+
+/// Returns a new function type with the specified arguments and results
+/// inserted.
+GraphType GraphType::getWithArgsAndResults(ArrayRef<unsigned> argIndices,
+ TypeRange argTypes,
+ ArrayRef<unsigned> resultIndices,
+ TypeRange resultTypes) {
+ SmallVector<Type> argStorage, resultStorage;
+ TypeRange newArgTypes =
+ insertTypesInto(getInputs(), argIndices, argTypes, argStorage);
+ TypeRange newResultTypes =
+ insertTypesInto(getResults(), resultIndices, resultTypes, resultStorage);
+ return clone(newArgTypes, newResultTypes);
+}
+
+/// Returns a new function type without the specified arguments and results.
+GraphType GraphType::getWithoutArgsAndResults(const BitVector &argIndices,
+ const BitVector &resultIndices) {
+ SmallVector<Type> argStorage, resultStorage;
+ TypeRange newArgTypes = filterTypesOut(getInputs(), argIndices, argStorage);
+ TypeRange newResultTypes =
+ filterTypesOut(getResults(), resultIndices, resultStorage);
+ return clone(newArgTypes, newResultTypes);
+}
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index f56bc3967b4b7..4ef242bdc5b16 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -306,3 +306,20 @@ func.func @constant_composite_replicate() -> () {
%0 = spirv.EXT.ConstantCompositeReplicate [1 : i32] : vector<2xi32>
spirv.Return
}
+
+//===----------------------------------------------------------------------===//
+// GraphARM ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: graph_arm
+spirv.ARM.Graph @graph_arm(%arg0: !spirv.arm.tensor<1x16x16x16xi8>) -> !spirv.arm.tensor<1x16x16x16xi8> {
+ // CHECK: spirv.ARM.GraphOutputs min version: v1.0
+ // CHECK: spirv.ARM.GraphOutputs max version: v1.6
+ // CHECK: spirv.ARM.GraphOutputs extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
+ // CHECK: spirv.ARM.GraphOutputs capabilities: [ [GraphARM] ]
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<1x16x16x16xi8>
+// CHECK: spirv.ARM.Graph min version: v1.0
+// CHECK: spirv.ARM.Graph max version: v1.6
+// CHECK: spirv.ARM.Graph extensions: [ [SPV_ARM_graph, SPV_ARM_tensors] ]
+// CHECK: spirv.ARM.Graph capabilities: [ [GraphARM] ]
+}
diff --git a/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
new file mode 100644
index 0000000000000..798147df45a34
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/IR/graph-ops.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph and spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.ARM.Graph {{@.*}}({{%.*}}: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+ // CHECK: spirv.ARM.GraphOutputs {{%.*}} : !spirv.arm.tensor<14x19xi16>
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphConstant
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.ARM.Graph {{@.*}}() -> !spirv.arm.tensor<2x3xi16> {
+spirv.ARM.Graph @graphConstant() -> !spirv.arm.tensor<2x3xi16> {
+ // CHECK: [[CONST:%.*]] = spirv.ARM.GraphConstant {graph_constant_id = 42 : i32} : !spirv.arm.tensor<2x3xi16>
+ %0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+ // CHECK: spirv.ARM.GraphOutputs [[CONST:%.*]] : !spirv.arm.tensor<2x3xi16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<2x3xi16>
+}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphEntryPoint
+//===----------------------------------------------------------------------===//
+
+// CHECK: spirv.GlobalVariable [[VARARG0:@.*]] bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_arg_0 bind(0, 0) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.GlobalVariable [[VARRES0:@.*]] bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+spirv.GlobalVariable @entrypoint_res_0 bind(0, 1) : !spirv.ptr<!spirv.arm.tensor<14x19xi16>, UniformConstant>
+// CHECK: spirv.ARM.GraphEntryPoint [[GN:@.*]], [[VARARG0]], [[VARRES0]]
+spirv.ARM.GraphEntryPoint @entrypoint, @entrypoint_arg_0, @entrypoint_res_0
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no terminator
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{empty block: expect at least a terminator}}
+spirv.ARM.Graph @graphNoterminator(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph with no result types
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op there should be at least one result}}
+spirv.ARM.Graph @graphNoOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> () {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphConstant outside graph scope
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.GraphConstant' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+%0 = spirv.ARM.GraphConstant { graph_constant_id = 42 : i32 } : !spirv.arm.tensor<2x3xi16>
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.GraphOutputs outside graph scope
+//===----------------------------------------------------------------------===//
+
+%0 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
+// expected-error @+1 {{'spirv.ARM.GraphOutputs' op failed to verify that op must appear in a spirv.ARM.Graph op's block}}
+spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1xi16>
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<5x3xi16> {
+ // expected-error @+1 {{type of return operand 0 ('!spirv.arm.tensor<14x19xi16>') doesn't match graph result type ('!spirv.arm.tensor<5x3xi16>')}}
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph return type does not match number of results in spirv.ARM.GraphOutputs
+//===----------------------------------------------------------------------===//
+
+spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> (!spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>) {
+ // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 1 value(s) but enclosing spirv.ARM.Graph requires 2 result(s)}}
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> !spirv.arm.tensor<14x19xi16> {
+ // expected-error @+1 {{'spirv.ARM.GraphOutputs' op is returning 2 value(s) but enclosing spirv.ARM.Graph requires 1 result(s)}}
+ spirv.ARM.GraphOutputs %arg0, %arg0 : !spirv.arm.tensor<14x19xi16>, !spirv.arm.tensor<14x19xi16>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType argument
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op type of argument #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0: i8) -> !spirv.arm.tensor<14x19xi16> {
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.ARM.Graph using a non TensorArmType result
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spirv.ARM.Graph' op type of result #0 must be a TensorArmType, but got 'i8'}}
+spirv.ARM.Graph @graphAndOutputs(%arg0: !spirv.arm.tensor<14x19xi16>) -> i8 {
+}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index 4e534a30ad516..2d20ae0a13105 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -231,3 +231,14 @@ spirv.module Logical GLSL450 attributes {
spirv.ReturnValue %val : bf16
}
}
+
+// CHECK: requires #spirv.vce<v1.5, [GraphARM, TensorsARM, Int8, Float16, VulkanMemoryModel], [SPV_ARM_graph, SPV_ARM_tensors, SPV_KHR_vulkan_memory_model]>
+spirv.module Logical Vulkan attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.5, [VulkanMemoryModel, GraphARM, TensorsARM, Float16], [SPV_ARM_tensors, SPV_ARM_graph]>,
+ #spirv.resource_limits<>>
+} {
+ spirv.ARM.Graph @argmax(%arg0: !spirv.arm.tensor<14x19xi8>, %arg1 : !spirv.arm.tensor<1xf16>) -> !spirv.arm.tensor<14x19xi8> {
+ spirv.ARM.GraphOutputs %arg0 : !spirv.arm.tensor<14x19xi8>
+ }
+}
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 2e5e591fe5f91..5643a0ff5b91c 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -21,7 +21,7 @@ using namespace mlir;
namespace {
/// A pass for testing SPIR-V op availability.
struct PrintOpAvailability
- : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
+ : public PassWrapper<PrintOpAvailability, OperationPass<mlir::ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
void runOnOperation() override;
@@ -33,12 +33,10 @@ struct PrintOpAvailability
} // namespace
void PrintOpAvailability::runOnOperation() {
- auto f = getOperation();
- llvm::outs() << f.getName() << "\n";
-
+ mlir::ModuleOp moduleOp = getOperation();
Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
- f->walk([&](Operation *op) {
+ auto opCallback = [&](Operation *op) {
if (op->getDialect() != spirvDialect)
return WalkResult::advance();
@@ -89,6 +87,16 @@ void PrintOpAvailability::runOnOperation() {
os.flush();
return WalkResult::advance();
+ };
+
+ moduleOp.walk([&](func::FuncOp f) {
+ llvm::outs() << f.getName() << "\n";
+ f->walk(opCallback);
+ });
+
+ moduleOp.walk([&](spirv::GraphARMOp g) {
+ llvm::outs() << g.getName() << "\n";
+ g->walk(opCallback);
});
}
More information about the Mlir-commits
mailing list