[Mlir-commits] [mlir] [mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix (PR #142786)

Igor Wodiany llvmlistbot at llvm.org
Tue Jun 10 06:04:12 PDT 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/142786

>From 8ae14d8efa667da885d5eb24b5562e44a1a43b06 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 3 Jun 2025 17:55:50 +0100
Subject: [PATCH 1/3] [mlir][spirv] Deserialize OpConstantComposite of type
 Cooperative Matrix

---
 .../SPIRV/Deserialization/Deserializer.cpp    |  6 ++--
 .../Target/SPIRV/Serialization/Serializer.cpp | 28 ++++++++++++++++---
 mlir/test/Target/SPIRV/constant.mlir          |  7 +++++
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 3957dbc0db984..c43d584d7b913 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   }
 
   auto resultID = operands[1];
-  if (auto vectorType = dyn_cast<VectorType>(resultType)) {
-    auto attr = DenseElementsAttr::get(vectorType, elements);
+  if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+    auto attr = DenseElementsAttr::get(shapedType, elements);
     // For normal constants, we just record the attribute (and its type) for
     // later materialization at use sites.
-    constantMap.try_emplace(resultID, attr, resultType);
+    constantMap.try_emplace(resultID, attr, shapedType);
   } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
     auto attr = opBuilder.getArrayAttr(elements);
     constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 15e06616f4492..83e6c7ea7af1d 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
     return 0;
   }
 
+  int64_t numberOfConstituents = shapedType.getDimSize(dim);
   uint32_t resultID = getNextID();
   SmallVector<uint32_t, 4> operands = {typeID, resultID};
-  operands.reserve(shapedType.getDimSize(dim) + 2);
   auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
-  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
-    index[dim] = i;
+
+  // "If the Result Type is a cooperative matrix type, then there must be only
+  // one Constituent, with scalar type matching the cooperative matrix Component
+  // Type, and all components of the matrix are initialized to that value."
+  // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
+  if (isa<spirv::CooperativeMatrixType>(constType)) {
+    // numberOfConstituents is 1, so we only need one more elements in the
+    // SmallVector, so the total is 3 (1 + 2).
+    operands.reserve(3);
+    // We set dim directly to `shapedType.getRank()` so the recursive call
+    // directly returns the scalar type.
     if (auto elementID = prepareDenseElementsConstant(
-            loc, elementType, valueAttr, dim + 1, index)) {
+            loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
       operands.push_back(elementID);
     } else {
       return 0;
     }
+  } else {
+    operands.reserve(numberOfConstituents + 2);
+    for (int i = 0; i < numberOfConstituents; ++i) {
+      index[dim] = i;
+      if (auto elementID = prepareDenseElementsConstant(
+              loc, elementType, valueAttr, dim + 1, index)) {
+        operands.push_back(elementID);
+      } else {
+        return 0;
+      }
+    }
   }
   spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
   encodeInstructionInto(typesGlobalValues, opcode, operands);
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index f3950214a7f05..a018692afab81 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -277,4 +277,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %signed_minus_one = spirv.Constant -1 : si16
     spirv.ReturnValue %signed_minus_one : si16
   }
+
+  // CHECK-LABEL: @coop_matrix_const
+  spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  }
 }

>From 03402e9b2c6cf0825d4781d388b0a849715c9e51 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Mon, 9 Jun 2025 18:11:16 +0100
Subject: [PATCH 2/3] Add non-zero test

---
 mlir/test/Target/SPIRV/constant.mlir | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index a018692afab81..05a1001b39f9e 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -278,10 +278,17 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %signed_minus_one : si16
   }
 
-  // CHECK-LABEL: @coop_matrix_const
-  spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+  // CHECK-LABEL: @coop_matrix_const_zero
+  spirv.func @coop_matrix_const_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
   }
+
+  // CHECK-LABEL: @coop_matrix_const_non_zero
+  spirv.func @coop_matrix_const_non_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    %coop = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  }
 }

>From 91797d662816d8856483c3eebd47727315563e14 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Tue, 10 Jun 2025 12:50:38 +0100
Subject: [PATCH 3/3] Verify non-splat

---
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        |  7 ++++
 .../Target/SPIRV/Serialization/Serializer.cpp |  6 ++++
 mlir/test/Dialect/SPIRV/IR/structure-ops.mlir | 33 +++++++++++++++++++
 mlir/test/Target/SPIRV/constant.mlir          | 24 +++++++++++---
 4 files changed, 65 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 3d2cb1dd7a032..7148027dae78d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -558,6 +558,13 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) {
 
 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
                                         Type opType) {
+  if (isa<spirv::CooperativeMatrixType>(opType)) {
+    auto denseAttr = dyn_cast<DenseElementsAttr>(value);
+    if (!denseAttr || !denseAttr.isSplat())
+      return op.emitOpError("expected a splat dense attribute for cooperative "
+                            "matrix constant, but found ")
+             << denseAttr;
+  }
   if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
     auto valueType = llvm::cast<TypedAttr>(value).getType();
     if (valueType != opType)
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 83e6c7ea7af1d..647535809554c 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -855,6 +855,12 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
   // Type, and all components of the matrix are initialized to that value."
   // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
   if (isa<spirv::CooperativeMatrixType>(constType)) {
+    if (!valueAttr.isSplat()) {
+      emitError(
+          loc,
+          "cannot serialize a non-splat value for a cooperative matrix type");
+      return 0;
+    }
     // numberOfConstituents is 1, so we only need one more elements in the
     // SmallVector, so the total is 3 (1 + 2).
     operands.reserve(3);
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 5e98b9fdb3c54..207549afdda94 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -62,6 +62,10 @@ func.func @const() -> () {
   // CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   // CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   // CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
+  // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  // CHECK: spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  // CHECK: spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+  // CHECK: spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
 
   %0 = spirv.Constant true
   %1 = spirv.Constant 42 : i32
@@ -73,6 +77,10 @@ func.func @const() -> () {
   %7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>>
   %8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>>
   %9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>>
+  %10 = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  %11 = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+  %12 = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+  %13 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
   return
 }
 
@@ -132,6 +140,31 @@ func.func @value_result_num_elements_mismatch() -> () {
 
 // -----
 
+func.func @coop_matrix_const_non_splat() -> () {
+    // expected-error @+1 {{expected a splat dense attribute for cooperative matrix constant, but found}}
+    %0 = spirv.Constant dense<[[1.0, 2.0], [3.0, 4.0]]> : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
+    return
+}
+
+// -----
+
+func.func @coop_matrix_const_non_dense() -> () {
+    // expected-error @+2 {{floating point value not valid for specified type}}
+    %0 = spirv.Constant 0.000000e+00 : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    return
+}
+
+// -----
+
+func.func @coop_matrix_const_wrong_type() -> () {
+    // expected-error @below {{unexpected decimal integer literal for a floating point value}}
+    // expected-note @+1 {{add a trailing dot to make the literal a float}}
+    %0 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
+    return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.EntryPoint
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir
index 05a1001b39f9e..8d4e53418b70f 100644
--- a/mlir/test/Target/SPIRV/constant.mlir
+++ b/mlir/test/Target/SPIRV/constant.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s
 
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @bool_const
@@ -278,17 +278,31 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %signed_minus_one : si16
   }
 
-  // CHECK-LABEL: @coop_matrix_const_zero
-  spirv.func @coop_matrix_const_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+  // CHECK-LABEL: @coop_matrix_const_zero_f32
+  spirv.func @coop_matrix_const_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
   }
 
-  // CHECK-LABEL: @coop_matrix_const_non_zero
-  spirv.func @coop_matrix_const_non_zero() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
+  // CHECK-LABEL: @coop_matrix_const_non_zero_f32
+  spirv.func @coop_matrix_const_non_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
     // CHECK: {{%.*}} = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     %coop = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
     spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
   }
+
+  // CHECK-LABEL: @coop_matrix_const_zero_i8
+  spirv.func @coop_matrix_const_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+    %coop = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+    spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+  }
+
+  // CHECK-LABEL: @coop_matrix_const_non_zero_i8
+  spirv.func @coop_matrix_const_non_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" {
+    // CHECK: {{%.*}} = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+    %coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+    spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>
+  }
 }



More information about the Mlir-commits mailing list