[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Oct 26 03:37:12 PDT 2023


================
@@ -515,4 +536,93 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand> :
+  OptionalTypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs">,
+    HasMatchingMaskTypeConstraint<"rhs">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    OptionalTypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)">
+  ]>
+{
+  let summary = "Vector outerproduct with optional fused add";
+
+  let description = [{
+    This op is based on `vector.outerproduct` with the extra conditions that:
+
+    * AXPY operations are not supported
----------------
MacDue wrote:

I'm not sure if it's a TODO? I don't know if we want to support it :)

https://github.com/llvm/llvm-project/pull/69604


More information about the Mlir-commits mailing list