[Mlir-commits] [mlir] [mlir][spirv] Allow CooperativeMatrixType in Bitcast (PR #196096)

Igor Wodiany llvmlistbot at llvm.org
Wed May 6 08:30:28 PDT 2026


https://github.com/IgWod updated https://github.com/llvm/llvm-project/pull/196096

>From 505baaecce35f407bf4474d05f059d5ebb003971 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <dev at wodiany.com>
Date: Tue, 24 Mar 2026 20:27:16 +0000
Subject: [PATCH] [mlir][spirv] Allow CooperativeMatrixType in Bitcast

This makes is consistent with the spec: "Allow the use of OpBitcast
on objects of cooperative matrix type whose Component Type are integer
types with the same Width."

Assisted-by: Codex
---
 .../mlir/Dialect/SPIRV/IR/SPIRVBase.td        |  2 +
 .../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td     | 12 ++---
 mlir/lib/Dialect/SPIRV/IR/CastOps.cpp         | 30 ++++++++++++
 mlir/test/Dialect/SPIRV/IR/cast-ops.mlir      | 46 +++++++++++++++++++
 mlir/test/Target/SPIRV/cast-ops.mlir          | 10 ++++
 5 files changed, 94 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8badb84a879fa..c903da7c1b6f7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4346,6 +4346,8 @@ class SPIRV_MatrixOf<Type type> :
 
 def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
 def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
+def SPIRV_ScalarOrVectorOrPtrOrCoopMatrix :
+    AnyTypeOf<[SPIRV_ScalarOrVectorOrPtr, SPIRV_AnyCooperativeMatrix]>;
 
 class SPIRV_Vec4<Type type> : VectorOfRankAndLengthAndType<[1], [4], [type]>;
 def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index f05030dfb5d57..a434764750a56 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -40,11 +40,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
   let summary = "Bit pattern-preserving type conversion.";
 
   let description = [{
-    Result Type must be an OpTypePointer, or a scalar or vector of
-    numerical-type.
+    Result Type must be an OpTypePointer, cooperative matrix, or a scalar or
+    vector of numerical-type.
 
-    Operand must have a type of OpTypePointer, or a scalar or vector of
-    numerical-type. It must be a different type than Result Type.
+    Operand must have a type of OpTypePointer, cooperative matrix, or a scalar
+    or vector of numerical-type. It must be a different type than Result Type.
 
     If either Result Type or Operand is a pointer, the other must be a
     pointer (diverges from the SPIR-V spec).
@@ -71,11 +71,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
   }];
 
   let arguments = (ins
-    SPIRV_ScalarOrVectorOrPtr:$operand
+    SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$operand
   );
 
   let results = (outs
-    SPIRV_ScalarOrVectorOrPtr:$result
+    SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$result
   );
 
   let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index a5330dc56d48f..ea8cb8de74536 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -87,6 +87,36 @@ LogicalResult BitcastOp::verify() {
   if (operandType == resultType) {
     return emitError("result type must be different from operand type");
   }
+
+  auto operandCoopMatrixType =
+      dyn_cast<spirv::CooperativeMatrixType>(operandType);
+  auto resultCoopMatrixType =
+      dyn_cast<spirv::CooperativeMatrixType>(resultType);
+  if (operandCoopMatrixType || resultCoopMatrixType) {
+    if (!operandCoopMatrixType || !resultCoopMatrixType)
+      return emitError("unhandled bit cast conversion from cooperative matrix "
+                       "type to non-cooperative matrix type");
+
+    if (operandCoopMatrixType.getRows() != resultCoopMatrixType.getRows() ||
+        operandCoopMatrixType.getColumns() != resultCoopMatrixType.getColumns())
+      return emitError("cooperative matrix dimensions must match");
+
+    if (operandCoopMatrixType.getScope() != resultCoopMatrixType.getScope())
+      return emitError("cooperative matrix scope must match");
+
+    if (operandCoopMatrixType.getUse() != resultCoopMatrixType.getUse())
+      return emitError("cooperative matrix use must match");
+
+    unsigned operandBitWidth =
+        getBitWidth(operandCoopMatrixType.getElementType());
+    unsigned resultBitWidth =
+        getBitWidth(resultCoopMatrixType.getElementType());
+    if (operandBitWidth != resultBitWidth)
+      return emitOpError("mismatch in result and operand type bitwidth");
+
+    return success();
+  }
+
   if (isa<spirv::PointerType>(operandType) &&
       !isa<spirv::PointerType>(resultType)) {
     return emitError(
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4480a1f3720f2..6614a49a1253e 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -40,6 +40,12 @@ func.func @cast6(%arg0 : vector<4xf32>) {
   return
 }
 
+func.func @cast_coop_matrix(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+  return
+}
+
 // -----
 
 func.func @cast1(%arg0 : f32) {
@@ -82,6 +88,46 @@ func.func @cast3(%arg0 : i64) {
 
 // -----
 
+func.func @cast_coop_matrix_size_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // expected-error @+1 {{cooperative matrix dimensions must match}}
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x2xi32, Subgroup, MatrixA>
+  return
+}
+
+// -----
+
+func.func @cast_coop_matrix_scope_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // expected-error @+1 {{cooperative matrix scope must match}}
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Workgroup, MatrixA>
+  return
+}
+
+// -----
+
+func.func @cast_coop_matrix_use_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // expected-error @+1 {{cooperative matrix use must match}}
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixB>
+  return
+}
+
+// -----
+
+func.func @cast_coop_matrix_bitwidth_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // expected-error @+1 {{mismatch in result and operand type bitwidth}}
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi64, Subgroup, MatrixA>
+  return
+}
+
+// -----
+
+func.func @cast_coop_to_non_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+  // expected-error @+1 {{unhandled bit cast conversion from cooperative matrix type to non-cooperative matrix type}}
+  %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to i32
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.ConvertFToS
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 4f29610f928c4..317264d37b8c8 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -19,6 +19,16 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
 
 // -----
 
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [Shader, Linkage, CooperativeMatrixKHR, VulkanMemoryModel], [SPV_KHR_cooperative_matrix, SPV_KHR_vulkan_memory_model]> {
+  spirv.func @bit_cast_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) "None" {
+    // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+    %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+    spirv.Return
+  }
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, BFloat16TypeKHR, Float64, Int64], [SPV_KHR_bfloat16]> {
   spirv.func @convert_f_to_s(%arg0 : f32) -> i32 "None" {
     // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : f32 to i32



More information about the Mlir-commits mailing list