[Mlir-commits] [mlir] [mlir][spirv] Make `MatrixType` type a `ShapedType` (PR #185470)
Igor Wodiany
llvmlistbot at llvm.org
Mon Mar 9 10:34:42 PDT 2026
https://github.com/IgWod created https://github.com/llvm/llvm-project/pull/185470
This will allow to enforce some of the type constraints in ODS using builtin classes e.g., `AllElementTypesMatch`.
This is a first PR in a series of PRs moving all verification for Matrix ops to ODS -- having multiple small PRs makes a logical sense and makes the review process easier.
>From 05e11b6fe4afeb65dd67cd8a8114ef84a83e5c19 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <dev at wodiany.com>
Date: Fri, 6 Mar 2026 21:51:11 +0000
Subject: [PATCH] [mlir][spirv] Make `MatrixType` type a `ShapedType`
This will allow to enforce some of the type constraints in ODS
using builtin classes e.g., `AllElementTypesMatch`.
---
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 25 +++++++++++++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 28 +++++++++++++------
2 files changed, 42 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2c54e95ef11b8..4a0c29d4b5d90 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -425,8 +425,9 @@ class CooperativeMatrixType
};
// SPIR-V matrix type
-class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
- detail::MatrixTypeStorage> {
+class MatrixType
+ : public Type::TypeBase<MatrixType, CompositeType,
+ detail::MatrixTypeStorage, ShapedType::Trait> {
public:
using Base::Base;
@@ -457,6 +458,26 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
/// Returns the elements' type (i.e, single element type).
Type getElementType() const;
+
+ operator ShapedType() const { return cast<ShapedType>(*this); }
+
+ ArrayRef<int64_t> getShape() const;
+
+ bool hasRank() const { return true; }
+
+ MatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (!shape)
+ return get(elementType, getNumColumns());
+
+ assert(shape.value().size() == 2);
+
+ auto vectorType = cast<VectorType>(elementType);
+ Type newElementType =
+ vectorType.cloneWith({shape.value()[0]}, vectorType.getElementType());
+
+ return get(newElementType, shape.value()[1]);
+ }
};
/// SPIR-V TensorARM Type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 63b51d1836f75..6b25f1a86b5ee 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1162,25 +1162,30 @@ llvm::hash_code spirv::hash_value(
//===----------------------------------------------------------------------===//
struct spirv::detail::MatrixTypeStorage : public TypeStorage {
- MatrixTypeStorage(Type columnType, uint32_t columnCount)
- : columnType(columnType), columnCount(columnCount) {}
+ // Use a 64-bit integer as a column count internally to better support a
+ // `ShapedType` interface. See comment in `CooperativeMatrixType` for more
+ // context.
+ using KeyTy = std::tuple<Type, int64_t>;
- using KeyTy = std::tuple<Type, uint32_t>;
+ MatrixTypeStorage(const KeyTy &key)
+ : columnType(std::get<0>(key)),
+ shape({cast<VectorType>(std::get<0>(key)).getShape()[0],
+ std::get<1>(key)}) {}
static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
// Initialize the memory using placement new.
- return new (allocator.allocate<MatrixTypeStorage>())
- MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
+ return new (allocator.allocate<MatrixTypeStorage>()) MatrixTypeStorage(key);
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(columnType, columnCount);
+ return key == KeyTy(columnType, shape[1]);
}
Type columnType;
- const uint32_t columnCount;
+ // [#rows, #columns]
+ std::array<int64_t, 2> shape;
};
MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
@@ -1228,16 +1233,21 @@ Type MatrixType::getElementType() const {
return cast<VectorType>(getImpl()->columnType).getElementType();
}
-unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
+unsigned MatrixType::getNumColumns() const {
+ assert(getImpl()->shape[1] != ShapedType::kDynamic);
+ return static_cast<uint32_t>(getImpl()->shape[1]);
+}
unsigned MatrixType::getNumRows() const {
return cast<VectorType>(getImpl()->columnType).getShape()[0];
}
unsigned MatrixType::getNumElements() const {
- return (getImpl()->columnCount) * getNumRows();
+ return getNumColumns() * getNumRows();
}
+ArrayRef<int64_t> MatrixType::getShape() const { return getImpl()->shape; }
+
void TypeCapabilityVisitor::addConcrete(MatrixType type) {
add(type.getColumnType());
static constexpr auto cap = Capability::Matrix;
More information about the Mlir-commits
mailing list