[Mlir-commits] [mlir] [mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix (PR #142786)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 9 09:28:52 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

<details>
<summary>Changes</summary>

Depends on #<!-- -->142784.

---
Full diff: https://github.com/llvm/llvm-project/pull/142786.diff


3 Files Affected:

- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+3-3) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+24-4) 
- (modified) mlir/test/Target/SPIRV/constant.mlir (+7) 


``````````diff
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3957dbc0db984..c43d584d7b913 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
-  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
-    auto attr = DenseElementsAttr::get(vectorType, elements);
+  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+    auto attr = DenseElementsAttr::get(shapedType, elements);
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
-    constantMap.try_emplace(resultID, attr, resultType);
+    constantMap.try_emplace(resultID, attr, shapedType);
   } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
     auto attr = opBuilder.getArrayAttr(elements);
     constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..83e6c7ea7af1d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
     return 0;
   }
 
+  int64_t numberOfConstituents = shapedType.getDimSize(dim);
   uint32_t resultID = getNextID();
   SmallVector<uint32_t, 4> operands = {typeID, resultID};
-  operands.reserve(shapedType.getDimSize(dim) + 2);
   auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
-  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
-    index[dim] = i;
+
+  // "If the Result Type is a cooperative matrix type, then there must be only
+  // one Constituent, with scalar type matching the cooperative matrix Component
+  // Type, and all components of the matrix are initialized to that value."
+  // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
+  if (isa<spirv::CooperativeMatrixType>(constType)) {
+    // numberOfConstituents is 1, so we only need one more elements in the
+    // SmallVector, so the total is 3 (1 + 2).
+    operands.reserve(3);
+    // We set dim directly to `shapedType.getRank()` so the recursive call
+    // directly returns the scalar type.
     if (auto elementID = prepareDenseElementsConstant(
-            loc, elementType, valueAttr, dim + 1, index)) {
+            loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
       operands.push_back(elementID);
     } else {
       return 0;
     }
+  } else {
+    operands.reserve(numberOfConstituents + 2);
+    for (int i = 0; i < numberOfConstituents; ++i) {
+      index[dim] = i;
+      if (auto elementID = prepareDenseElementsConstant(
+              loc, elementType, valueAttr, dim + 1, index)) {
+        operands.push_back(elementID);
+      } else {
+        return 0;
+      }
+    }
   }
   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
   encodeInstructionInto(typesGlobalValues, opcode, operands);
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f05..a018692afab81 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -277,4 +277,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %signed_minus_one = spirv.Constant -1 : si16
     spirv.ReturnValue %signed_minus_one : si16
   }
+
+  // CHECK-LABEL: @coop_matrix_const
+  spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/142786


More information about the Mlir-commits mailing list