[Mlir-commits] [mlir] [mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix (PR #142786)
Igor Wodiany
llvmlistbot at llvm.org
Wed Jun 4 08:14:12 PDT 2025
https://github.com/IgWod-IMG created https://github.com/llvm/llvm-project/pull/142786
Depends on #142784.
>From c2d667e440996bac737bd043fdc7be352c031d0a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 20 May 2025 17:42:12 +0100
Subject: [PATCH 1/2] [mlir][spirv] Make `CooperativeMatrixType` a `ShapedType`
This is to enable `CooperativeMatrixType` to be used with
`DenseElementsAttr`, so that a `spirv.Constant` can be easily
built from `OpConstantComposite`. For example:
```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```
Additional constraints are added to arithmetic operations, as
`SameOperandsAndResultType` can no longer fully verify CoopMatrices.
This is because for shaped types the verifier only checks
element type and shapes, whereas for any other arbitrary type it
looks for an exact match.
This patch does not enable the actual deserialization.
---
.../Dialect/SPIRV/IR/SPIRVArithmeticOps.td | 14 ++++++++++--
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 20 ++++++++++++++++-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 22 ++++++++++++-------
.../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 4 ++--
4 files changed, 47 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+ "cooperative matrix types match",
+ CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+ "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+ "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand1", "operand2">,
+ SPIRV_SameCoopMatrix<"operand2", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
- detail::CooperativeMatrixTypeStorage> {
+ detail::CooperativeMatrixTypeStorage,
+ ShapedType::Trait> {
public:
using Base::Base;
@@ -418,6 +419,23 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
+
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ ArrayRef<int64_t> getShape() const;
+
+ bool hasRank() const { return true; }
+
+ CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (shape == std::nullopt)
+ return get(elementType, getRows(), getColumns(), getScope(), getUse());
+ else {
+ assert(shape.value().size() == 2);
+ return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+ getUse());
+ }
+ }
};
// SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
using KeyTy =
- std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+ std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, scope, use);
+ return key == KeyTy(elementType, shape[0], shape[1], scope, use);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
- : elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<3>(key)),
+ : elementType(std::get<0>(key)),
+ shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
use(std::get<4>(key)) {}
Type elementType;
- uint32_t rows;
- uint32_t columns;
+ // [#rows, #columns]
+ SmallVector<int64_t, 2> shape;
Scope scope;
CooperativeMatrixUseKHR use;
};
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
return getImpl()->elementType;
}
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+ return static_cast<uint32_t>(getImpl()->shape[0]);
+}
uint32_t CooperativeMatrixType::getColumns() const {
- return getImpl()->columns;
+ return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+ return getImpl()->shape;
}
Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.IAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.FAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
>From 4b75dc706bd23c7503c6c83a90ac25507ef327bd Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 3 Jun 2025 17:55:50 +0100
Subject: [PATCH 2/2] [mlir][spirv] Deserialize OpConstantComposite of type
Cooperative Matrix
---
.../SPIRV/Deserialization/Deserializer.cpp | 6 ++--
.../Target/SPIRV/Serialization/Serializer.cpp | 28 ++++++++++++++++---
mlir/test/Target/SPIRV/constant.mlir | 7 +++++
3 files changed, 34 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 7afd6e9b25b77..ab5abb4ca9408 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1449,11 +1449,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto vectorType = dyn_cast<VectorType>(resultType)) {
- auto attr = DenseElementsAttr::get(vectorType, elements);
+ if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
- constantMap.try_emplace(resultID, attr, resultType);
+ constantMap.try_emplace(resultID, attr, shapedType);
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db..4830223c2ebf5 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
+ int64_t numberOfConstituents = shapedType.getDimSize(dim);
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
- operands.reserve(shapedType.getDimSize(dim) + 2);
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
- for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
- index[dim] = i;
+
+ // "If the Result Type is a cooperative matrix type, then there must be only
+ // one Constituent, with scalar type matching the cooperative matrix Component
+ // Type, and all components of the matrix are initialized to that value."
+ // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
+ if (isa<spirv::CooperativeMatrixType>(constType)) {
+ // numberOfConstituents is 1, so we only need one more elements in the
+ // SmallVector, so the total is 3 (1 + 2).
+ operands.reserve(3);
+ // We set dim directly to `shapedType.getRank()` so the recursive call
+ // directly returns the scalar type.
if (auto elementID = prepareDenseElementsConstant(
- loc, elementType, valueAttr, dim + 1, index)) {
+ loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
operands.push_back(elementID);
} else {
return 0;
}
+ } else {
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ index[dim] = i;
+ if (auto elementID = prepareDenseElementsConstant(
+ loc, elementType, valueAttr, dim + 1, index)) {
+ operands.push_back(elementID);
+ } else {
+ return 0;
+ }
+ }
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f05..a018692afab81 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -277,4 +277,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%signed_minus_one = spirv.Constant -1 : si16
spirv.ReturnValue %signed_minus_one : si16
}
+
+ // CHECK-LABEL: @coop_matrix_const
+ spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ }
}
More information about the Mlir-commits
mailing list