[Mlir-commits] [mlir] [mlir][ArmSME] Support widening outer products (PR #78975)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 22 06:12:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

<details>
<summary>Changes</summary>

This patch introduces support for widening outer products. This enables the folding of 2 or 4 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations.

Changes:

- Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants. These map to instruction variants added in SME2 and use different intrinsics. Intrinsics are already implemented for widening variants from SME1.
- Mark 'arm_sme.outerproduct' as Pure. This is consistent with 'vector.outerproduct' and enables them to be removed if dead.
- Adds the following operations:
  - fmopa_wide_2way, fmops_wide_2way
  - smopa_wide_2way, smops_wide_2way
  - umopa_wide_2way, umops_wide_2way
  - smopa_wide_4way, smops_wide_4way
  - umopa_wide_4way, umops_wide_4way
  - sumopa_wide_4way, sumops_wide_4way
  - sumopa_wide_4way, sumops_wide_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Adds a pass 'arm-sme-outer-product' widening that folds 'arm_sme.outerproduct' operations.

For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' and 'arm_sme.smopa_wide_4way' descriptions.

The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this.

---

Patch is 154.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78975.diff


18 Files Affected:

- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+645) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+39) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+4) 
- (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+4) 
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+80-2) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+2) 
- (added) mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp (+501) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+272) 
- (modified) mlir/test/Dialect/ArmSME/cse.mlir (+15-10) 
- (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+66) 
- (added) mlir/test/Dialect/ArmSME/outer-product-widening.mlir (+785) 
- (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+272) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir (+100) 
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir (+142) 
- (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+12) 
- (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc46..f051e03efbcda64 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
 def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
 
 class ArmSME_IntrLoadStoreOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012fe..1365dd38c115ef2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -736,6 +736,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -814,6 +815,650 @@ let arguments = (ins
   }];
 }
 
+class OuterProductWideBase<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes,
+                           int numOuterProducts> :
+  ArmSME_Op<mnemonic, [
+    Pure,
+    ArmSMETileOpInterface,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+    >,
+    OptionalTypesMatchWith<"result and acc have the same type",
+                           "result", "acc", "::llvm::cast<Type>($_self)">,
+    // this trait ensures the input type match the correct output type for ops
+    // that takes multiple inputs and outputs (i.e., 4-way).
+    PredOpTrait<
+      "tile element size equals lhs element size * " # numOuterProducts,
+      CPred<"getTileType().getElementTypeBitWidth() == "
+            "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+    >,
+  ]> {
+
+  let arguments = (ins
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+    Optional<AnyVector>:$acc);
+  let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
+    VectorType getTileType() {
+      return getResultType();
+    }
+  }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+  : OuterProductWide2Way<"fmopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and accumulate";
+
+  let description = [{
+    This operation represents a sum of 2 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (fp16 to fp32):
+
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+      vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+    2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                          LHS                          RHS
+               [A0 A1 A2 A3 A4 A5 A6 A7]    [B0 B1 B2 B3 B4 B5 B6 B7]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                              [A0 A1]    |
+                              [A2 A3]    |    [B0 B2 B4 B6]
+                              [A4 A5]    |    [B1 B3 B5 B7]
+                              [A6 A7]    |
+
+    ----------------------------------------------------------------------------
+
+                                  2 outer products
+
+                      Acol0 ⊗ Brow0      |           Acol1 ⊗ Brow1
+                      -------------      |           -------------
+                                         |
+                  [B0 B2 B4 B6]          |       [B1 B3 B5 B7]
+                                         |
+             [A0  [A0B0 A0B2 A0B4 A0B6]  |  [A1  [A1B1 A1B3 A1B5 A1B7]
+              A2  [A2B0 A2B2 A2B4 A2B6]  |   A3  [A3B1 A3B3 A3B5 A3B7]
+              A4  [A4B0 A4B2 A4B4 A4B6]  |   A5  [A5B1 A5B3 A5B5 A5B7]
+              A6] [A6B0 A6B2 A6B4 A6B6]  |   A7] [A7B1 A7B3 A7B5 A7B7]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 2 outer products
+
+                           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+                 [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+                 [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+                 [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+                 [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 2 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+    %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+    %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+    %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    The 2 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+	```
+
+    This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+    ```
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+    | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+me
+    [1] https://developer.arm.com/documentation/ddi0616
+  }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+  : OuterProductWide2Way<"fmops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and subtract";
+  let description = [{
+    Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+    destination `result`.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+    | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+    ```
+  }];
+}
+
+def SMopaWide2WayOp
+  : OuterProductWide2Way<"smopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopsWide2WayOp
+  : OuterProductWide2Way<"smops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopaWide2WayOp
+  : OuterProductWide2Way<"umopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopsWide2WayOp
+  : OuterProductWide2Way<"umops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopaWide4WayOp
+  : OuterProductWide4Way<"smopa_wide_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed integer sum of 4 outer products and accumulate";
+  let description = [{
+    This operation represents a sum of 4 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (i8 to i32):
+
+    ```mlir
+    %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+      vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+    4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                                        LHS
+              [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+                                        RHS
+              [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                    [A0   A1  A2  A3]    |    [B0 B4  B8 B12]
+                    [A4   A5  A6  A7]    |    [B1 B5  B9 B13]
+                    [A8   A9 A10 A11]    |    [B2 B6 B10 B14]
+                    [A12 A13 A14 A15]    |    [B3 B7 B11 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  4 outer products
+
+                 Acol0 ⊗ Brow0           |            Acol1 ⊗ Brow1
+                 -------------           |            -------------
+                                         |
+             [B0 B4 B8 B12]              |        [B1 B5 B9 B13]
+                                         |
+       [A0   [ A0B0  A0B4  A0B8  A0B12]  |  [A1   [ A1B1  A1B5  A1B9  A1B13]
+        A4   [ A4B0  A4B4  A4B8  A4B12]  |   A5   [ A5B1  A5B5  A5B9  A5B13]
+        A8   [ A8B0  A8B4  A8B8  A8B12]  |   A9   [ A9B1  A9B5  A9B9  A9B13]
+        A12] [A12B0 A12B4 A12B8 A12B12]  |   A13] [A13B1 A13B5 A13B9 A13B13]
+                                         |
+                 Acol2 ⊗ Brow2           |            Acol3 ⊗ Brow3
+                 -------------           |            -------------
+                                         |
+             [B2, B6, B10, B14]          |        [B3 B7 B11 B15]
+                                         |
+       [A2   [ A2B2  A2B6  A2B10  A2B14] |  [A3   [ A3B3  A3B7  A3B11  A3B15]
+        A6   [ A6B2  A6B6  A6B10  A6B14] |   A7   [ A7B3  A7B7  A7B11  A7B15]
+        A10  [A10B2 A10B6 A10B10 A10B14] |   A11  [A11B3 A11B7 A11B11 A11B15]
+        A14] [A14B2 A14B6 A14B10 A14B14] |   A15] [A15B3 A15B7 A15B11 A15B15]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 4 outer products
+
+           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+     [ A0B0 +  A1B1 +  A2B2 +  A3B3 ... ...  A0B12 +  A1B13 +  A2B14 +  A3B15]
+     [ A4B0 +  A5B1 +  A6B2 +  A7B3 ... ...  A4B12 +  A5B13 +  A6B14 +  A7B15]
+     [ A8B0 +  A9B1 + A10B2 + A11B3 ... ...  A8B12 +  A9B13 + A10B14 + A11B15]
+     [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 4 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+    %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+    %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+    %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+    %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+    %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+    %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+    ```
+
+    The 4 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.ins %a0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.ins %a1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a2_ins = vector.scalable.ins %a2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a3_ins = vector.scalable.ins %a3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %b0_ins = vector.scalable.ins %b0_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.ins %b1_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b2_ins = vector.scalable.ins %b2_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b3_ins = vector.scalable.ins %b3_ext, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %0 = arm_sme.smopa_wide_4way %lh...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list