[Mlir-commits] [mlir] [mlir][spirv] Reject coop matrix operands on unsupported arithmetic ops (PR #147230)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 6 19:17:41 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Darren Wihandi (fairywreath)

<details>
<summary>Changes</summary>

Cooperative matrix operands are only supported for `add/sub/mul/div` binary arithmetic ops, but currently all binary arithmetic ops accept cooperative matrix operands, including `mod/rem`. This change fixes this behaviour.


---
Full diff: https://github.com/llvm/llvm-project/pull/147230.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+38-19) 
- (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+43) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 309079e549846..2601debce3520 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -26,6 +26,25 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOf<type>:$operand1,
+    SPIRV_ScalarOrVectorOf<type>:$operand2
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOf<type>:$result
+  );
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
+class SPIRV_ArithmeticBinaryOpWithCoopMatrix<string mnemonic, Type type,
+                                             list<Trait> traits = []> :
+      // Operands type same as result type.
+      SPIRV_BinaryOp<mnemonic, type, type,
+                   !listconcat(traits,
+                               [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
+  // In addition to normal types these arithmetic instructions can support
+  // cooperative matrix.
   let arguments = (ins
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
     SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
@@ -82,7 +101,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
 
 // -----
 
-def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
+def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FAdd", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -104,7 +123,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
 
 // -----
 
-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
+def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FDiv", SPIRV_Float, []> {
   let summary = "Floating-point division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -154,7 +173,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
+def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FMul", SPIRV_Float, [Commutative]> {
   let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -229,7 +248,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
+def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"FSub", SPIRV_Float, []> {
   let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -251,9 +270,9 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
 
 // -----
 
-def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IAdd",
+                                                          SPIRV_Integer,
+                                                          [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer addition of Operand 1 and Operand 2.";
 
   let description = [{
@@ -322,9 +341,9 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",
 
 // -----
 
-def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
-                                        SPIRV_Integer,
-                                        [Commutative, UsableInSpecConstantOp]> {
+def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"IMul",
+                                                          SPIRV_Integer,
+                                                          [Commutative, UsableInSpecConstantOp]> {
   let summary = "Integer multiplication of Operand 1 and Operand 2.";
 
   let description = [{
@@ -354,9 +373,9 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
 
 // -----
 
-def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"ISub",
+                                                          SPIRV_Integer,
+                                                          [UsableInSpecConstantOp]> {
   let summary = "Integer subtraction of Operand 2 from Operand 1.";
 
   let description = [{
@@ -460,9 +479,9 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
 
 // -----
 
-def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
-                                        SPIRV_Integer,
-                                        [UsableInSpecConstantOp]> {
+def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"SDiv",
+                                                          SPIRV_Integer,
+                                                          [UsableInSpecConstantOp]> {
   let summary = "Signed-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -622,9 +641,9 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",
 
 // -----
 
-def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
-                                        SPIRV_Integer,
-                                        [UnsignedOp, UsableInSpecConstantOp]> {
+def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOpWithCoopMatrix<"UDiv",
+                                                          SPIRV_Integer,
+                                                          [UnsignedOp, UsableInSpecConstantOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 8733ff93768ab..6aff7b5039638 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -549,3 +549,46 @@ spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, Matrix
   %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
   spirv.Return
 }
+
+// -----
+
+// These binary arithmetic instructions do not support coop matrix operands.
+
+spirv.func @fmod(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %p = spirv.FMod %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @frem(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
+  %p = spirv.FRem %a, %b : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+spirv.func @smod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.SMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @srem(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.SRem %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @umod(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>) "None" {
+  // expected-error @+1 {{operand #0 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values of length 2/3/4/8/16}}
+  %p = spirv.UMod %a, %b : !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/147230


More information about the Mlir-commits mailing list