[Mlir-commits] [mlir] [mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix (PR #142786)
Igor Wodiany
llvmlistbot at llvm.org
Mon Jun 9 10:16:36 PDT 2025
https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/142786
>From 8ae14d8efa667da885d5eb24b5562e44a1a43b06 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 3 Jun 2025 17:55:50 +0100
Subject: [PATCH 1/2] [mlir][spirv] Deserialize OpConstantComposite of type
Cooperative Matrix
---
.../SPIRV/Deserialization/Deserializer.cpp | 6 ++--
.../Target/SPIRV/Serialization/Serializer.cpp | 28 ++++++++++++++++---
mlir/test/Target/SPIRV/constant.mlir | 7 +++++
3 files changed, 34 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3957dbc0db984..c43d584d7b913 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto vectorType = dyn_cast<VectorType>(resultType)) {
- auto attr = DenseElementsAttr::get(vectorType, elements);
+ if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
- constantMap.try_emplace(resultID, attr, resultType);
+ constantMap.try_emplace(resultID, attr, shapedType);
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
auto attr = opBuilder.getArrayAttr(elements);
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..83e6c7ea7af1d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
return 0;
}
+ int64_t numberOfConstituents = shapedType.getDimSize(dim);
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
- operands.reserve(shapedType.getDimSize(dim) + 2);
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
- for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
- index[dim] = i;
+
+ // "If the Result Type is a cooperative matrix type, then there must be only
+ // one Constituent, with scalar type matching the cooperative matrix Component
+ // Type, and all components of the matrix are initialized to that value."
+ // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
+ if (isa<spirv::CooperativeMatrixType>(constType)) {
+ // numberOfConstituents is 1, so we only need one more elements in the
+ // SmallVector, so the total is 3 (1 + 2).
+ operands.reserve(3);
+ // We set dim directly to `shapedType.getRank()` so the recursive call
+ // directly returns the scalar type.
if (auto elementID = prepareDenseElementsConstant(
- loc, elementType, valueAttr, dim + 1, index)) {
+ loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
operands.push_back(elementID);
} else {
return 0;
}
+ } else {
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ index[dim] = i;
+ if (auto elementID = prepareDenseElementsConstant(
+ loc, elementType, valueAttr, dim + 1, index)) {
+ operands.push_back(elementID);
+ } else {
+ return 0;
+ }
+ }
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f05..a018692afab81 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -277,4 +277,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%signed_minus_one = spirv.Constant -1 : si16
spirv.ReturnValue %signed_minus_one : si16
}
+
+ // CHECK-LABEL: @coop_matrix_const
+ spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ }
}
>From 03402e9b2c6cf0825d4781d388b0a849715c9e51 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 9 Jun 2025 18:11:16 +0100
Subject: [PATCH 2/2] Add non-zero test
---
mlir/test/Target/SPIRV/constant.mlir | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index a018692afab81..05a1001b39f9e 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -278,10 +278,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %signed_minus_one : si16
}
- // CHECK-LABEL: @coop_matrix_const
- spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+ // CHECK-LABEL: @coop_matrix_const_zero
+ spirv.func @coop_matrix_const_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
// CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
%coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
}
+
+ // CHECK-LABEL: @coop_matrix_const_non_zero
+ spirv.func @coop_matrix_const_non_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+ // CHECK: {{%.*}} = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ %coop = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+ }
}
More information about the Mlir-commits
mailing list