[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Oct 26 09:22:22 PDT 2023
================
@@ -515,4 +536,93 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
}];
}
+class HasMatchingMaskTypeConstraint<string operand> :
+ OptionalTypesMatchWith<
+ "shape of `" # operand # "Mask` matches `" # operand # "`",
+ operand, operand # "Mask",
+ "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
+def OuterProductOp :
+ ArmSME_Op<"outerproduct", [Pure,
+ AttrSizedOperandSegments,
+ AllElementTypesMatch<["lhs", "rhs", "result"]>,
+ HasMatchingMaskTypeConstraint<"lhs">,
+ HasMatchingMaskTypeConstraint<"rhs">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+ PredOpTrait<
+ "result type is derived from `lhs` and `rhs`",
+ CPred<
+ "getResultType() == VectorType::get({"
+ "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+ "getRhsType().getElementType(),"
+ "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+ OptionalTypesMatchWith<"`result` and `acc` have the same type",
+ "result", "acc",
+ "::llvm::cast<mlir::Type>($_self)">
+ ]>
+{
+ let summary = "Vector outerproduct with optional fused add";
+
+ let description = [{
+ This op is based on `vector.outerproduct` with the extra conditions that:
+
+ * AXPY operations are not supported
+ * The only combining functions are "add" and "sub"
+ * Masking is performed on the inputs (rather than the output)
+
+ This is meant as an intermediate op for lowering `vector.outerproduct` to
+ SME. Types are not restricted to SVE/SME vectors at this level.
+
+ Example 1: Unmasked outerproduct (without accumulator)
+ ```mlir
+ %result = arm_sme.outerproduct $lhs, $rhs
+ : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+ ```
+
+ Example 2: Unmasked outerproduct (with accumulator)
+ ```mlir
+ %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+ : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+ ```
----------------
MacDue wrote:
It's zeroing :)
https://github.com/llvm/llvm-project/pull/69604
More information about the Mlir-commits
mailing list