[Mlir-commits] [mlir] [mlir][spirv] Add definition for VectorTimesMatrixOp (PR #124571)

Igor Wodiany llvmlistbot at llvm.org
Mon Jan 27 11:43:37 PST 2025


================
@@ -198,4 +197,50 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
 
 // -----
 
+def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [Pure]> {
+  let summary = "Linear-algebraic Vector X Matrix.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+    Vector must be a vector with the same Component Type as the Component
+    Type in Result Type. Its number of components must equal the number of
+    components in each column in Matrix.
+
+    Matrix must be a matrix with the same Component Type as the Component
+    Type in Result Type. Its number of columns must equal the number of
+    components in Result Type.
----------------
IgWod-IMG wrote:

I had a look into using `AllElementTypesMatch` and it works fine for checking types of `vector` and `result`, however it fails with `matrix`. I investigated it, and it seems that the underlying `getElementTypeOrSelf` function used by`AllElementTypesMatch`, needs a `ShapedType` to extract the element type. SPIRV_Matrix is not a `ShapedType`. I tried adding `public ShapedType::Trait<MatrixType>` to the `MatrixType` as quick fix, but that does not seem to be solution. I guess because it's added as a trait.

Then regarding `TypesMatchWith`, unless I missed something, it only allows transforming one side of the comparison, and to compare elements I would need to add `getElementType()` to both LHS and RHS. So, I am not sure it can be used here. But again, I may be missing something.

So, I see 4 options here:
1) Make `SPIRV_Matrix` a `ShapedType` - I need a pointer on how to do it.
2) Add new `TypesMatchWith` that allows transformation of both side - I guess since it's more global change that would have to be a separate PR.
3) Alternatively, I could write a local extended `TypesMatchWith` or specialized `AllElementTypesMatch` just for SPIR-V dialect that does the element comparison for SPIR-V types (matrices).
4) Leave verification in C++.

I'm still learning MLIR, so there is a chance I missed something obvious in here, and everything I need can be achieved with existing features. Thanks in advance!

https://github.com/llvm/llvm-project/pull/124571


More information about the Mlir-commits mailing list