[Mlir-commits] [mlir] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType` (PR #142784)

Igor Wodiany llvmlistbot at llvm.org
Mon Jun 9 07:28:12 PDT 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/142784

>From c2d667e440996bac737bd043fdc7be352c031d0a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 20 May 2025 17:42:12 +0100
Subject: [PATCH 1/3] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType`

This is to enable `CooperativeMatrixType` to be used with
`DenseElementsAttr`, so that a `spirv.Constant` can be easily
built from `OpConstantComposite`. For example:

```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```

Additional constraints are added to arithmetic operations, as
`SameOperandsAndResultType` can no longer fully verify CoopMatrices.
This is because for shaped types the verifier only checks
element type and shapes, whereas for any other arbitrary type it
looks for an exact match.

This patch does not enable the actual deserialization.
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    | 14 ++++++++++--
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 20 ++++++++++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 22 ++++++++++++-------
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  |  4 ++--
 4 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+  "cooperative matrix types match",
+  CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+        "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+        "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
 class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                list<Trait> traits = []> :
       // Operands type same as result type.
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand1", "operand2">,
+                                SPIRV_SameCoopMatrix<"operand2", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
       // Operand type same as result type.
       SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
-                            detail::CooperativeMatrixTypeStorage> {
+                            detail::CooperativeMatrixTypeStorage,
+                            ShapedType::Trait> {
 public:
   using Base::Base;
 
@@ -418,6 +419,23 @@ class CooperativeMatrixType
                      std::optional<StorageClass> storage = std::nullopt);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
                        std::optional<StorageClass> storage = std::nullopt);
+
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  ArrayRef<int64_t> getShape() const;
+
+  bool hasRank() const { return true; }
+
+  CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                  Type elementType) const {
+    if (shape == std::nullopt)
+      return get(elementType, getRows(), getColumns(), getScope(), getUse());
+    else {
+      assert(shape.value().size() == 2);
+      return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+                 getUse());
+    }
+  }
 };
 
 // SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
 
 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   using KeyTy =
-      std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+      std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
 
   static CooperativeMatrixTypeStorage *
   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, rows, columns, scope, use);
+    return key == KeyTy(elementType, shape[0], shape[1], scope, use);
   }
 
   CooperativeMatrixTypeStorage(const KeyTy &key)
-      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
-        columns(std::get<2>(key)), scope(std::get<3>(key)),
+      : elementType(std::get<0>(key)),
+        shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
         use(std::get<4>(key)) {}
 
   Type elementType;
-  uint32_t rows;
-  uint32_t columns;
+  // [#rows, #columns]
+  SmallVector<int64_t, 2> shape;
   Scope scope;
   CooperativeMatrixUseKHR use;
 };
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
   return getImpl()->elementType;
 }
 
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+  return static_cast<uint32_t>(getImpl()->shape[0]);
+}
 
 uint32_t CooperativeMatrixType::getColumns() const {
-  return getImpl()->columns;
+  return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+  return getImpl()->shape;
 }
 
 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
 
 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.IAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
     -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
 
 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.FAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
     -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

>From 760787b2496a234363b0d162074c100d1398a2f3 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 9 Jun 2025 14:23:29 +0100
Subject: [PATCH 2/3] Address feedback

---
 .../mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td   | 14 ++------------
 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h   | 11 +++++------
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp          | 15 ++++++++++++++-
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir      |  4 ++--
 4 files changed, 23 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 48f525e048e60..309079e549846 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,21 +18,12 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
-class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
-  "cooperative matrix types match",
-  CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
-        "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
-        "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
->;
-
 class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                list<Trait> traits = []> :
       // Operands type same as result type.
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType,
-                                SPIRV_SameCoopMatrix<"operand1", "operand2">,
-                                SPIRV_SameCoopMatrix<"operand2", "result">])> {
+                               [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
@@ -51,8 +42,7 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
       // Operand type same as result type.
       SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType,
-                                SPIRV_SameCoopMatrix<"operand", "result">])> {
+                               [Pure, AllTypesMatch<["operand", "result"]>])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index a7b6569245dd5..787535d0a6bd2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -428,13 +428,12 @@ class CooperativeMatrixType
 
   CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                   Type elementType) const {
-    if (shape == std::nullopt)
+    if (!shape)
       return get(elementType, getRows(), getColumns(), getScope(), getUse());
-    else {
-      assert(shape.value().size() == 2);
-      return get(elementType, shape.value()[0], shape.value()[1], getScope(),
-                 getUse());
-    }
+
+    assert(shape.value().size() == 2);
+    return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+               getUse());
   }
 };
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index de2034680cd5f..2ed78db52c87a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -194,6 +194,19 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
 //===----------------------------------------------------------------------===//
 
 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
+  // In the specification dimensions of the Cooperative Matrix are 32-bit
+  // integers --- the initial implementation kept those values as such. However,
+  // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
+  // as 32-bits and expose it as int64_t through `getShape`, however, this
+  // method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
+  // 32-bits integers would require an extra logic and storage. So, we diverge
+  // from the spec and internally represent the dimensions as 64-bit integers,
+  // so we can easily return an `ArrayRef` from `getShape` without any extra
+  // logic. Alternatively, we could store both rows and columns (both 32-bits)
+  // and shape (64-bits), assigning rows and columns to shape whenever
+  // `getShape` is called. This would be at the cost of extra logic and storage.
+  // Note: Because `ArrayRef` is returned we cannot construct an object in
+  // `getShape` on the fly.
   using KeyTy =
       std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
 
@@ -214,7 +227,7 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
 
   Type elementType;
   // [#rows, #columns]
-  SmallVector<int64_t, 2> shape;
+  std::array<int64_t, 2> shape;
   Scope scope;
   CooperativeMatrixUseKHR use;
 };
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 4ae8b70bf43ca..8733ff93768ab 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
 
 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
-  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
+  // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
   %q = "spirv.IAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
     -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
 
 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
-  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
+  // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
   %q = "spirv.FAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
     -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

>From ef0629545f9556c75d9b839a609a78ca54cd4e38 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 9 Jun 2025 15:27:00 +0100
Subject: [PATCH 3/3] Add asserts

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 2ed78db52c87a..1aff43c301334 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -245,10 +245,12 @@ Type CooperativeMatrixType::getElementType() const {
 }
 
 uint32_t CooperativeMatrixType::getRows() const {
+  assert(getImpl()->shape[0] != ShapedType::kDynamic);
   return static_cast<uint32_t>(getImpl()->shape[0]);
 }
 
 uint32_t CooperativeMatrixType::getColumns() const {
+  assert(getImpl()->shape[1] != ShapedType::kDynamic);
   return static_cast<uint32_t>(getImpl()->shape[1]);
 }
 



More information about the Mlir-commits mailing list