[Mlir-commits] [mlir] [mlir][spirv] Move some of the verification for Matrix ops to ODS (PR #185597)

Igor Wodiany llvmlistbot at llvm.org
Tue Mar 10 01:59:59 PDT 2026


https://github.com/IgWod created https://github.com/llvm/llvm-project/pull/185597

This moves C++ verification to ODS where it is possible to use existing constraints. A subsequent patch will focus on removing all C++ verification introducing new classes when required.

Assisted-by: Codex

>From c9845bf69bd495b5d9ce0c9faac2c5089e73e23d Mon Sep 17 00:00:00 2001
From: Igor Wodiany <dev at wodiany.com>
Date: Fri, 6 Mar 2026 21:53:23 +0000
Subject: [PATCH] [mlir][spirv] Move some of the verification for Matrix ops to
 ODS

This moves C++ verification to ODS where it is possible to use existing
constraints. A subsequent patch will focus on removing all C++ verification
introducing new classes when required.

Assisted-by: Codex
---
 .../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td   | 22 ++++++++--
 mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp        | 43 -------------------
 .../SPIRV/IR/khr-cooperative-matrix-ops.mlir  |  2 +-
 mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir    | 20 ++++++---
 4 files changed, 33 insertions(+), 54 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index f2796861cdf56..dc363de3eff31 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -16,7 +16,10 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // -----
 
-def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
+def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [
+    Pure,
+    AllElementTypesMatch<["leftmatrix", "rightmatrix", "result"]>
+  ]> {
   let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix.";
 
   let description = [{
@@ -63,7 +66,11 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
 
 // -----
 
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [
+    Pure,
+    AllTypesMatch<["matrix", "result"]>,
+    AllElementTypesMatch<["matrix", "scalar"]>
+  ]> {
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
@@ -109,12 +116,15 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMat
     Extension<[]>,
     Capability<[SPIRV_C_Matrix]>
   ];
+
+  let hasVerifier = 0;
 }
 
 // -----
 
 def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
     Pure,
+    AllElementTypesMatch<["matrix", "result"]>,
     AllElementTypesMatch<["vector", "result"]>
   ]> {
   let summary = "Linear-algebraic Matrix X Vector.";
@@ -157,7 +167,10 @@ def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
 
 // -----
 
-def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
+def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [
+    Pure,
+    AllElementTypesMatch<["matrix", "result"]>
+  ]> {
   let summary = "Transpose a matrix.";
 
   let description = [{
@@ -202,7 +215,8 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
 
 def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [
     Pure,
-    AllElementTypesMatch<["vector", "result"]>
+    AllElementTypesMatch<["vector", "result"]>,
+    AllElementTypesMatch<["matrix", "result"]>
   ]> {
   let summary = "Linear-algebraic Vector X Matrix.";
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 34e06bf52f70d..4b993ef700c38 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1697,27 +1697,6 @@ LogicalResult spirv::VectorShuffleOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// spirv.MatrixTimesScalar
-//===----------------------------------------------------------------------===//
-
-LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  Type elementType =
-      llvm::TypeSwitch<Type, Type>(getMatrix().getType())
-          .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
-              [](auto matrixType) { return matrixType.getElementType(); })
-          .Default(nullptr);
-
-  assert(elementType && "Unhandled type");
-
-  // Check that the scalar type is the same as the matrix element type.
-  if (getScalar().getType() != elementType)
-    return emitOpError("input matrix components' type and scaling value must "
-                       "have the same type");
-
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // spirv.Transpose
 //===----------------------------------------------------------------------===//
@@ -1735,11 +1714,6 @@ LogicalResult spirv::TransposeOp::verify() {
     return emitError("input matrix columns count must be equal to "
                      "output matrix rows count");
 
-  // Verify that the input and output matrices have the same component type
-  if (inputMatrix.getElementType() != resultMatrix.getElementType())
-    return emitError("input and output matrices must have the same "
-                     "component type");
-
   return success();
 }
 
@@ -1762,9 +1736,6 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
            << resultType.getNumElements() << ") must match the matrix rows ("
            << matrixType.getNumRows() << ")";
 
-  if (matrixType.getElementType() != resultType.getElementType())
-    return emitOpError("matrix and result element types must match");
-
   return success();
 }
 
@@ -1785,10 +1756,6 @@ LogicalResult spirv::VectorTimesMatrixOp::verify() {
     return emitOpError("number of columns in matrix must equal the number of "
                        "components in result");
 
-  if (matrixType.getElementType() != resultType.getElementType())
-    return emitOpError("matrix must be a matrix with the same component type "
-                       "as the component type in result");
-
   return success();
 }
 
@@ -1811,16 +1778,6 @@ LogicalResult spirv::MatrixTimesMatrixOp::verify() {
     return emitError(
         "right and result matrices must have equal columns' count");
 
-  // right and result matrices component type must be the same
-  if (rightMatrix.getElementType() != resultMatrix.getElementType())
-    return emitError("right and result matrices' component type must"
-                     " be the same");
-
-  // left and result matrices component type must be the same
-  if (leftMatrix.getElementType() != resultMatrix.getElementType())
-    return emitError("left and result matrices' component type"
-                     " must be the same");
-
   // left and result matrices rows count must be the same
   if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
     return emitError("left and result matrices must have equal rows' count");
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 491c7a7758ce1..69235eab8d0dc 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -573,7 +573,7 @@ spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
 // -----
 
 spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
-  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
   %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
   spirv.Return
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index ba95322dbf38b..9996cb2dc0bef 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -61,7 +61,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 // -----
 
 func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
-  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
   return
 }
@@ -69,7 +69,7 @@ func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 :
 // -----
 
 func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
-  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
+  // expected-error @+1 {{op failed to verify that all of {matrix, scalar} have same element type}}
   %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
   return
 }
@@ -93,7 +93,7 @@ func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>
 // -----
 
 func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
-   // expected-error @+1 {{input and output matrices must have the same component type}}
+   // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
    return
 }
@@ -125,7 +125,7 @@ func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x v
 // -----
 
 func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
-   // expected-error @+1 {{right and result matrices' component type must be the same}}
+   // expected-error @+1 {{op failed to verify that all of {leftmatrix, rightmatrix, result} have same element type}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
    return
 }
@@ -133,7 +133,7 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
 // -----
 
 func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
-   // expected-error @+1 {{left and result matrices' component type must be the same}}
+   // expected-error @+1 {{op failed to verify that all of {leftmatrix, rightmatrix, result} have same element type}}
    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
    return
 }
@@ -148,6 +148,14 @@ func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x ve
 
 // -----
 
+func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf16>>, %arg1: vector<4xf32>) {
+  // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
+  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf16>>, vector<4xf32> -> vector<3xf32>
+  return
+}
+
+// -----
+
 func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
   // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
   %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
@@ -189,7 +197,7 @@ func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1:
 // -----
 
 func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
-  // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+  // expected-error @+1 {{op failed to verify that all of {matrix, result} have same element type}}
   %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
   return
 }



More information about the Mlir-commits mailing list