[Mlir-commits] [mlir] [mlir][spirv] Make `MatrixType` type a `ShapedType` (PR #185470)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 10:35:21 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod)

<details>
<summary>Changes</summary>

This will allow to enforce some of the type constraints in ODS using builtin classes e.g., `AllElementTypesMatch`.

This is a first PR in a series of PRs moving all verification for Matrix ops to ODS -- having multiple small PRs makes a logical sense and makes the review process easier.

---
Full diff: https://github.com/llvm/llvm-project/pull/185470.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+23-2) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+19-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2c54e95ef11b8..4a0c29d4b5d90 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -425,8 +425,9 @@ class CooperativeMatrixType
 };
 
 // SPIR-V matrix type
-class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
-                                         detail::MatrixTypeStorage> {
+class MatrixType
+    : public Type::TypeBase<MatrixType, CompositeType,
+                            detail::MatrixTypeStorage, ShapedType::Trait> {
 public:
   using Base::Base;
 
@@ -457,6 +458,26 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
 
   /// Returns the elements' type (i.e, single element type).
   Type getElementType() const;
+
+  operator ShapedType() const { return cast<ShapedType>(*this); }
+
+  ArrayRef<int64_t> getShape() const;
+
+  bool hasRank() const { return true; }
+
+  MatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                       Type elementType) const {
+    if (!shape)
+      return get(elementType, getNumColumns());
+
+    assert(shape.value().size() == 2);
+
+    auto vectorType = cast<VectorType>(elementType);
+    Type newElementType =
+        vectorType.cloneWith({shape.value()[0]}, vectorType.getElementType());
+
+    return get(newElementType, shape.value()[1]);
+  }
 };
 
 /// SPIR-V TensorARM Type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 63b51d1836f75..6b25f1a86b5ee 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -1162,25 +1162,30 @@ llvm::hash_code spirv::hash_value(
 //===----------------------------------------------------------------------===//
 
 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
-  MatrixTypeStorage(Type columnType, uint32_t columnCount)
-      : columnType(columnType), columnCount(columnCount) {}
+  // Use a 64-bit integer as a column count internally to better support a
+  // `ShapedType` interface. See comment in `CooperativeMatrixType` for more
+  // context.
+  using KeyTy = std::tuple<Type, int64_t>;
 
-  using KeyTy = std::tuple<Type, uint32_t>;
+  MatrixTypeStorage(const KeyTy &key)
+      : columnType(std::get<0>(key)),
+        shape({cast<VectorType>(std::get<0>(key)).getShape()[0],
+               std::get<1>(key)}) {}
 
   static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
                                       const KeyTy &key) {
 
     // Initialize the memory using placement new.
-    return new (allocator.allocate<MatrixTypeStorage>())
-        MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
+    return new (allocator.allocate<MatrixTypeStorage>()) MatrixTypeStorage(key);
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(columnType, columnCount);
+    return key == KeyTy(columnType, shape[1]);
   }
 
   Type columnType;
-  const uint32_t columnCount;
+  // [#rows, #columns]
+  std::array<int64_t, 2> shape;
 };
 
 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
@@ -1228,16 +1233,21 @@ Type MatrixType::getElementType() const {
   return cast<VectorType>(getImpl()->columnType).getElementType();
 }
 
-unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
+unsigned MatrixType::getNumColumns() const {
+  assert(getImpl()->shape[1] != ShapedType::kDynamic);
+  return static_cast<uint32_t>(getImpl()->shape[1]);
+}
 
 unsigned MatrixType::getNumRows() const {
   return cast<VectorType>(getImpl()->columnType).getShape()[0];
 }
 
 unsigned MatrixType::getNumElements() const {
-  return (getImpl()->columnCount) * getNumRows();
+  return getNumColumns() * getNumRows();
 }
 
+ArrayRef<int64_t> MatrixType::getShape() const { return getImpl()->shape; }
+
 void TypeCapabilityVisitor::addConcrete(MatrixType type) {
   add(type.getColumnType());
   static constexpr auto cap = Capability::Matrix;

``````````

</details>


https://github.com/llvm/llvm-project/pull/185470


More information about the Mlir-commits mailing list