[Mlir-commits] [mlir] [mlir][spirv] Add spirv.GL.MatrixInverse Op (PR #193594)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 22 13:59:01 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jack (jack-slingsby)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/193594.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td (+32)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+13)
- (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+26)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
index 01fe12a4660af..7aeabd297211a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td
@@ -441,6 +441,38 @@ def SPIRV_GLInverseSqrtOp : SPIRV_GLUnaryArithmeticOp<"InverseSqrt", 32, SPIRV_F
// -----
+def SPIRV_GLMatrixInverseOp : SPIRV_GLOp<"MatrixInverse", 34,
+ [Pure, SameOperandsAndResultType]> {
+ let summary = "Compute the inverse of a matrix";
+
+ let description = [{
+ Result is the inverse of the operand. The operand x must be a square matrix
+ of floating-point type.
+
+ Result Type and the type of x must be the same type.
+
+ #### Example:
+
+ ```mlir
+ %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>>
+ ```
+ }];
+
+ let arguments = (ins
+ SPIRV_MatrixOf<SPIRV_Float>:$matrix
+ );
+
+ let results = (outs
+ SPIRV_MatrixOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = "$matrix attr-dict `:` type($matrix)";
+
+ let hasVerifier = 1;
+}
+
+// -----
+
def SPIRV_GLLogOp : SPIRV_GLUnaryArithmeticOp<"Log", 28, SPIRV_Float16or32> {
let summary = "Natural logarithm of the operand";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9300483a0f92f..befbb2841fc8b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -2054,6 +2054,19 @@ LogicalResult spirv::GLLdexpOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.GL.MatrixInverse
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::GLMatrixInverseOp::verify() {
+ auto matrixType = cast<spirv::MatrixType>(getMatrix().getType());
+ if (matrixType.getNumColumns() != matrixType.getNumRows())
+ return emitOpError("matrix must be square, got ")
+ << matrixType.getNumColumns() << " columns and "
+ << matrixType.getNumRows() << " rows";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogicalOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
index eea80ca3798a6..fde1bda15589a 100644
--- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir
@@ -127,6 +127,32 @@ func.func @inversesqrtvec(%arg0 : vector<3xf16>) -> () {
// -----
+//===----------------------------------------------------------------------===//
+// spirv.GL.MatrixInverse
+//===----------------------------------------------------------------------===//
+
+func.func @matrix_inverse(%matrix : !spirv.matrix<4 x vector<4xf32>>) -> () {
+ // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<4 x vector<4xf32>>
+ %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<4 x vector<4xf32>>
+ return
+}
+
+func.func @matrix_inverse_2x2(%matrix : !spirv.matrix<2 x vector<2xf32>>) -> () {
+ // CHECK: spirv.GL.MatrixInverse {{%.*}} : !spirv.matrix<2 x vector<2xf32>>
+ %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<2 x vector<2xf32>>
+ return
+}
+
+// -----
+
+func.func @matrix_inverse_non_square(%matrix : !spirv.matrix<3 x vector<4xf32>>) -> () {
+ // expected-error @+1 {{matrix must be square, got 3 columns and 4 rows}}
+ %0 = spirv.GL.MatrixInverse %matrix : !spirv.matrix<3 x vector<4xf32>>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.GL.Sqrt
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/193594
More information about the Mlir-commits
mailing list