[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_tensors (PR #144667)

Davide Grohmann llvmlistbot at llvm.org
Wed Jun 25 03:24:54 PDT 2025


https://github.com/davidegrohmann updated https://github.com/llvm/llvm-project/pull/144667

>From 136f901344c205e4fe93f165ff4b7e32490abc98 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Fri, 13 Jun 2025 17:03:09 +0200
Subject: [PATCH 1/4] Add support for SPV_ARM_tensors

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian at arm.com>
Change-Id: If78909a47417ef3dda710847cfe90c34b984ff09
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  34 ++++-
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  35 ++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    |  77 ++++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |   6 +
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 124 +++++++++++++++++-
 .../SPIRV/Deserialization/DeserializeOps.cpp  |   1 +
 .../SPIRV/Deserialization/Deserializer.cpp    |  52 ++++++++
 .../SPIRV/Deserialization/Deserializer.h      |   2 +
 .../Target/SPIRV/Serialization/Serializer.cpp |  48 +++++++
 mlir/test/Dialect/SPIRV/IR/types.mlir         |  51 +++++++
 mlir/test/Target/SPIRV/tensorARM.mlir         |  66 ++++++++++
 11 files changed, 487 insertions(+), 9 deletions(-)
 create mode 100644 mlir/test/Target/SPIRV/tensorARM.mlir

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
+def SPV_ARM_tensors                      : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
 def SPIRV_ExtensionAttr :
     SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
       SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
       SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
       SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
       SPV_EXT_mesh_shader,
+      SPV_ARM_tensors,
       SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
       SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
       SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams                             : I32EnumAttrCase<"Geome
 def SPIRV_C_MultiViewport                               : I32EnumAttrCase<"MultiViewport", 57> {
   list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
 }
+def SPIRV_C_TensorsARM                                  : I32EnumAttrCase<"TensorsARM", 4174> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT        : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT     : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+  list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+  list<Availability> availability = [
+    Extension<[SPV_ARM_tensors]>
+  ];
+}
 def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR  : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
   list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
   list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
       SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
       SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
       SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+      SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+      SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
       SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
       SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
       SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
 def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
 def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
 def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
 
 // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
 // for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
                                 "any SPIR-V struct type">;
 def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
                                 "any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+                                 "any SPIR-V tensorArm type">;
 
 def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
 def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
     SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
-    SPIRV_AnyImage
+    SPIRV_AnyImage, SPIRV_AnyTensorArm
   ]>;
 
 def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor      : I32EnumAttrCase<"OpGroupNonUnifo
 def SPIRV_OC_OpGroupNonUniformLogicalAnd      : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
 def SPIRV_OC_OpGroupNonUniformLogicalOr       : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
 def SPIRV_OC_OpGroupNonUniformLogicalXor      : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM                  : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
 def SPIRV_OC_OpSubgroupBallotKHR              : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 def SPIRV_OC_OpGroupNonUniformRotateKHR       : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
 def SPIRV_OC_OpSDot                           : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
       SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
       SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
-      SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+      SPIRV_OC_OpGroupNonUniformLogicalXor,
+      SPIRV_OC_OpTypeTensorARM,
+      SPIRV_OC_OpSubgroupBallotKHR,
       SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
       SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
       SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
 namespace detail {
 struct ArrayTypeStorage;
 struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
 struct ImageTypeStorage;
 struct MatrixTypeStorage;
 struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
   std::optional<int64_t> getSizeInBytes();
 };
 
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
 class CompositeType : public SPIRVType {
 public:
   using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
+// SPIR-V TensorARM Type
+class TensorArmType
+    : public Type::TypeBase<TensorArmType, CompositeType,
+                            detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+  using Base::Base;
+
+  static constexpr StringLiteral name = "spirv.arm.tensor";
+
+  // TensorArm supports minimum rank of 1, hence an empty shape here means
+  // unranked.
+  static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+  TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                          Type elementType) const;
+
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType);
+
+  Type getElementType() const;
+  ArrayRef<int64_t> getShape() const;
+  unsigned getNumElements() const;
+  bool hasRank() const { return !getShape().empty(); }
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+                     std::optional<StorageClass> storage = std::nullopt);
+  void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+                       std::optional<StorageClass> storage = std::nullopt);
+};
+
 } // namespace spirv
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
           << t.getNumElements();
       return Type();
     }
+  } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+    if (!llvm::isa<ScalarType>(t.getElementType())) {
+      parser.emitError(
+          typeLoc, "only scalar element type allowed in tensor type but found ")
+          << t.getElementType();
+      return Type();
+    }
   } else {
     parser.emitError(typeLoc, "cannot use ")
         << type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
   return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
 }
 
+// tensor-arm-type ::=
+//   `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+                               DialectAsmParser &parser) {
+  if (parser.parseLess())
+    return {};
+
+  bool unranked = false;
+  SmallVector<int64_t, 4> dims;
+  SMLoc countLoc = parser.getCurrentLocation();
+
+  if (parser.parseOptionalStar().succeeded()) {
+    unranked = true;
+    if (parser.parseXInDimensionList())
+      return {};
+  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+    return {};
+
+  if (!unranked && dims.empty()) {
+    parser.emitError(countLoc, "arm.tensors do not support rank zero");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim == 0; })) {
+    parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+    return {};
+  }
+
+  if (std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim < 0; }) &&
+      std::any_of(dims.begin(), dims.end(),
+                  [](int64_t dim) { return dim > 0; })) {
+    parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+                               "fully dynamic or completed shaped");
+    return {};
+  }
+
+  auto elementTy = parseAndVerifyType(dialect, parser);
+  if (!elementTy)
+    return {};
+
+  if (parser.parseGreater())
+    return {};
+
+  return TensorArmType::get(dims, elementTy);
+}
+
 // TODO: Reorder methods to be utilities first and parse*Type
 // methods in alphabetical order
 //
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
     return parseStructType(*this, parser);
   if (keyword == "matrix")
     return parseMatrixType(*this, parser);
+  if (keyword == "arm.tensor")
+    return parseTensorArmType(*this, parser);
   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
   return Type();
 }
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
   os << ">";
 }
 
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+  os << "arm.tensor<";
+
+  llvm::interleave(
+      type.getShape(), os,
+      [&](int64_t dim) {
+        if (ShapedType::isDynamic(dim))
+          os << '?';
+        else
+          os << dim;
+      },
+      "x");
+  if (!type.hasRank()) {
+    os << "*";
+  }
+  os << "x" << type.getElementType() << ">";
+}
+
 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
   TypeSwitch<Type>(type)
       .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
-            ImageType, SampledImageType, StructType, MatrixType>(
+            ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
           [&](auto type) { print(type, os); })
       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
 }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
       return failure();
   }
 
+  if (llvm::isa<TensorArmType>(type)) {
+    if (parser.parseOptionalColon().succeeded())
+      if (parser.parseType(type))
+        return failure();
+  }
+
   return parser.addTypeToList(type, result.types);
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <iterator>
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
     return isValid(vectorType);
   return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
                    spirv::MatrixType, spirv::RuntimeArrayType,
-                   spirv::StructType>(type);
+                   spirv::StructType, spirv::TensorArmType>(type);
 }
 
 bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
 
 Type CompositeType::getElementType(unsigned index) const {
   return TypeSwitch<Type, Type>(*this)
-      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
-          [](auto type) { return type.getElementType(); })
+      .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+            TensorArmType>([](auto type) { return type.getElementType(); })
       .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
       .Case<StructType>(
           [index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+    return tensorArmType.getNumElements();
   if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
         "invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Extension exts[] = {Extension::SPV_ARM_tensors};
+        ArrayRef<Extension> ref(exts, std::size(exts));
+        extensions.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getExtensions(extensions, storage);
+      })
+
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
+      .Case<TensorArmType>([&](TensorArmType type) {
+        static const Capability caps[] = {Capability::TensorsARM};
+        ArrayRef<Capability> ref(caps, std::size(caps));
+        capabilities.push_back(ref);
+        return llvm::cast<ScalarType>(type.getElementType())
+            .getCapabilities(capabilities, storage);
+      })
       .Default([](Type) { llvm_unreachable("invalid composite type"); });
 }
 
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
       return std::nullopt;
     return *elementSize * vectorType.getNumElements();
   }
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    std::optional<int64_t> elementSize =
+        llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+    if (!elementSize)
+      return std::nullopt;
+    return *elementSize * tensorArmType.getNumElements();
+  }
   return std::nullopt;
 }
 
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
     return true;
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+    return llvm::isa<ScalarType>(tensorArmType.getElementType());
+  }
   return false;
 }
 
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
     matrixType.getExtensions(extensions, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getExtensions(extensions, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getExtensions(extensions, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getExtensions");
   }
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
     matrixType.getCapabilities(capabilities, storage);
   } else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
     ptrType.getCapabilities(capabilities, storage);
+  } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+    tensorArmType.getCapabilities(capabilities, storage);
   } else {
     llvm_unreachable("invalid SPIR-V Type to getCapabilities");
   }
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
   llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
 }
 
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+  static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    auto shape = std::get<0>(key);
+    auto elementType = std::get<1>(key);
+    shape = allocator.copyInto(shape);
+    return new (allocator.allocate<TensorArmTypeStorage>())
+        TensorArmTypeStorage(std::move(shape), std::move(elementType));
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(shape, elementType);
+  }
+
+  TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+      : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                       Type elementType) const {
+  return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+  auto shape = getShape();
+  return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+                         std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+  extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capability caps[] = {Capability::TensorsARM};
+  capabilities.push_back(caps);
+}
+
+LogicalResult
+TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType) {
+  if (std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim == 0; }))
+    return emitError() << "arm.tensor do not support dimensions = 0";
+  if (std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim < 0; }) &&
+      std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim > 0; }))
+    return emitError()
+           << "arm.tensor shape dimensions must be either fully dynamic or "
+              "completed shaped";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // SPIR-V Dialect
 //===----------------------------------------------------------------------===//
 
 void SPIRVDialect::registerTypes() {
   addTypes<ArrayType, CooperativeMatrixType, ImageType, MatrixType, PointerType,
-           RuntimeArrayType, SampledImageType, StructType>();
+           RuntimeArrayType, SampledImageType, StructType, TensorArmType>();
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index b30da773d4896..55d6a380d0bff 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -164,6 +164,7 @@ LogicalResult spirv::Deserializer::processInstruction(
   case spirv::Opcode::OpTypeRuntimeArray:
   case spirv::Opcode::OpTypeStruct:
   case spirv::Opcode::OpTypePointer:
+  case spirv::Opcode::OpTypeTensorARM:
   case spirv::Opcode::OpTypeCooperativeMatrixKHR:
     return processType(opcode, operands);
   case spirv::Opcode::OpTypeForwardPointer:
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b9d9a9015eb61..f0e42047c559e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -138,6 +138,7 @@ LogicalResult spirv::Deserializer::processHeader() {
       MIN_VERSION_CASE(3);
       MIN_VERSION_CASE(4);
       MIN_VERSION_CASE(5);
+      MIN_VERSION_CASE(6);
 #undef MIN_VERSION_CASE
     default:
       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
@@ -935,6 +936,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
     return processStructType(operands);
   case spirv::Opcode::OpTypeMatrix:
     return processMatrixType(operands);
+  case spirv::Opcode::OpTypeTensorARM:
+    return processTensorARMType(operands);
   default:
     return emitError(unknownLoc, "unhandled type instruction");
   }
@@ -1238,6 +1241,55 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
+  unsigned size = operands.size();
+  if (size < 2 || size > 4) {
+    return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
+                                 "(result_id, element_type, (rank), (shape))")
+           << size;
+  }
+  Type elementTy = getType(operands[1]);
+  if (!elementTy) {
+    return emitError(unknownLoc,
+                     "OpTypeTensorARM references undefined element type.")
+           << operands[1];
+  }
+  if (size == 2) {
+    typeMap[operands[0]] = TensorArmType::get({}, elementTy);
+    return success();
+  }
+
+  auto rankAttr = getConstantInt(operands[2]);
+  if (!rankAttr)
+    return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
+                                 "scalar integer constant instruction");
+  unsigned rank = rankAttr.getValue().getZExtValue();
+  if (size == 3) {
+    SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
+    typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+    return success();
+  }
+
+  auto shapeInfo = getConstant(operands[3]);
+  if (!shapeInfo) {
+    return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
+                                 "constant instruction of type OpTypeArray");
+  }
+  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
+  SmallVector<int64_t, 1> shape;
+  for (auto dimAttr : shapeArrayAttr.getValue()) {
+    auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
+    if (!dimIntAttr) {
+      return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
+                                   "dimension size");
+    }
+    shape.push_back(dimIntAttr.getValue().getSExtValue());
+  }
+  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
+  return success();
+}
+
 LogicalResult
 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index e4556e7652b17..1bc9e4a3c75d8 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -291,6 +291,8 @@ class Deserializer {
 
   LogicalResult processMatrixType(ArrayRef<uint32_t> operands);
 
+  LogicalResult processTensorARMType(ArrayRef<uint32_t> operands);
+
   LogicalResult processTypeForwardPointer(ArrayRef<uint32_t> operands);
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index d258bfd852961..ebebd2d283afa 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -729,6 +729,54 @@ LogicalResult Serializer::prepareBasicType(
     return success();
   }
 
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+    uint32_t elementTypeID = 0;
+    uint32_t rank = 0;
+    uint32_t shapeID = 0;
+    uint32_t rankID = 0;
+    if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
+                               elementTypeID, serializationCtx))) {
+      return failure();
+    }
+    if (tensorArmType.hasRank()) {
+      ArrayRef<int64_t> dims = tensorArmType.getShape();
+      rank = dims.size();
+      rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
+      if (rankID == 0) {
+        return failure();
+      }
+
+      bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
+      if (rank > 0 && shaped) {
+        auto I32Type = IntegerType::get(type.getContext(), 32);
+        auto shapeType = ArrayType::get(I32Type, rank);
+        if (rank == 1) {
+          SmallVector<uint64_t, 1> index(rank);
+          shapeID = prepareDenseElementsConstant(
+              loc, shapeType,
+              mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
+              index);
+        } else {
+          shapeID = prepareArrayConstant(
+              loc, shapeType,
+              mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
+        }
+        if (shapeID == 0) {
+          return failure();
+        }
+      }
+    }
+    typeEnum = spirv::Opcode::OpTypeTensorARM;
+    operands.push_back(elementTypeID);
+    if (rankID == 0)
+      return success();
+    operands.push_back(rankID);
+    if (shapeID == 0)
+      return success();
+    operands.push_back(shapeID);
+    return success();
+  }
+
   // TODO: Handle other types.
   return emitError(loc, "unhandled type in serialization: ") << type;
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index c23894c62826b..7d45b5ea82643 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -564,3 +564,54 @@ func.func private @matrix_size_type(!spirv.matrix< x vector<3xi32>>) -> ()
 func.func private @matrix_size_type(!spirv.matrix<2.0 x vector<3xi32>>) -> ()
 
 // -----
+
+//===----------------------------------------------------------------------===//
+// TensorArm
+//===----------------------------------------------------------------------===//
+
+// CHECK: func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>)
+func.func private @arm_tensor_type_single_dim_i32(!spirv.arm.tensor<1xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>)
+func.func private @arm_tensor_type_multi_dim_i32(!spirv.arm.tensor<1x2x3xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>)
+func.func private @arm_tensor_type_single_dim_f16(!spirv.arm.tensor<1xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>)
+func.func private @arm_tensor_type_multi_dim_f16(!spirv.arm.tensor<1x2x3xf16>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>)
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<?xi32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>)
+func.func private @arm_tensor_type_dynamic_dim_2(!spirv.arm.tensor<?x?xi32>) -> ()
+// -----
+
+// expected-error @+1 {{arm.tensor shape dimensions must be either fully dynamic or completed shaped}}
+func.func private @arm_tensor_type_dynamic_dim(!spirv.arm.tensor<1x?xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support rank zero}}
+func.func private @arm_tensor_rank_zero(!spirv.arm.tensor<i32>) -> ()
+
+// -----
+
+// CHECK: func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>)
+func.func private @arm_tensor_type_unranked(!spirv.arm.tensor<*xi32>) -> ()
+
+// -----
+
+// expected-error @+1 {{arm.tensors do not support zero dimensions}}
+func.func private @arm_tensor_type_zero_dim(!spirv.arm.tensor<0xi32>) -> ()
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
new file mode 100644
index 0000000000000..25c2a25b47d88
--- /dev/null
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -0,0 +1,66 @@
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.6, [Shader, TensorsARM], [SPV_ARM_tensors]> {
+  // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
+  spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @shaped_rank2_int_arm_tensor(%arg0: !spirv.arm.tensor<2x3xi32>) "None" {
+  spirv.func @shaped_rank2_int_arm_tensor(%arg0 : !spirv.arm.tensor<2x3xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xi64> "None" {
+  spirv.func @ui64_arm_tensor_const() -> !spirv.arm.tensor<3xui64> "None" {
+    // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xi64>
+    %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xui64>
+
+    spirv.ReturnValue %0: !spirv.arm.tensor<3xui64>
+  }
+
+// -----
+
+  // CHECK: spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+  spirv.func @si32_arm_tensor_const() -> !spirv.arm.tensor<3xsi32> "None" {
+    // CHECK: spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+    %0 = spirv.Constant dense<[5, 6, 7]> : !spirv.arm.tensor<3xsi32>
+
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xsi32>
+  }
+
+// -----
+
+  // CHECK: spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+  spirv.func @float_arm_tensor_const() -> !spirv.arm.tensor<3xf32> "None" {
+    // CHECK: spirv.Constant dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : !spirv.arm.tensor<3xf32>
+    %0 = spirv.Constant dense<[3., 4., 5.]> : !spirv.arm.tensor<3xf32>
+
+    spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32>
+  }
+
+// -----
+
+  // CHECK: spirv.func @unranked_int_arm_tensor(%arg0: !spirv.arm.tensor<*xi32>) "None" {
+  spirv.func @unranked_int_arm_tensor(%arg0 : !spirv.arm.tensor<*xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @unshaped_int_arm_tensor(%arg0: !spirv.arm.tensor<?xi32>) "None" {
+  spirv.func @unshaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<?xi32>) "None" {
+    spirv.Return
+  }
+
+// -----
+
+  // CHECK: spirv.func @unshaped_int_arm_tensor_2(%arg0: !spirv.arm.tensor<?x?xi32>) "None" {
+  spirv.func @unshaped_int_arm_tensor_2(%arg0 : !spirv.arm.tensor<?x?xi32>) "None" {
+    spirv.Return
+  }
+}

>From 4750c732d13dd609924045df31dc685cf193849b Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 19 Jun 2025 10:36:29 +0200
Subject: [PATCH 2/4] Resolve several review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I9f99f2e5efc1b433bcf885a4890c730cb8cac213
---
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        |  6 ++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp    | 16 ++++-----
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 34 +++++++------------
 .../SPIRV/Deserialization/Deserializer.cpp    | 23 ++++++-------
 4 files changed, 33 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 7ffea6e7dba81..2d3a4e1778490 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -97,8 +97,8 @@ class ScalarType : public SPIRVType {
   std::optional<int64_t> getSizeInBytes();
 };
 
-// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
-// StructType.
+// SPIR-V composite type: VectorType, SPIR-V ArrayType, SPIR-V
+// StructType, or SPIR-V TensorArmType.
 class CompositeType : public SPIRVType {
 public:
   using SPIRVType::SPIRVType;
@@ -479,7 +479,7 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
                        std::optional<StorageClass> storage = std::nullopt);
 };
 
-// SPIR-V TensorARM Type
+/// SPIR-V TensorARM Type
 class TensorArmType
     : public Type::TypeBase<TensorArmType, CompositeType,
                             detail::TensorArmTypeStorage, ShapedType::Trait> {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 15002f1d5d16e..88c7adf3dfcb3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,8 +194,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
           << t.getNumElements();
       return Type();
     }
-  } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
-    if (!llvm::isa<ScalarType>(t.getElementType())) {
+  } else if (auto t = dyn_cast<TensorArmType>(type)) {
+    if (!isa<ScalarType>(t.getElementType())) {
       parser.emitError(
           typeLoc, "only scalar element type allowed in tensor type but found ")
           << t.getElementType();
@@ -385,24 +385,22 @@ static Type parseTensorArmType(SPIRVDialect const &dialect,
     unranked = true;
     if (parser.parseXInDimensionList())
       return {};
-  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+  } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) {
     return {};
+  }
 
   if (!unranked && dims.empty()) {
     parser.emitError(countLoc, "arm.tensors do not support rank zero");
     return {};
   }
 
-  if (std::any_of(dims.begin(), dims.end(),
-                  [](int64_t dim) { return dim == 0; })) {
+  if (llvm::is_contained(dims, 0)) {
     parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
     return {};
   }
 
-  if (std::any_of(dims.begin(), dims.end(),
-                  [](int64_t dim) { return dim < 0; }) &&
-      std::any_of(dims.begin(), dims.end(),
-                  [](int64_t dim) { return dim > 0; })) {
+  if (llvm::any_of(dims, [](int64_t dim) { return dim < 0; }) &&
+      llvm::any_of(dims, [](int64_t dim) { return dim > 0; })) {
     parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
                                "fully dynamic or completed shaped");
     return {};
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index e4eeb0a7f37d5..18e13a7d5d3ec 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -20,7 +20,6 @@
 
 #include <algorithm>
 #include <cstdint>
-#include <iterator>
 #include <numeric>
 
 using namespace mlir;
@@ -127,7 +126,7 @@ unsigned CompositeType::getNumElements() const {
     return structType.getNumElements();
   if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
     return vectorType.getNumElements();
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+  if (auto tensorArmType = dyn_cast<TensorArmType>(*this))
     return tensorArmType.getNumElements();
   if (llvm::isa<CooperativeMatrixType>(*this)) {
     llvm_unreachable(
@@ -156,9 +155,7 @@ void CompositeType::getExtensions(
             .getExtensions(extensions, storage);
       })
       .Case<TensorArmType>([&](TensorArmType type) {
-        static const Extension exts[] = {Extension::SPV_ARM_tensors};
-        ArrayRef<Extension> ref(exts, std::size(exts));
-        extensions.push_back(ref);
+        extensions.push_back({Extension::SPV_ARM_tensors});
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
@@ -184,9 +181,7 @@ void CompositeType::getCapabilities(
             .getCapabilities(capabilities, storage);
       })
       .Case<TensorArmType>([&](TensorArmType type) {
-        static const Capability caps[] = {Capability::TensorsARM};
-        ArrayRef<Capability> ref(caps, std::size(caps));
-        capabilities.push_back(ref);
+        capabilities.push_back({Capability::TensorsARM});
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
@@ -1245,15 +1240,15 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
 
   static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
                                          const KeyTy &key) {
-    auto shape = std::get<0>(key);
-    auto elementType = std::get<1>(key);
+    auto [shape, elementType] = key;
     shape = allocator.copyInto(shape);
     return new (allocator.allocate<TensorArmTypeStorage>())
         TensorArmTypeStorage(std::move(shape), std::move(elementType));
   }
 
   static llvm::hash_code hashKey(const KeyTy &key) {
-    return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+    auto [shape, elementType] = key;
+    return llvm::hash_combine(shape, elementType);
   }
 
   bool operator==(const KeyTy &key) const {
@@ -1280,7 +1275,7 @@ Type TensorArmType::getElementType() const { return getImpl()->elementType; }
 ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
 
 unsigned TensorArmType::getNumElements() const {
-  auto shape = getShape();
+  ArrayRef<int64_t> shape = getShape();
   return std::accumulate(shape.begin(), shape.end(), unsigned(1),
                          std::multiplies<unsigned>());
 }
@@ -1290,8 +1285,7 @@ void TensorArmType::getExtensions(
     std::optional<StorageClass> storage) {
 
   llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
-  static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
-  extensions.push_back(exts);
+  extensions.push_back({Extension::SPV_ARM_tensors});
 }
 
 void TensorArmType::getCapabilities(
@@ -1299,20 +1293,16 @@ void TensorArmType::getCapabilities(
     std::optional<StorageClass> storage) {
   llvm::cast<SPIRVType>(getElementType())
       .getCapabilities(capabilities, storage);
-  static constexpr Capability caps[] = {Capability::TensorsARM};
-  capabilities.push_back(caps);
+  capabilities.push_back({Capability::TensorsARM});
 }
 
 LogicalResult
 TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
                                 ArrayRef<int64_t> shape, Type elementType) {
-  if (std::any_of(shape.begin(), shape.end(),
-                  [](int64_t dim) { return dim == 0; }))
+  if (llvm::is_contained(shape, 0))
     return emitError() << "arm.tensor do not support dimensions = 0";
-  if (std::any_of(shape.begin(), shape.end(),
-                  [](int64_t dim) { return dim < 0; }) &&
-      std::any_of(shape.begin(), shape.end(),
-                  [](int64_t dim) { return dim > 0; }))
+  if (llvm::any_of(shape, [](int64_t dim) { return dim < 0; }) &&
+      llvm::any_of(shape, [](int64_t dim) { return dim > 0; }))
     return emitError()
            << "arm.tensor shape dimensions must be either fully dynamic or "
               "completed shaped";
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index f0e42047c559e..b801f5a4660fc 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1244,23 +1244,23 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
 LogicalResult
 spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
   unsigned size = operands.size();
-  if (size < 2 || size > 4) {
+  if (size < 2 || size > 4)
     return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
-                                 "(result_id, element_type, (rank), (shape))")
+                                 "(result_id, element_type, (rank), (shape)) ")
            << size;
-  }
+
   Type elementTy = getType(operands[1]);
-  if (!elementTy) {
+  if (!elementTy)
     return emitError(unknownLoc,
-                     "OpTypeTensorARM references undefined element type.")
+                     "OpTypeTensorARM references undefined element type ")
            << operands[1];
-  }
+
   if (size == 2) {
     typeMap[operands[0]] = TensorArmType::get({}, elementTy);
     return success();
   }
 
-  auto rankAttr = getConstantInt(operands[2]);
+  IntegerAttr rankAttr = getConstantInt(operands[2]);
   if (!rankAttr)
     return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
                                  "scalar integer constant instruction");
@@ -1271,19 +1271,18 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
     return success();
   }
 
-  auto shapeInfo = getConstant(operands[3]);
-  if (!shapeInfo) {
+  std::optional<std::pair<Attribute, Type>> shapeInfo = getConstant(operands[3]);
+  if (!shapeInfo)
     return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
                                  "constant instruction of type OpTypeArray");
-  }
+
   ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
   SmallVector<int64_t, 1> shape;
   for (auto dimAttr : shapeArrayAttr.getValue()) {
     auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
-    if (!dimIntAttr) {
+    if (!dimIntAttr)
       return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
                                    "dimension size");
-    }
     shape.push_back(dimIntAttr.getValue().getSExtValue());
   }
   typeMap[operands[0]] = TensorArmType::get(shape, elementTy);

>From a96196f9a5597d9c6d8774836fc1ab95af4493aa Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Thu, 19 Jun 2025 13:53:13 +0200
Subject: [PATCH 3/4] Fix more comment and formatting

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I527cdeb796135237ef5d0965d5f8d54a36515458
---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp               | 2 +-
 mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 18e13a7d5d3ec..e3ed16da2a6de 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1243,7 +1243,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
     auto [shape, elementType] = key;
     shape = allocator.copyInto(shape);
     return new (allocator.allocate<TensorArmTypeStorage>())
-        TensorArmTypeStorage(std::move(shape), std::move(elementType));
+        TensorArmTypeStorage(shape, elementType);
   }
 
   static llvm::hash_code hashKey(const KeyTy &key) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index b801f5a4660fc..893aa38da93d1 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1271,7 +1271,8 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
     return success();
   }
 
-  std::optional<std::pair<Attribute, Type>> shapeInfo = getConstant(operands[3]);
+  std::optional<std::pair<Attribute, Type>> shapeInfo =
+      getConstant(operands[3]);
   if (!shapeInfo)
     return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
                                  "constant instruction of type OpTypeArray");

>From d4464e595ea313b106bf9069005e99fafc23f237 Mon Sep 17 00:00:00 2001
From: Davide Grohmann <davide.grohmann at arm.com>
Date: Wed, 25 Jun 2025 12:21:41 +0200
Subject: [PATCH 4/4] Resolve more review comments

Signed-off-by: Davide Grohmann <davide.grohmann at arm.com>
Change-Id: I9ee3c5755d123fcf3ef203ac5d76af9511c5ef7a
---
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h   | 10 +++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp          | 15 +++++----------
 .../Target/SPIRV/Deserialization/Deserializer.cpp |  1 -
 mlir/test/Target/SPIRV/tensorARM.mlir             |  2 +-
 4 files changed, 15 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2d3a4e1778490..6fa09888e887b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -486,6 +486,15 @@ class TensorArmType
 public:
   using Base::Base;
 
+  using ShapedType::Trait<TensorArmType>::getElementTypeBitWidth;
+  using ShapedType::Trait<TensorArmType>::getRank;
+  using ShapedType::Trait<TensorArmType>::getNumElements;
+  using ShapedType::Trait<TensorArmType>::isDynamicDim;
+  using ShapedType::Trait<TensorArmType>::hasStaticShape;
+  using ShapedType::Trait<TensorArmType>::getNumDynamicDims;
+  using ShapedType::Trait<TensorArmType>::getDimSize;
+  using ShapedType::Trait<TensorArmType>::getDynamicDimIndex;
+
   static constexpr StringLiteral name = "spirv.arm.tensor";
 
   // TensorArm supports minimum rank of 1, hence an empty shape here means
@@ -500,7 +509,6 @@ class TensorArmType
 
   Type getElementType() const;
   ArrayRef<int64_t> getShape() const;
-  unsigned getNumElements() const;
   bool hasRank() const { return !getShape().empty(); }
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index e3ed16da2a6de..b288eb2edc315 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -155,7 +155,8 @@ void CompositeType::getExtensions(
             .getExtensions(extensions, storage);
       })
       .Case<TensorArmType>([&](TensorArmType type) {
-        extensions.push_back({Extension::SPV_ARM_tensors});
+        static constexpr Extension ext{Extension::SPV_ARM_tensors};
+        extensions.push_back(ext);
         return llvm::cast<ScalarType>(type.getElementType())
             .getExtensions(extensions, storage);
       })
@@ -181,7 +182,8 @@ void CompositeType::getCapabilities(
             .getCapabilities(capabilities, storage);
       })
       .Case<TensorArmType>([&](TensorArmType type) {
-        capabilities.push_back({Capability::TensorsARM});
+        static constexpr Capability cap{Capability::TensorsARM};
+        capabilities.push_back(cap);
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
       })
@@ -712,9 +714,8 @@ bool SPIRVType::classof(Type type) {
     return true;
   if (auto vectorType = llvm::dyn_cast<VectorType>(type))
     return CompositeType::isValid(vectorType);
-  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type))
     return llvm::isa<ScalarType>(tensorArmType.getElementType());
-  }
   return false;
 }
 
@@ -1274,12 +1275,6 @@ TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
 Type TensorArmType::getElementType() const { return getImpl()->elementType; }
 ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
 
-unsigned TensorArmType::getNumElements() const {
-  ArrayRef<int64_t> shape = getShape();
-  return std::accumulate(shape.begin(), shape.end(), unsigned(1),
-                         std::multiplies<unsigned>());
-}
-
 void TensorArmType::getExtensions(
     SPIRVType::ExtensionArrayRefVector &extensions,
     std::optional<StorageClass> storage) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 893aa38da93d1..b1abd8b3dffe9 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -138,7 +138,6 @@ LogicalResult spirv::Deserializer::processHeader() {
       MIN_VERSION_CASE(3);
       MIN_VERSION_CASE(4);
       MIN_VERSION_CASE(5);
-      MIN_VERSION_CASE(6);
 #undef MIN_VERSION_CASE
     default:
       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
diff --git a/mlir/test/Target/SPIRV/tensorARM.mlir b/mlir/test/Target/SPIRV/tensorARM.mlir
index 25c2a25b47d88..75b648ebfd008 100644
--- a/mlir/test/Target/SPIRV/tensorARM.mlir
+++ b/mlir/test/Target/SPIRV/tensorARM.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
 
-spirv.module Logical GLSL450 requires #spirv.vce<v1.6, [Shader, TensorsARM], [SPV_ARM_tensors]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, TensorsARM], [SPV_ARM_tensors]> {
   // CHECK: spirv.func @shaped_int_arm_tensor(%arg0: !spirv.arm.tensor<2xi32>) "None" {
   spirv.func @shaped_int_arm_tensor(%arg0 : !spirv.arm.tensor<2xi32>) "None" {
     spirv.Return



More information about the Mlir-commits mailing list