[Mlir-commits] [mlir] [mlir][spirv] Move some of the verification for Matrix ops to ODS (PR #185597)
Igor Wodiany
llvmlistbot at llvm.org
Tue Mar 10 01:59:59 PDT 2026
https://github.com/IgWod created https://github.com/llvm/llvm-project/pull/185597
This moves C++ verification to ODS where it is possible to use existing constraints. A subsequent patch will focus on removing all C++ verification introducing new classes when required.
Assisted-by: Codex
>From c9845bf69bd495b5d9ce0c9faac2c5089e73e23d Mon Sep 17 00:00:00 2001
From: Igor Wodiany <dev at wodiany.com>
Date: Fri, 6 Mar 2026 21:53:23 +0000
Subject: [PATCH] [mlir][spirv] Move some of the verification for Matrix ops to
ODS
This moves C++ verification to ODS where it is possible to use existing
constraints. A subsequent patch will focus on removing all C++ verification
introducing new classes when required.
Assisted-by: Codex
---
.../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td | 22 ++++++++--
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 43 -------------------
.../SPIRV/IR/khr-cooperative-matrix-ops.mlir | 2 +-
mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir | 20 ++++++---
4 files changed, 33 insertions(+), 54 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index f2796861cdf56..dc363de3eff31 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -16,7 +16,10 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
-def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
+def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [
+ Pure,
+ AllElementTypesMatch<["leftmatrix", "rightmatrix", "result"]>
+ ]> {
let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
let description = [{
@@ -63,7 +66,11 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
// -----
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [
+ Pure,
+ AllTypesMatch<["matrix", "result"]>,
+ AllElementTypesMatch<["matrix", "scalar"]>
+ ]> {
let summary = "Scale a floating-point matrix.";
let description = [{
@@ -109,12 +116,15 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMat
Extension<[]>,
Capability<[SPIRV_C_Matrix]>
];
+
+ let hasVerifier = 0;
}
// -----
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
Pure,
+ AllElementTypesMatch<["matrix", "result"]>,
AllElementTypesMatch<["vector", "result"]>
]> {
let summary = "Linear-algebraic Matrix X Vector.";
@@ -157,7 +167,10 @@ def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
// -----
-def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
+def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [
+ Pure,
+ AllElementTypesMatch<["matrix", "result"]>
+ ]> {
let summary = "Transpose a matrix.";
let description = [{
@@ -202,7 +215,8 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [
Pure,
- AllElementTypesMatch<["vector", "result"]>
+ AllElementTypesMatch<["vector", "result"]>,
+ AllElementTypesMatch<["matrix", "result"]>
]> {
let summary = "Linear-algebraic Vector X Matrix.";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 34e06bf52f70d..4b993ef700c38 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1697,27 +1697,6 @@ LogicalResult spirv::VectorShuffleOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// spirv.MatrixTimesScalar
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::MatrixTimesScalarOp::verify() {
- Type elementType =
- llvm::TypeSwitch<Type, Type>(getMatrix().getType())
- .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
- [](auto matrixType) { return matrixType.getElementType(); })
- .Default(nullptr);
-
- assert(elementType && "Unhandled type");
-
- // Check that the scalar type is the same as the matrix element type.
- if (getScalar().getType() != elementType)
- return emitOpError("input matrix components' type and scaling value must "
- "have the same type");
-
- return success();
-}
-
//===----------------------------------------------------------------------===//
// spirv.Transpose
//===----------------------------------------------------------------------===//
@@ -1735,11 +1714,6 @@ LogicalResult spirv::TransposeOp::verify() {
return emitError("input matrix columns count must be equal to "
"output matrix rows count");
- // Verify that the input and output matrices have the same component type
- if (inputMatrix.getElementType() != resultMatrix.getElementType())
- return emitError("input and output matrices must have the same "
- "component type");
-
return success();
}
@@ -1762,9 +1736,6 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
<< resultType.getNumElements() << ") must match the matrix rows ("
<< matrixType.getNumRows() << ")";
- if (matrixType.getElementType() != resultType.getElementType())
- return emitOpError("matrix and result element types must match");
-
return success();
}
@@ -1785,10 +1756,6 @@ LogicalResult spirv::VectorTimesMatrixOp::verify() {
return emitOpError("number of columns in matrix must equal the number of "
"components in result");
- if (matrixType.getElementType() != resultType.getElementType())
- return emitOpError("matrix must be a matrix with the same component type "
- "as the component type in result");
-
return success();
}
@@ -1811,16 +1778,6 @@ LogicalResult spirv::MatrixTimesMatrixOp::verify() {
return emitError(
"right and result matrices must have equal columns' count");
- // right and result matrices component type must be the same
- if (rightMatrix.getElementType() != resultMatrix.getElementType())
- return emitError("right and result matrices' component type must"
- " be the same");
-
- // left and result matrices component type must be the same
- if (leftMatrix.getElementType() != resultMatrix.getElementType())
- return emitError("left and result matrices' component type"
- " must be the same");
-
// left and result matrices rows count must be the same
if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
return emitError("left and result matrices must have equal rows' count");
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 491c7a7758ce1..69235eab8d0dc 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -573,7 +573,7 @@ spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
// -----
spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
- // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+ // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
%p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index ba95322dbf38b..9996cb2dc0bef 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
- // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+ // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
return
}
@@ -69,7 +69,7 @@ func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 :
// -----
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
- // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+ // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
return
}
@@ -93,7 +93,7 @@ func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>
// -----
func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
- // expected-error @+1 {{input and output matrices must have the same component type}}
+ // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
return
}
@@ -125,7 +125,7 @@ func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x v
// -----
func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
- // expected-error @+1 {{right and result matrices' component type must be the same}}
+ // expected-error @+1 {{op failed to verify that all of {leftmatrix, rightmatrix, result} have same element type}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
return
}
@@ -133,7 +133,7 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
// -----
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
- // expected-error @+1 {{left and result matrices' component type must be the same}}
+ // expected-error @+1 {{op failed to verify that all of {leftmatrix, rightmatrix, result} have same element type}}
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
return
}
@@ -148,6 +148,14 @@ func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x ve
// -----
+func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf16>>, %arg1: vector<4xf32>) {
+ // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
+ %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf16>>, vector<4xf32> -> vector<3xf32>
+ return
+}
+
+// -----
+
func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
// expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
@@ -189,7 +197,7 @@ func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1:
// -----
func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
- // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+ // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
return
}
More information about the Mlir-commits
mailing list