[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 22 06:12:26 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Cullen Rhodes (c-rhodes)
<details>
<summary>Changes</summary>
This patch introduces support for widening outer products. This enables the folding of 2 or 4 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations.
Changes:
- Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants. These map to instruction variants added in SME2 and use different intrinsics. Intrinsics are already implemented for widening variants from SME1.
- Mark 'arm_sme.outerproduct' as Pure. This is consistent with 'vector.outerproduct' and enables them to be removed if dead.
- Adds the following operations:
- fmopa_wide_2way, fmops_wide_2way
- smopa_wide_2way, smops_wide_2way
- umopa_wide_2way, umops_wide_2way
- smopa_wide_4way, smops_wide_4way
- umopa_wide_4way, umops_wide_4way
- sumopa_wide_4way, sumops_wide_4way
- sumopa_wide_4way, sumops_wide_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Adds a pass 'arm-sme-outer-product' widening that folds 'arm_sme.outerproduct' operations.
For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' and 'arm_sme.smopa_wide_4way' descriptions.
The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this.
---
Patch is 154.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78975.diff
18 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+4)
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+645)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+39)
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+4)
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+4)
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+80-2)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+2)
- (added) mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp (+501)
- (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+272)
- (modified) mlir/test/Dialect/ArmSME/cse.mlir (+15-10)
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+66)
- (added) mlir/test/Dialect/ArmSME/outer-product-widening.mlir (+785)
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+272)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir (+100)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir (+142)
- (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+12)
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+7)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc46..f051e03efbcda64 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
class ArmSME_IntrLoadStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012fe..1365dd38c115ef2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -736,6 +736,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -814,6 +815,650 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ Pure,
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and accumulate";
+
+ let description = [{
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (fp16 to fp32):
+
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS RHS
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1] |
+ [A2 A3] | [B0 B2 B4 B6]
+ [A4 A5] | [B1 B3 B5 B7]
+ [A6 A7] |
+
+ ----------------------------------------------------------------------------
+
+ 2 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
+ |
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 2 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 2 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ The 2 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ ```
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+me
+ [1] https://developer.arm.com/documentation/ddi0616
+ }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+ : OuterProductWide2Way<"fmops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and subtract";
+ let description = [{
+ Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+ destination `result`.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+ ```
+ }];
+}
+
+def SMopaWide2WayOp
+ : OuterProductWide2Way<"smopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopsWide2WayOp
+ : OuterProductWide2Way<"smops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopaWide2WayOp
+ : OuterProductWide2Way<"umopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopsWide2WayOp
+ : OuterProductWide2Way<"umops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopaWide4WayOp
+ : OuterProductWide4Way<"smopa_wide_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ This operation represents a sum of 4 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (i8 to i32):
+
+ ```mlir
+ %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+ vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+ 4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS
+ [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+ RHS
+ [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1 A2 A3] | [B0 B4 B8 B12]
+ [A4 A5 A6 A7] | [B1 B5 B9 B13]
+ [A8 A9 A10 A11] | [B2 B6 B10 B14]
+ [A12 A13 A14 A15] | [B3 B7 B11 B15]
+
+ ----------------------------------------------------------------------------
+
+ 4 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B4 B8 B12] | [B1 B5 B9 B13]
+ |
+ [A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
+ A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
+ A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
+ A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
+ |
+ Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
+ ------------- | -------------
+ |
+ [B2, B6, B10, B14] | [B3 B7 B11 B15]
+ |
+ [A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
+ A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
+ A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
+ A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 4 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+ [ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
+ [ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
+ [ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
+ [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 4 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+ ```
+
+ The 4 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a2_ins = vector.scalable.ins %a2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a3_ins = vector.scalable.ins %a3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b2_ins = vector.scalable.ins %b2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b3_ins = vector.scalable.ins %b3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+ %0 = arm_sme.smopa_wide_4way %lh...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/78975
More information about the Mlir-commits
mailing list