[Mlir-commits] [mlir] [mlir][spirv] Allow CooperativeMatrixType in Bitcast (PR #196096)
Igor Wodiany
llvmlistbot at llvm.org
Wed May 6 08:30:28 PDT 2026
https://github.com/IgWod updated https://github.com/llvm/llvm-project/pull/196096
>From 505baaecce35f407bf4474d05f059d5ebb003971 Mon Sep 17 00:00:00 2001
From: Igor Wodiany <dev at wodiany.com>
Date: Tue, 24 Mar 2026 20:27:16 +0000
Subject: [PATCH] [mlir][spirv] Allow CooperativeMatrixType in Bitcast
This makes is consistent with the spec: "Allow the use of OpBitcast
on objects of cooperative matrix type whose Component Type are integer
types with the same Width."
Assisted-by: Codex
---
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 2 +
.../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td | 12 ++---
mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 30 ++++++++++++
mlir/test/Dialect/SPIRV/IR/cast-ops.mlir | 46 +++++++++++++++++++
mlir/test/Target/SPIRV/cast-ops.mlir | 10 ++++
5 files changed, 94 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 8badb84a879fa..c903da7c1b6f7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4346,6 +4346,8 @@ class SPIRV_MatrixOf<Type type> :
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
+def SPIRV_ScalarOrVectorOrPtrOrCoopMatrix :
+ AnyTypeOf<[SPIRV_ScalarOrVectorOrPtr, SPIRV_AnyCooperativeMatrix]>;
class SPIRV_Vec4<Type type> : VectorOfRankAndLengthAndType<[1], [4], [type]>;
def SPIRV_IntVec4 : SPIRV_Vec4<SPIRV_Integer>;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
index f05030dfb5d57..a434764750a56 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
@@ -40,11 +40,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
let summary = "Bit pattern-preserving type conversion.";
let description = [{
- Result Type must be an OpTypePointer, or a scalar or vector of
- numerical-type.
+ Result Type must be an OpTypePointer, cooperative matrix, or a scalar or
+ vector of numerical-type.
- Operand must have a type of OpTypePointer, or a scalar or vector of
- numerical-type. It must be a different type than Result Type.
+ Operand must have a type of OpTypePointer, cooperative matrix, or a scalar
+ or vector of numerical-type. It must be a different type than Result Type.
If either Result Type or Operand is a pointer, the other must be a
pointer (diverges from the SPIR-V spec).
@@ -71,11 +71,11 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
}];
let arguments = (ins
- SPIRV_ScalarOrVectorOrPtr:$operand
+ SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$operand
);
let results = (outs
- SPIRV_ScalarOrVectorOrPtr:$result
+ SPIRV_ScalarOrVectorOrPtrOrCoopMatrix:$result
);
let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index a5330dc56d48f..ea8cb8de74536 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -87,6 +87,36 @@ LogicalResult BitcastOp::verify() {
if (operandType == resultType) {
return emitError("result type must be different from operand type");
}
+
+ auto operandCoopMatrixType =
+ dyn_cast<spirv::CooperativeMatrixType>(operandType);
+ auto resultCoopMatrixType =
+ dyn_cast<spirv::CooperativeMatrixType>(resultType);
+ if (operandCoopMatrixType || resultCoopMatrixType) {
+ if (!operandCoopMatrixType || !resultCoopMatrixType)
+ return emitError("unhandled bit cast conversion from cooperative matrix "
+ "type to non-cooperative matrix type");
+
+ if (operandCoopMatrixType.getRows() != resultCoopMatrixType.getRows() ||
+ operandCoopMatrixType.getColumns() != resultCoopMatrixType.getColumns())
+ return emitError("cooperative matrix dimensions must match");
+
+ if (operandCoopMatrixType.getScope() != resultCoopMatrixType.getScope())
+ return emitError("cooperative matrix scope must match");
+
+ if (operandCoopMatrixType.getUse() != resultCoopMatrixType.getUse())
+ return emitError("cooperative matrix use must match");
+
+ unsigned operandBitWidth =
+ getBitWidth(operandCoopMatrixType.getElementType());
+ unsigned resultBitWidth =
+ getBitWidth(resultCoopMatrixType.getElementType());
+ if (operandBitWidth != resultBitWidth)
+ return emitOpError("mismatch in result and operand type bitwidth");
+
+ return success();
+ }
+
if (isa<spirv::PointerType>(operandType) &&
!isa<spirv::PointerType>(resultType)) {
return emitError(
diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
index 4480a1f3720f2..6614a49a1253e 100644
--- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
@@ -40,6 +40,12 @@ func.func @cast6(%arg0 : vector<4xf32>) {
return
}
+func.func @cast_coop_matrix(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+ return
+}
+
// -----
func.func @cast1(%arg0 : f32) {
@@ -82,6 +88,46 @@ func.func @cast3(%arg0 : i64) {
// -----
+func.func @cast_coop_matrix_size_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // expected-error @+1 {{cooperative matrix dimensions must match}}
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x2xi32, Subgroup, MatrixA>
+ return
+}
+
+// -----
+
+func.func @cast_coop_matrix_scope_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // expected-error @+1 {{cooperative matrix scope must match}}
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Workgroup, MatrixA>
+ return
+}
+
+// -----
+
+func.func @cast_coop_matrix_use_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // expected-error @+1 {{cooperative matrix use must match}}
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixB>
+ return
+}
+
+// -----
+
+func.func @cast_coop_matrix_bitwidth_mismatch(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // expected-error @+1 {{mismatch in result and operand type bitwidth}}
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi64, Subgroup, MatrixA>
+ return
+}
+
+// -----
+
+func.func @cast_coop_to_non_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) {
+ // expected-error @+1 {{unhandled bit cast conversion from cooperative matrix type to non-cooperative matrix type}}
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to i32
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.ConvertFToS
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir
index 4f29610f928c4..317264d37b8c8 100644
--- a/mlir/test/Target/SPIRV/cast-ops.mlir
+++ b/mlir/test/Target/SPIRV/cast-ops.mlir
@@ -19,6 +19,16 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// -----
+spirv.module Logical Vulkan requires #spirv.vce<v1.6, [Shader, Linkage, CooperativeMatrixKHR, VulkanMemoryModel], [SPV_KHR_cooperative_matrix, SPV_KHR_vulkan_memory_model]> {
+ spirv.func @bit_cast_coop(%arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>) "None" {
+ // CHECK: {{%.*}} = spirv.Bitcast {{%.*}} : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+ %0 = spirv.Bitcast %arg0 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> to !spirv.coopmatrix<4x4xi32, Subgroup, MatrixA>
+ spirv.Return
+ }
+}
+
+// -----
+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, BFloat16TypeKHR, Float64, Int64], [SPV_KHR_bfloat16]> {
spirv.func @convert_f_to_s(%arg0 : f32) -> i32 "None" {
// CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : f32 to i32
More information about the Mlir-commits
mailing list