[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Jan 23 05:35:30 PST 2024
================
@@ -814,6 +814,649 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and accumulate";
+
+ let description = [{
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (fp16 to fp32):
+
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS RHS
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1] |
+ [A2 A3] | [B0 B2 B4 B6]
+ [A4 A5] | [B1 B3 B5 B7]
+ [A6 A7] |
+
+ ----------------------------------------------------------------------------
+
+ 2 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
+ |
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 2 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 2 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ The 2 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
----------------
banach-space wrote:
```suggestion
%a0_ins = vector.scalable.insert %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
```
?
https://github.com/llvm/llvm-project/pull/78975
More information about the Mlir-commits
mailing list