[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)

Cullen Rhodes llvmlistbot at llvm.org
Wed Jan 24 02:43:10 PST 2024


================
@@ -814,6 +814,649 @@ let arguments = (ins
   }];
 }
 
+class OuterProductWideBase<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes,
+                           int numOuterProducts> :
+  ArmSME_Op<mnemonic, [
+    ArmSMETileOpInterface,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+    >,
+    OptionalTypesMatchWith<"result and acc have the same type",
+                           "result", "acc", "::llvm::cast<Type>($_self)">,
+    // this trait ensures the input type match the correct output type for ops
+    // that takes multiple inputs and outputs (i.e., 4-way).
+    PredOpTrait<
+      "tile element size equals lhs element size * " # numOuterProducts,
+      CPred<"getTileType().getElementTypeBitWidth() == "
+            "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+    >,
+  ]> {
+
+  let arguments = (ins
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+    Optional<AnyVector>:$acc);
----------------
c-rhodes wrote:

functionally it makes no difference, this could literally be `AnyType` (and perhaps it should), as AFAIK there's no way of expressing with the type alone that `lhs` must equal `rhs` without using constraints, unless the op only accepted one input type of course, in that case it could be `inputType:$lhs, inputType:$rhs`.

It's a bit unfortunate because the auto-generated op documentation will say `rhs` is `vector of any type values`, which we know isn't true, but `a vector type that matches the size of a SVE vector` isn't true either.

https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list