[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Jan 24 02:43:10 PST 2024
================
@@ -814,6 +814,649 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
----------------
c-rhodes wrote:
> `wide` in `fmopa_wide_2way` is a bit confusing - it's meant to mean "widening", but it's ambiguous. I would consider dropping that - one can infer from element types that this is a "widening" OP. Also, aren't all 2-way and 4-way OPs widening? Alternatively, why no replace `wide` with `widening`?
Glad you said that as I did consider dropping `wide` from the op name altogether and you've reinforced that. `fmopa_widening_2way` feels a bit long?
https://github.com/llvm/llvm-project/pull/78975
More information about the Mlir-commits
mailing list