[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_tensors (PR #144667)
Davide Grohmann
llvmlistbot at llvm.org
Thu Jun 19 04:09:30 PDT 2025
https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/144667
>From 136f901344c205e4fe93f165ff4b7e32490abc98 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 13 Jun 2025 17:03:09 +0200
Subject: [PATCH 1/2] Add support for SPV_ARM_tensors
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Change-Id: If78909a47417ef3dda710847cfe90c34b984ff09
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 34 ++++-
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 35 ++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 77 ++++++++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 6 +
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 124 +++++++++++++++++-
.../SPIRV/Deserialization/DeserializeOps.cpp | 1 +
.../SPIRV/Deserialization/Deserializer.cpp | 52 ++++++++
.../SPIRV/Deserialization/Deserializer.h | 2 +
.../Target/SPIRV/Serialization/Serializer.cpp | 48 +++++++
mlir/test/Dialect/SPIRV/IR/types.mlir | 51 +++++++
mlir/test/Target/SPIRV/tensorARM.mlir | 66 ++++++++++
11 files changed, 487 insertions(+), 9 deletions(-)
create mode 100644 mlir/test/Target/SPIRV/tensorARM.mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
+def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
+ SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
+def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+ SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+ "any SPIR-V tensorArm type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage
+ SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor,
+ SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};
+// SPIR-V TensorARM Type
+class TensorArmType
+ : public Type::TypeBase<TensorArmType, CompositeType,
+ detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "spirv.arm.tensor";
+
+ // TensorArm supports minimum rank of 1, hence an empty shape here means
+ // unranked.
+ static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+ TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
+
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType);
+
+ Type getElementType() const;
+ ArrayRef<int64_t> getShape() const;
+ unsigned getNumElements() const;
+ bool hasRank() const { return !getShape().empty(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage = std::nullopt);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage = std::nullopt);
+};
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
+ } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+ if (!llvm::isa<ScalarType>(t.getElementType())) {
+ parser.emitError(
+ typeLoc, "only scalar element type allowed in tensor type but found ")
+ << t.getElementType();
+ return Type();
+ }
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
+// tensor-arm-type ::=
+// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ bool unranked = false;
+ SmallVector<int64_t, 4> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+
+ if (parser.parseOptionalStar().succeeded()) {
+ unranked = true;
+ if (parser.parseXInDimensionList())
+ return {};
+ } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+ return {};
+
+ if (!unranked && dims.empty()) {
+ parser.emitError(countLoc, "arm.tensors do not support rank zero");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim == 0; })) {
+ parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim < 0; }) &&
+ std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim > 0; })) {
+ parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+ "fully dynamic or completed shaped");
+ return {};
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return TensorArmType::get(dims, elementTy);
+}
+
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
+ if (keyword == "arm.tensor")
+ return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+ os << "arm.tensor<";
+
+ llvm::interleave(
+ type.getShape(), os,
+ [&](int64_t dim) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ },
+ "x");
+ if (!type.hasRank()) {
+ os << "*";
+ }
+ os << "x" << type.getElementType() << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType>(
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
}
+ if (llvm::isa<TensorArmType>(type)) {
+ if (parser.parseOptionalColon().succeeded())
+ if (parser.parseType(type))
+ return failure();
+ }
+
return parser.addTypeToList(type, result.types);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
#include <cstdint>
#include <iterator>
+#include <numeric>
using namespace mlir;
using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType>(type);
+ spirv::StructType, spirv::TensorArmType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
- [](auto type) { return type.getElementType(); })
+ .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+ TensorArmType>([](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+ return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Extension exts[] = {Extension::SPV_ARM_tensors};
+ ArrayRef<Extension> ref(exts, std::size(exts));
+ extensions.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getExtensions(extensions, storage);
+ })
+
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Capability caps[] = {Capability::TensorsARM};
+ ArrayRef<Capability> ref(caps, std::size(caps));
+ capabilities.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getCapabilities(capabilities, storage);
+ })
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
return std::nullopt;
return *elementSize * vectorType.getNumElements();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ std::optional<int64_t> elementSize =
+ llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+ if (!elementSize)
+ return std::nullopt;
+ return *elementSize * tensorArmType.getNumElements();
+ }
return std::nullopt;
}
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
return true;
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ return llvm::isa<ScalarType>(tensorArmType.getElementType());
+ }
return false;
}
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
}
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+ using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+ static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ auto shape = std::get<0>(key);
+ auto elementType = std::get<1>(key);
+ shape = allocator.copyInto(shape);
+ return new (allocator.allocate<TensorArmTypeStorage>())
+ TensorArmTypeStorage(std::move(shape), std::move(elementType));
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(shape, elementType);
+ }
+
+ TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+ : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+ ArrayRef<int64_t> shape;
+ Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+ return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+ auto shape = getShape();
+ return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+ std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+ extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
+ static constexpr Capability caps[] = {Capability::TensorsARM};
+ capabilities.push_back(caps);
+}
+
+LogicalResult
+TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType) {
+ if (std::any_of(shape.begin(), shape.end(),
+ [](int64_t dim) { return dim == 0; }))
+ return emitError() << "arm.tensor do not support dimensions = 0";
+ if (std::any_of(shape.begin(), shape.end(),
+ [](int64_t dim) { return dim < 0; }) &&
+ std::any_of(shape.begin(), shape.end(),
+ [](int64_t dim) { return dim > 0; }))
+ return emitError()
+ << "arm.tensor shape dimensions must be either fully dynamic or "
+ "completed shaped";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
void SPIRVDialect::registerTypes() {
addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
- RuntimeArrayType, SampledImageType, StructType>();
+ RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b30da773d4896..55d6a380d0bff 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
+ case spirv::Opcode::OpTypeTensorARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b9d9a9015eb61..f0e42047c559e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -138,6 +138,7 @@ LogicalResult spirv::Deserializer::processHeader() {
MIN_VERSION_CASE(3);
MIN_VERSION_CASE(4);
MIN_VERSION_CASE(5);
+ MIN_VERSION_CASE(6);
#undef MIN_VERSION_CASE
default:
return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
@@ -935,6 +936,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processStructType(operands);
case spirv::Opcode::OpTypeMatrix:
return processMatrixType(operands);
+ case spirv::Opcode::OpTypeTensorARM:
+ return processTensorARMType(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@@ -1238,6 +1241,55 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
return success();
}
+LogicalResult
+spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
+ unsigned size = operands.size();
+ if (size < 2 || size > 4) {
+ return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
+ "(result_id, element_type, (rank), (shape))")
+ << size;
+ }
+ Type elementTy = getType(operands[1]);
+ if (!elementTy) {
+ return emitError(unknownLoc,
+ "OpTypeTensorARM references undefined element type.")
+ << operands[1];
+ }
+ if (size == 2) {
+ typeMap[operands[0]] = TensorArmType::get({}, elementTy);
+ return success();
+ }
+
+ auto rankAttr = getConstantInt(operands[2]);
+ if (!rankAttr)
+ return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
+ "scalar integer constant instruction");
+ unsigned rank = rankAttr.getValue().getZExtValue();
+ if (size == 3) {
+ SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
+ typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+ return success();
+ }
+
+ auto shapeInfo = getConstant(operands[3]);
+ if (!shapeInfo) {
+ return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
+ "constant instruction of type OpTypeArray");
+ }
+ ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+ SmallVector<int64_t, 1> shape;
+ for (auto dimAttr : shapeArrayAttr.getValue()) {
+ auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+ if (!dimIntAttr) {
+ return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
+ "dimension size");
+ }
+ shape.push_back(dimIntAttr.getValue().getSExtValue());
+ }
+ typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+ return success();
+}
+
LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e4556e7652b17..1bc9e4a3c75d8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -291,6 +291,8 @@ class Deserializer {
LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
+ LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+
LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d258bfd852961..ebebd2d283afa 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -729,6 +729,54 @@ LogicalResult Serializer::prepareBasicType(
return success();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ uint32_t elementTypeID = 0;
+ uint32_t rank = 0;
+ uint32_t shapeID = 0;
+ uint32_t rankID = 0;
+ if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
+ elementTypeID, serializationCtx))) {
+ return failure();
+ }
+ if (tensorArmType.hasRank()) {
+ ArrayRef<int64_t> dims = tensorArmType.getShape();
+ rank = dims.size();
+ rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
+ if (rankID == 0) {
+ return failure();
+ }
+
+ bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
+ if (rank > 0 && shaped) {
+ auto I32Type = IntegerType::get(type.getContext(), 32);
+ auto shapeType = ArrayType::get(I32Type, rank);
+ if (rank == 1) {
+ SmallVector<uint64_t, 1> index(rank);
+ shapeID = prepareDenseElementsConstant(
+ loc, shapeType,
+ mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
+ index);
+ } else {
+ shapeID = prepareArrayConstant(
+ loc, shapeType,
+ mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
+ }
+ if (shapeID == 0) {
+ return failure();
+ }
+ }
+ }
+ typeEnum = spirv::Opcode::OpTypeTensorARM;
+ operands.push_back(elementTypeID);
+ if (rankID == 0)
+ return success();
+ operands.push_back(rankID);
+ if (shapeID == 0)
+ return success();
+ operands.push_back(shapeID);
+ return success();
+ }
+
// TODO: Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index c23894c62826b..7d45b5ea82643 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -564,3 +564,54 @@ func.func private @matrix_size_type(!spirv.matrix< x vector<3xi32>>) -> ()
func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> ()
// -----
+
+//===----------------------------------------------------------------------===//
+// TensorArm
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>)
+func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>)
+func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>)
+func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>)
+func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>)
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>)
+func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>) -> ()
+// -----
+
+// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}}
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support rank zero}}
+func.func private @arm_tensor_rank_zero(!spirv.arm.tensor<i32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>)
+func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support zero dimensions}}
+func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
new file mode 100644
index 0000000000000..25c2a25b47d88
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.6, [Shader, TensorsARM], [SPV_ARM_tensors]> {
+ // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
+ spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" {
+ spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" {
+ spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" {
+ // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64>
+ %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64>
+
+ spirv.ReturnValue %0: !spirv.arm.tensor<3xui64>
+ }
+
+// -----
+
+ // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+ spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+ // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+ %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32>
+ }
+
+// -----
+
+ // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+ spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+ // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32>
+ %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32>
+
+ spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+ }
+
+// -----
+
+ // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" {
+ spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor<?xi32>) "None" {
+ spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<?xi32>) "None" {
+ spirv.Return
+ }
+
+// -----
+
+ // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor<?x?xi32>) "None" {
+ spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor<?x?xi32>) "None" {
+ spirv.Return
+ }
+}
>From 4750c732d13dd609924045df31dc685cf193849b Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 19 Jun 2025 10:36:29 +0200
Subject: [PATCH 2/2] Resolve several review comments
Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I9f99f2e5efc1b433bcf885a4890c730cb8cac213
---
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 6 ++--
mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp | 16 ++++-----
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 34 +++++++------------
.../SPIRV/Deserialization/Deserializer.cpp | 23 ++++++-------
4 files changed, 33 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 7ffea6e7dba81..2d3a4e1778490 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -97,8 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};
-// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
-// StructType.
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
+// StructType, or SPIR-V TensorArmType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
@@ -479,7 +479,7 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};
-// SPIR-V TensorARM Type
+/// SPIR-V TensorARM Type
class TensorArmType
: public Type::TypeBase<TensorArmType, CompositeType,
detail::TensorArmTypeStorage, ShapedType::Trait> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 15002f1d5d16e..88c7adf3dfcb3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,8 +194,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
- } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
- if (!llvm::isa<ScalarType>(t.getElementType())) {
+ } else if (auto t = dyn_cast<TensorArmType>(type)) {
+ if (!isa<ScalarType>(t.getElementType())) {
parser.emitError(
typeLoc, "only scalar element type allowed in tensor type but found ")
<< t.getElementType();
@@ -385,24 +385,22 @@ static Type parseTensorArmType(SPIRVDialect const &dialect,
unranked = true;
if (parser.parseXInDimensionList())
return {};
- } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+ } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
return {};
+ }
if (!unranked && dims.empty()) {
parser.emitError(countLoc, "arm.tensors do not support rank zero");
return {};
}
- if (std::any_of(dims.begin(), dims.end(),
- [](int64_t dim) { return dim == 0; })) {
+ if (llvm::is_contained(dims, 0)) {
parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
return {};
}
- if (std::any_of(dims.begin(), dims.end(),
- [](int64_t dim) { return dim < 0; }) &&
- std::any_of(dims.begin(), dims.end(),
- [](int64_t dim) { return dim > 0; })) {
+ if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
+ llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
"fully dynamic or completed shaped");
return {};
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index e4eeb0a7f37d5..18e13a7d5d3ec 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,7 +20,6 @@
#include <algorithm>
#include <cstdint>
-#include <iterator>
#include <numeric>
using namespace mlir;
@@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
- if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+ if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
@@ -156,9 +155,7 @@ void CompositeType::getExtensions(
.getExtensions(extensions, storage);
})
.Case<TensorArmType>([&](TensorArmType type) {
- static const Extension exts[] = {Extension::SPV_ARM_tensors};
- ArrayRef<Extension> ref(exts, std::size(exts));
- extensions.push_back(ref);
+ extensions.push_back({Extension::SPV_ARM_tensors});
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
@@ -184,9 +181,7 @@ void CompositeType::getCapabilities(
.getCapabilities(capabilities, storage);
})
.Case<TensorArmType>([&](TensorArmType type) {
- static const Capability caps[] = {Capability::TensorsARM};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
+ capabilities.push_back({Capability::TensorsARM});
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
@@ -1245,15 +1240,15 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
- auto shape = std::get<0>(key);
- auto elementType = std::get<1>(key);
+ auto [shape, elementType] = key;
shape = allocator.copyInto(shape);
return new (allocator.allocate<TensorArmTypeStorage>())
TensorArmTypeStorage(std::move(shape), std::move(elementType));
}
static llvm::hash_code hashKey(const KeyTy &key) {
- return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+ auto [shape, elementType] = key;
+ return llvm::hash_combine(shape, elementType);
}
bool operator==(const KeyTy &key) const {
@@ -1280,7 +1275,7 @@ Type TensorArmType::getElementType() const { return getImpl()->elementType; }
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
unsigned TensorArmType::getNumElements() const {
- auto shape = getShape();
+ ArrayRef<int64_t> shape = getShape();
return std::accumulate(shape.begin(), shape.end(), unsigned(1),
std::multiplies<unsigned>());
}
@@ -1290,8 +1285,7 @@ void TensorArmType::getExtensions(
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
- static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
- extensions.push_back(exts);
+ extensions.push_back({Extension::SPV_ARM_tensors});
}
void TensorArmType::getCapabilities(
@@ -1299,20 +1293,16 @@ void TensorArmType::getCapabilities(
std::optional<StorageClass> storage) {
llvm::cast<SPIRVType>(getElementType())
.getCapabilities(capabilities, storage);
- static constexpr Capability caps[] = {Capability::TensorsARM};
- capabilities.push_back(caps);
+ capabilities.push_back({Capability::TensorsARM});
}
LogicalResult
TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
- if (std::any_of(shape.begin(), shape.end(),
- [](int64_t dim) { return dim == 0; }))
+ if (llvm::is_contained(shape, 0))
return emitError() << "arm.tensor do not support dimensions = 0";
- if (std::any_of(shape.begin(), shape.end(),
- [](int64_t dim) { return dim < 0; }) &&
- std::any_of(shape.begin(), shape.end(),
- [](int64_t dim) { return dim > 0; }))
+ if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
+ llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
return emitError()
<< "arm.tensor shape dimensions must be either fully dynamic or "
"completed shaped";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f0e42047c559e..b801f5a4660fc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1244,23 +1244,23 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
LogicalResult
spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
unsigned size = operands.size();
- if (size < 2 || size > 4) {
+ if (size < 2 || size > 4)
return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
- "(result_id, element_type, (rank), (shape))")
+ "(result_id, element_type, (rank), (shape)) ")
<< size;
- }
+
Type elementTy = getType(operands[1]);
- if (!elementTy) {
+ if (!elementTy)
return emitError(unknownLoc,
- "OpTypeTensorARM references undefined element type.")
+ "OpTypeTensorARM references undefined element type ")
<< operands[1];
- }
+
if (size == 2) {
typeMap[operands[0]] = TensorArmType::get({}, elementTy);
return success();
}
- auto rankAttr = getConstantInt(operands[2]);
+ IntegerAttr rankAttr = getConstantInt(operands[2]);
if (!rankAttr)
return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
"scalar integer constant instruction");
@@ -1271,19 +1271,18 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}
- auto shapeInfo = getConstant(operands[3]);
- if (!shapeInfo) {
+ std::optional<std::pair<Attribute, Type>> shapeInfo = getConstant(operands[3]);
+ if (!shapeInfo)
return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
"constant instruction of type OpTypeArray");
- }
+
ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
SmallVector<int64_t, 1> shape;
for (auto dimAttr : shapeArrayAttr.getValue()) {
auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
- if (!dimIntAttr) {
+ if (!dimIntAttr)
return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
"dimension size");
- }
shape.push_back(dimIntAttr.getValue().getSExtValue());
}
typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
More information about the Mlir-commits
mailing list