[Mlir-commits] [mlir] [mlir][spirv] Add support for SPV_ARM_tensors (PR #144667)

Jakub Kuderski llvmlistbot at llvm.org
Wed Jun 18 09:18:30 PDT 2025


================
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
   llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
 }
 
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+  using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+  static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+                                         const KeyTy &key) {
+    auto shape = std::get<0>(key);
+    auto elementType = std::get<1>(key);
+    shape = allocator.copyInto(shape);
+    return new (allocator.allocate<TensorArmTypeStorage>())
+        TensorArmTypeStorage(std::move(shape), std::move(elementType));
+  }
+
+  static llvm::hash_code hashKey(const KeyTy &key) {
+    return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+  }
+
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(shape, elementType);
+  }
+
+  TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+      : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+  ArrayRef<int64_t> shape;
+  Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+  return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                       Type elementType) const {
+  return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+  auto shape = getShape();
+  return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+                         std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+    SPIRVType::ExtensionArrayRefVector &extensions,
+    std::optional<StorageClass> storage) {
+
+  llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+  static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+  extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+    SPIRVType::CapabilityArrayRefVector &capabilities,
+    std::optional<StorageClass> storage) {
+  llvm::cast<SPIRVType>(getElementType())
+      .getCapabilities(capabilities, storage);
+  static constexpr Capability caps[] = {Capability::TensorsARM};
+  capabilities.push_back(caps);
+}
+
+LogicalResult
+TensorArmType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType) {
+  if (std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim == 0; }))
+    return emitError() << "arm.tensor do not support dimensions = 0";
+  if (std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim < 0; }) &&
+      std::any_of(shape.begin(), shape.end(),
+                  [](int64_t dim) { return dim > 0; }))
----------------
kuhar wrote:

also here

https://github.com/llvm/llvm-project/pull/144667


More information about the Mlir-commits mailing list