[Mlir-commits] [mlir] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType` (PR #142784)

Igor Wodiany llvmlistbot at llvm.org
Wed Jun 4 08:11:46 PDT 2025


https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/142784

This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example:

```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```

Additional constraints are added to arithmetic operations, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match.

This patch does not enable the actual deserialization. This is done in *TBD*.

>From c2d667e440996bac737bd043fdc7be352c031d0a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 20 May 2025 17:42:12 +0100
Subject: [PATCH] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType`

This is to enable `CooperativeMatrixType` to be used with
`DenseElementsAttr`, so that a `spirv.Constant` can be easily
built from `OpConstantComposite`. For example:

```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```

Additional constraints are added to arithmetic operations, as
`SameOperandsAndResultType` can no longer fully verify CoopMatrices.
This is because for shaped types the verifier only checks
element type and shapes, whereas for any other arbitrary type it
looks for an exact match.

This patch does not enable the actual deserialization.
---
 .../Dialect/SPIRV/IR/SPIRVArithmeticOps.td    | 14 ++++++++++--
 .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h        | 20 ++++++++++++++++-
 mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp      | 22 ++++++++++++-------
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  |  4 ++--
 4 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+  "cooperative matrix types match",
+  CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+        "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+        "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
 class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                list<Trait> traits = []> :
       // Operands type same as result type.
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand1", "operand2">,
+                                SPIRV_SameCoopMatrix<"operand2", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
       // Operand type same as result type.
       SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
-                            detail::CooperativeMatrixTypeStorage> {
+                            detail::CooperativeMatrixTypeStorage,
+                            ShapedType::Trait> {
 public:
   using Base::Base;
 
@@ -418,6 +419,23 @@ class CooperativeMatrixType
                      std::optional<StorageClass> storage = std::nullopt);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
                        std::optional<StorageClass> storage = std::nullopt);
+
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  ArrayRef<int64_t> getShape() const;
+
+  bool hasRank() const { return true; }
+
+  CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                  Type elementType) const {
+    if (shape == std::nullopt)
+      return get(elementType, getRows(), getColumns(), getScope(), getUse());
+    else {
+      assert(shape.value().size() == 2);
+      return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+                 getUse());
+    }
+  }
 };
 
 // SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
 
 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   using KeyTy =
-      std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+      std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
 
   static CooperativeMatrixTypeStorage *
   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, rows, columns, scope, use);
+    return key == KeyTy(elementType, shape[0], shape[1], scope, use);
   }
 
   CooperativeMatrixTypeStorage(const KeyTy &key)
-      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
-        columns(std::get<2>(key)), scope(std::get<3>(key)),
+      : elementType(std::get<0>(key)),
+        shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
         use(std::get<4>(key)) {}
 
   Type elementType;
-  uint32_t rows;
-  uint32_t columns;
+  // [#rows, #columns]
+  SmallVector<int64_t, 2> shape;
   Scope scope;
   CooperativeMatrixUseKHR use;
 };
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
   return getImpl()->elementType;
 }
 
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+  return static_cast<uint32_t>(getImpl()->shape[0]);
+}
 
 uint32_t CooperativeMatrixType::getColumns() const {
-  return getImpl()->columns;
+  return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+  return getImpl()->shape;
 }
 
 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
 
 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.IAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
     -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
 
 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.FAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
     -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>



More information about the Mlir-commits mailing list