[Mlir-commits] [mlir] [mlir][spirv] Add spirv.GL.MatrixInverse Op (PR #193594)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 22 13:59:01 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jack (jack-slingsby)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/193594.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td (+32) 
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+13) 
- (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+26) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 01fe12a4660af..7aeabd297211a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -441,6 +441,38 @@ def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_F
 
 // -----
 
+def SPIRV_GLMatrixInverseOp : SPIRV_GLOp<"MatrixInverse", 34,
+    [Pure, SameOperandsAndResultType]> {
+  let summary = "Compute the inverse of a matrix";
+
+  let description = [{
+    Result is the inverse of the operand. The operand x must be a square matrix
+    of floating-point type.
+
+    Result Type and the type of x must be the same type.
+
+    #### Example:
+
+    ```mlir
+    %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>>
+    ```
+  }];
+
+  let arguments = (ins
+    SPIRV_MatrixOf<SPIRV_Float>:$matrix
+  );
+
+  let results = (outs
+    SPIRV_MatrixOf<SPIRV_Float>:$result
+  );
+
+  let assemblyFormat = "$matrix attr-dict `:` type($matrix)";
+
+  let hasVerifier = 1;
+}
+
+// -----
+
 def SPIRV_GLLogOp : SPIRV_GLUnaryArithmeticOp<"Log", 28, SPIRV_Float16or32> {
   let summary = "Natural logarithm of the operand";
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9300483a0f92f..befbb2841fc8b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2054,6 +2054,19 @@ LogicalResult spirv::GLLdexpOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.GL.MatrixInverse
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GLMatrixInverseOp::verify() {
+  auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+  if (matrixType.getNumColumns() != matrixType.getNumRows())
+    return emitOpError("matrix must be square, got ")
+           << matrixType.getNumColumns() << " columns and "
+           << matrixType.getNumRows() << " rows";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.ShiftLeftLogicalOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index eea80ca3798a6..fde1bda15589a 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -127,6 +127,32 @@ func.func @inversesqrtvec(%arg0 : vector<3xf16>) -> () {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// spirv.GL.MatrixInverse
+//===----------------------------------------------------------------------===//
+
+func.func @matrix_inverse(%matrix : !spirv.matrix<4 x vector<4xf32>>) -> () {
+  // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<4 x vector<4xf32>>
+  %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>>
+  return
+}
+
+func.func @matrix_inverse_2x2(%matrix : !spirv.matrix<2 x vector<2xf32>>) -> () {
+  // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<2 x vector<2xf32>>
+  %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<2 x vector<2xf32>>
+  return
+}
+
+// -----
+
+func.func @matrix_inverse_non_square(%matrix : !spirv.matrix<3 x vector<4xf32>>) -> () {
+  // expected-error @+1 {{matrix must be square, got 3 columns and 4 rows}}
+  %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<3 x vector<4xf32>>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.GL.Sqrt
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/193594


More information about the Mlir-commits mailing list