[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)

Cullen Rhodes llvmlistbot at llvm.org
Fri Feb 2 08:49:30 PST 2024

https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/79288

>From 33c0daddf638fb4d2a3a515952308ed27bac4625 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 24 Jan 2024 10:15:14 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Support 4-way widening outer products

This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.


- Adds the following operations:
  - smopa_4way, smops_4way
  - umopa_4way, umops_4way
  - sumopa_4way, sumops_4way
  - sumopa_4way, sumops_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Extends 'arm-sme-outer-product' pass.

For a detailed description of these operations see the
'arm_sme.smopa_4way' description.

Address comments. Changes:

- add common match failures.
- move isCompatible to static function.
- update isCompatible to take optional `rhsExtType`.
- use isCompatible in-place of isWidenable.
- add canFuseOuterProducts for 4-way.
- llvm::hasSingleElement -> hasOneUse.
- op.erase -> rewriter.eraseOp.

Address comments. Changes:

Same as 2-way comments.
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 333 ++++++++
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  16 +
 .../ArmSME/Transforms/OuterProductFusion.cpp  | 324 ++++++-
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         | 176 ++++
 mlir/test/Dialect/ArmSME/invalid.mlir         |  13 +
 .../Dialect/ArmSME/outer-product-fusion.mlir  | 807 ++++++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 160 ++++
 .../CPU/ArmSME/test-outerproduct-i8i8i32.mlir | 150 ++++
 8 files changed, 1941 insertions(+), 38 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6b6a67ab727cd..bf6ba187dcf1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1103,6 +1103,339 @@ def UMops2WayOp
+class OuterProduct4Way<string mnemonic,
+                       list<Type> allowedInputVectorTypes,
+                       list<Type> allowedResultVectorTypes>
+  : OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
+                             allowedResultVectorTypes, /*numOuterProducts=*/4>;
+def SMopa4WayOp
+  : OuterProduct4Way<"smopa_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed integer sum of 4 outer products and accumulate";
+  let description = [{
+    This operation represents a sum of 4 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+    For example (i8 to i32):
+    ```mlir
+    %result = arm_sme.smopa_4way $lhs, $rhs :
+      vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+    4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for i8 to i32, SVL=128 (i.e., vscale=1):
+    ```
+                                        LHS
+              [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+                                        RHS
+              [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+    ----------------------------------------------------------------------------
+                                  implicit layout
+                    [A0   A1  A2  A3]    |    [B0 B4  B8 B12]
+                    [A4   A5  A6  A7]    |    [B1 B5  B9 B13]
+                    [A8   A9 A10 A11]    |    [B2 B6 B10 B14]
+                    [A12 A13 A14 A15]    |    [B3 B7 B11 B15]
+    ----------------------------------------------------------------------------
+                                  4 outer products
+                 Acol0 ⊗ Brow0           |            Acol1 ⊗ Brow1
+                 -------------           |            -------------
+                                         |
+             [B0 B4 B8 B12]              |        [B1 B5 B9 B13]
+                                         |
+       [A0   [ A0B0  A0B4  A0B8  A0B12]  |  [A1   [ A1B1  A1B5  A1B9  A1B13]
+        A4   [ A4B0  A4B4  A4B8  A4B12]  |   A5   [ A5B1  A5B5  A5B9  A5B13]
+        A8   [ A8B0  A8B4  A8B8  A8B12]  |   A9   [ A9B1  A9B5  A9B9  A9B13]
+        A12] [A12B0 A12B4 A12B8 A12B12]  |   A13] [A13B1 A13B5 A13B9 A13B13]
+                                         |
+                 Acol2 ⊗ Brow2           |            Acol3 ⊗ Brow3
+                 -------------           |            -------------
+                                         |
+             [B2, B6, B10, B14]          |        [B3 B7 B11 B15]
+                                         |
+       [A2   [ A2B2  A2B6  A2B10  A2B14] |  [A3   [ A3B3  A3B7  A3B11  A3B15]
+        A6   [ A6B2  A6B6  A6B10  A6B14] |   A7   [ A7B3  A7B7  A7B11  A7B15]
+        A10  [A10B2 A10B6 A10B10 A10B14] |   A11  [A11B3 A11B7 A11B11 A11B15]
+        A14] [A14B2 A14B6 A14B10 A14B14] |   A15] [A15B3 A15B7 A15B11 A15B15]
+                                         |
+    ----------------------------------------------------------------------------
+                              sum of 4 outer products
+           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+     [ A0B0 +  A1B1 +  A2B2 +  A3B3 ... ...  A0B12 +  A1B13 +  A2B14 +  A3B15]
+     [ A4B0 +  A5B1 +  A6B2 +  A7B3 ... ...  A4B12 +  A5B13 +  A6B14 +  A7B15]
+     [ A8B0 +  A9B1 + A10B2 + A11B3 ... ...  A8B12 +  A9B13 + A10B14 + A11B15]
+     [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+    ----------------------------------------------------------------------------
+    ```
+    This operation enables the folding of 4 outer products chained via the
+    accumulator into a single outer product.
+    For example:
+    ```mlir
+    %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+    %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+    %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+    %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+    %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+    %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+    %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+    %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+    %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+    %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+    ```
+    The 4 outer products in the example above can be fused into a single outer
+    product as follows:
+    ```mlir
+    %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+    %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+    %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+    %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+    %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+    %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+    %0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    This is implemented in the `-arm-sme-outer-product-fusion` pass.
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.smopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def SMops4WayOp
+  : OuterProduct4Way<"smops_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed integer sum of 4 outer products and subtract";
+  let description = [{
+    Equivalent to `smopa_4way` but outer products are subtracted from
+    destination `result`.
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.smops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def UMopa4WayOp
+  : OuterProduct4Way<"umopa_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Unsigned integer sum of 4 outer products and accumulate";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.umopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def UMops4WayOp
+  : OuterProduct4Way<"umops_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Unsigned integer sum of 4 outer products and subtract";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.umops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def SuMopa4WayOp
+  : OuterProduct4Way<"sumopa_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed by unsigned integer sum of 4 outer products and accumulate";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def SuMops4WayOp
+  : OuterProduct4Way<"sumops_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed by unsigned integer sum of 4 outer products and subtract";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.sumops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def UsMopa4WayOp
+  : OuterProduct4Way<"usmopa_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Unsigned by signed integer sum of 4 outer products and accumulate";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
+def UsMops4WayOp
+  : OuterProduct4Way<"usmops_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Unsigned by signed integer sum of 4 outer products and subtract";
+  let description = [{
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.usmops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+    Example: I16 to I64
+    ```mlir
+    %result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+    Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+    detailed description of 4-way outer products.
+    | Spec | Features |
+    | ---- | -------- |
+    | [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+    ```
+  }];
 def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
   let summary = "Query the streaming vector length";
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e73388b0906e8..1ba1b88fc1234 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -939,6 +939,22 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+      OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
+                                       arm_sme::aarch64_sme_smopa_wide>,
+      OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
+                                       arm_sme::aarch64_sme_smops_wide>,
+      OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
+                                       arm_sme::aarch64_sme_umopa_wide>,
+      OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
+                                       arm_sme::aarch64_sme_umops_wide>,
+      OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
+                                       arm_sme::aarch64_sme_sumopa_wide>,
+      OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
+                                       arm_sme::aarch64_sme_sumops_wide>,
+      OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
+                                       arm_sme::aarch64_sme_usmopa_wide>,
+      OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
+                                       arm_sme::aarch64_sme_usmops_wide>,
       ZeroOpConversion, GetTileConversion>(patterns, converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index fa55f9b7b31e1..2d404bc593b72 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -31,6 +31,55 @@ using namespace mlir;
 using namespace mlir::arm_sme;
 namespace {
+// Common match failure reasons.
+static constexpr StringLiteral
+    MATCH_FAILURE_NO_ACCUMULATOR("no accumulator operand");
+    "defining op of accumulator must be 'arm_sme.outerproduct'");
+    "combining kind (add or sub) of outer products must match");
+    "outer product(s) not single use and cannot be removed, no benefit to "
+    "fusing");
+static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_MASKING(
+    "unsupported masking, either both outerproducts are masked "
+    "or neither");
+// An outer product is compatible if all of the following are true:
+// - the result type matches `resultType`.
+// - the defining operation of LHS is of the type `LhsExtOp`.
+// - the defining operation of RHS is of the type `RhsExtOp`.
+// - the input types of the defining operations are identical and match
+//   `inputType`.
+template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
+static LogicalResult isCompatible(PatternRewriter &rewriter,
+                                  arm_sme::OuterProductOp op,
+                                  VectorType resultType, VectorType inputType) {
+  if (op.getResultType() != resultType)
+    return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+      diag << "unsupported result type, expected " << resultType;
+    });
+  auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
+  auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
+  if (!lhsDefOp || !rhsDefOp)
+    return rewriter.notifyMatchFailure(
+        op, "defining op of outerproduct operands must be one of: "
+            "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+  auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
+  auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+  if (lhsInType != inputType || rhsInType != inputType)
+    return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+      diag << "unsupported input type, expected " << inputType;
+    });
+  return success();
 // Fuse two 'arm_sme.outerproduct' operations that are chained via the
 // accumulator into 2-way outer product operation.
@@ -63,18 +112,17 @@ class OuterProductFusion2Way
                                 PatternRewriter &rewriter) const override {
     Value acc = op.getAcc();
     if (!acc)
-      return rewriter.notifyMatchFailure(op, "no accumulator operand");
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
     arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
     arm_sme::OuterProductOp op2 = op;
     if (!op1)
-      return rewriter.notifyMatchFailure(op,
-                                         "defining op of accumulator operand "
-                                         "must be an 'arm_sme.outerproduct'");
+      return rewriter.notifyMatchFailure(
     if (op1.getKind() != op2.getKind())
       return rewriter.notifyMatchFailure(
-          op, "combining kind (add or sub) of outer products must match");
     if (!op1->hasOneUse()) {
       // If the first outer product has uses other than as the input to another
@@ -101,14 +149,12 @@ class OuterProductFusion2Way
       // No accumulator would be ok, but it's simpler to prevent this
       // altogether, since it has no benefit.
       return rewriter.notifyMatchFailure(
-          op, "first outer product is not single use and cannot be removed, "
-              "no benefit to fusing");
     if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
-      return rewriter.notifyMatchFailure(
-          op, "unsupported masking, either both outerproducts are masked "
-              "or neither");
+      return rewriter.notifyMatchFailure(op,
+                                         MATCH_FAILURE_INCONSISTENT_MASKING);
     if (failed(canFuseOuterProducts(rewriter, op1, op2)))
       return failure();
@@ -225,37 +271,238 @@ class OuterProductFusion2Way
     return success();
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+    : public OpRewritePattern<arm_sme::OuterProductOp> {
+  using OpRewritePattern::OpRewritePattern;
-  // An outer product is compatible if all of the following are true:
-  // - the result type matches `resultType`.
-  // - the defining operations of the inputs are identical and of the type
-  //   `ExtOp`.
-  // - the input types of the defining operations are identical and match
-  //   `inputType`.
-  template <typename ExtOp>
-  LogicalResult isCompatible(PatternRewriter &rewriter,
-                             arm_sme::OuterProductOp op, VectorType resultType,
-                             VectorType inputType) const {
-    if (op.getResultType() != resultType)
-      return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
-        diag << "unsupported result type, expected " << resultType;
-      });
-    auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
-    auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-    if (!lhsDefOp || !rhsDefOp)
+  LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+                                PatternRewriter &rewriter) const override {
+    Value acc = op.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+    arm_sme::OuterProductOp op4 = op;
+    arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op3)
       return rewriter.notifyMatchFailure(
-          op, "defining op of outerproduct operands must be one of: "
-              "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+    acc = op3.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+    arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op2)
+      return rewriter.notifyMatchFailure(
+    acc = op2.getAcc();
+    if (!acc)
+      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+    arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+    if (!op1)
+      return rewriter.notifyMatchFailure(
+    arm_sme::CombiningKind kind = op1.getKind();
+    if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
+      return rewriter.notifyMatchFailure(
+    if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
+      return rewriter.notifyMatchFailure(
+    if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
+        bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
+      return rewriter.notifyMatchFailure(op,
+                                         MATCH_FAILURE_INCONSISTENT_MASKING);
+    if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
+      return failure();
+    auto loc = op.getLoc();
+    auto packInputs = [&](Value lhs, Value rhs) {
+      auto inputType = cast<VectorType>(lhs.getType());
+      VectorType inputTypeX2 =
+          VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+      return rewriter.create<LLVM::experimental_vector_interleave2>(
+          loc, inputTypeX2, lhs, rhs);
+    };
-    auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
-    auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+    auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
+                           op3.getLhs().getDefiningOp()->getOperand(0));
+    auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
+                           op4.getLhs().getDefiningOp()->getOperand(0));
+    auto lhs = packInputs(lhs0, lhs1);
-    if (lhsInType != inputType || rhsInType != inputType)
-      return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
-        diag << "unsupported input type, expected " << inputType;
-      });
+    auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
+                           op3.getRhs().getDefiningOp()->getOperand(0));
+    auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
+                           op4.getRhs().getDefiningOp()->getOperand(0));
+    auto rhs = packInputs(rhs0, rhs1);
+    Value lhsMask, rhsMask;
+    if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
+        op4.getLhsMask()) {
+      auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
+      auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
+      lhsMask = packInputs(lhs0Mask, lhs1Mask);
+      auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
+      auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
+      rhsMask = packInputs(rhs0Mask, rhs1Mask);
+    }
+    auto lhsExtOp = op.getLhs().getDefiningOp();
+    auto rhsExtOp = op.getRhs().getDefiningOp();
+    if (kind == arm_sme::CombiningKind::Add) {
+      if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else
+        llvm_unreachable("unexpected extend op!");
+    } else if (kind == arm_sme::CombiningKind::Sub) {
+      if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+        rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
+            op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+      else
+        llvm_unreachable("unexpected extend op!");
+    } else {
+      llvm_unreachable("unexpected arm_sme::CombiningKind!");
+    }
+    rewriter.eraseOp(op3);
+    rewriter.eraseOp(op2);
+    rewriter.eraseOp(op1);
+    return success();
+  }
+  // Four outer products can be fused if all of the following are true:
+  // - input and result types match.
+  // - the defining operations of the inputs are identical extensions,
+  //   specifically either:
+  //     - a signed or unsigned extension for integer types.
+  //     - a floating-point extension for floating-point types.
+  // - the types and extension are supported, i.e. there's a 4-way operation
+  //   they can be fused into.
+  LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
+                                     arm_sme::OuterProductOp op1,
+                                     arm_sme::OuterProductOp op2,
+                                     arm_sme::OuterProductOp op3,
+                                     arm_sme::OuterProductOp op4) const {
+    // Supported result types.
+    auto nxnxv4i32 =
+        VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+    auto nxnxv2i64 =
+        VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+    // Supported input types.
+    // Note: this is before packing so these have 1/4 the number of elements
+    // of the input vector types of the 4-way operations.
+    auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
+    auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
+    if (
+        // signed, i8i8i32
+        (failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+        // signed, i16i16i64
+        (failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+         failed(
+             isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
+                                             nxv2i16))) &&
+        // unsigned, i8i8i32
+        (failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+         failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+        // unsigned, i16i16i64
+        (failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+         failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+         failed(
+             isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
+                                             nxv2i16))) &&
+        // signed by unsigned, i8i8i32
+        (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op1, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op2, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op3, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op4, nxnxv4i32, nxv4i8))) &&
+        // signed by unsigned, i16i16i64
+        (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op1, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op2, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op3, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+             rewriter, op4, nxnxv2i64, nxv2i16))) &&
+        // unsigned by signed, i8i8i32
+        (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op1, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op2, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op3, nxnxv4i32, nxv4i8)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op4, nxnxv4i32, nxv4i8))) &&
+        // unsigned by signed, i16i16i64
+        (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op1, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op2, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op3, nxnxv2i64, nxv2i16)) ||
+         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+             rewriter, op4, nxnxv2i64, nxv2i16))))
+      return failure();
     return success();
@@ -278,7 +525,8 @@ struct OuterProductFusionPass
 void mlir::arm_sme::populateOuterProductFusionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<OuterProductFusion2Way>(patterns.getContext());
+  patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(
+      patterns.getContext());
 std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index c41504d0e4724..81087cc02099f 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -697,3 +697,179 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
   %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
   return %result : vector<[4]x[4]xi32>
+// arm_sme.smopa_4way
+// -----
+// CHECK-LABEL: arm_sme_smopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_smopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.smops_4way
+// -----
+// CHECK-LABEL: arm_sme_smops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_smops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.umopa_4way
+// -----
+// CHECK-LABEL: arm_sme_umopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_umopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.umops_4way
+// -----
+// CHECK-LABEL: arm_sme_umops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_umops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.sumopa_4way
+// -----
+// CHECK-LABEL: arm_sme_sumopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_sumopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.sumops_4way
+// -----
+// CHECK-LABEL: arm_sme_sumops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_sumops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.usmopa_4way
+// -----
+// CHECK-LABEL: arm_sme_usmopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %reuslt : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_usmopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %reuslt : vector<[2]x[2]xi64>
+// arm_sme.usmops_4way
+// -----
+// CHECK-LABEL: arm_sme_usmops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %reuslt : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: arm_sme_usmops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %reuslt : vector<[2]x[2]xi64>
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index dcc231332f208..cc052fac0d9dc 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -226,3 +226,16 @@ func.func @arm_sme_fmopa_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vecto
   %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
   return %0 : vector<[4]x[4]xf32>
+// arm_sme.smopa_4way
+// -----
+func.func @arm_sme_smopa_4way__bad_tile_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32>
+  // expected-error at +1 {{op failed to verify that tile element size equals input element size * 4}}
+  %0 = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+  return %0 : vector<[4]x[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 6a200848c2d8e..67448d111b935 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -213,6 +213,581 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
   return %1 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
+// CHECK-SAME:    %[[A0:.*]]: vector<[4]xi8>, %[[B0:.*]]: vector<[4]xi8>, %[[A1:.*]]: vector<[4]xi8>, %[[B1:.*]]: vector<[4]xi8>, %[[A2:.*]]: vector<[4]xi8>, %[[B2:.*]]: vector<[4]xi8>, %[[A3:.*]]: vector<[4]xi8>, %[[B3:.*]]: vector<[4]xi8>,
+// CHECK-SAME:    %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>, %[[A2_MASK:.*]]: vector<[4]xi1>, %[[B2_MASK:.*]]: vector<[4]xi1>, %[[A3_MASK:.*]]: vector<[4]xi1>, %[[B3_MASK:.*]]: vector<[4]xi1>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
+// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i8i8i32
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+    %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+    %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
+// -----
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
+    %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+    %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+    %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+    %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+    %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+    %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+    %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+    %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+  %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+  %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+  %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+  %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+  %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+  %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+  %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+  %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+  %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+  return %3 : vector<[2]x[2]xi64>
 /// Negative tests
 // -----
@@ -232,6 +807,34 @@ func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vecto
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__no_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__no_acc(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+  return %2 : vector<[4]x[4]xi32>
+// -----
 /// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
 // CHECK-LABEL: @outerproduct_widening_2way__bad_acc
@@ -249,6 +852,41 @@ func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vect
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_acc(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+  // break chain
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
 /// Combining kinds of outer products must match to be fused.
 // CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
@@ -272,6 +910,40 @@ func.func @outerproduct_widening_2way__bad_combining_kind(
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_combining_kind
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_combining_kind(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<add> acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<add> acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<add> acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
 /// If the first outer product has uses other than as the input to another
 /// outer product, it can't be erased after fusion. This is a problem when
 /// it also has an accumulator as this will be used as the root for tile
@@ -302,6 +974,41 @@ func.func @outerproduct_widening_2way__cant_erase(
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__cant_erase
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__cant_erase(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
 // CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
 // CHECK-NOT: arm_sme.fmopa_2way
 // CHECK: arm_sme.outerproduct
@@ -323,6 +1030,40 @@ func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__unsupported_type_f16f16f64
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__unsupported_type_f16f16f64(
+    %a0 : vector<[2]xf16>, %b0 : vector<[2]xf16>,
+    %a1 : vector<[2]xf16>, %b1 : vector<[2]xf16>,
+    %a2 : vector<[2]xf16>, %b2 : vector<[2]xf16>,
+    %a3 : vector<[2]xf16>, %b3 : vector<[2]xf16>) -> vector<[2]x[2]xf64> {
+  %a0_ext = arith.extf %a0 : vector<[2]xf16> to vector<[2]xf64>
+  %b0_ext = arith.extf %b0 : vector<[2]xf16> to vector<[2]xf64>
+  %a1_ext = arith.extf %a1 : vector<[2]xf16> to vector<[2]xf64>
+  %b1_ext = arith.extf %b1 : vector<[2]xf16> to vector<[2]xf64>
+  %a2_ext = arith.extf %a2 : vector<[2]xf16> to vector<[2]xf64>
+  %b2_ext = arith.extf %b2 : vector<[2]xf16> to vector<[2]xf64>
+  %a3_ext = arith.extf %a3 : vector<[2]xf16> to vector<[2]xf64>
+  %b3_ext = arith.extf %b3 : vector<[2]xf16> to vector<[2]xf64>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[2]xf64>, vector<[2]xf64>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[2]xf64>, vector<[2]xf64>
+  return %3 : vector<[2]x[2]xf64>
+// -----
 /// Fusion only occurs if either both outer products are masked, or neither.
 // CHECK-LABEL: @outerproduct_widening_2way__bad_masking
@@ -347,6 +1088,41 @@ func.func @outerproduct_widening_2way__bad_masking(
 // -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_masking
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_masking(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+    %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
+// -----
 /// Defining op of outer product must be a supported extension op.
 // CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
@@ -362,3 +1138,34 @@ func.func @outerproduct_widening_2way__bad_defining_op(
   return %1 : vector<[4]x[4]xf32>
+// -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_defining_op(
+    %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+    %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+    %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
+    %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+  %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+  %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+  return %3 : vector<[4]x[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ca096363e7283..ab46c7adca596 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1243,3 +1243,163 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
   %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
   return %result : vector<[4]x[4]xi32>
+// arm_sme.smopa_4way
+// -----
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.smops_4way
+// -----
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.umopa_4way
+// -----
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.umops_4way
+// -----
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.sumopa_4way
+// -----
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.sumops_4way
+// -----
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %result : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+// arm_sme.usmopa_4way
+// -----
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %reuslt : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %reuslt : vector<[2]x[2]xi64>
+// arm_sme.usmops_4way
+// -----
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+  // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+  return %reuslt : vector<[4]x[4]xi32>
+// -----
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+  return %reuslt : vector<[2]x[2]xi64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
new file mode 100644
index 0000000000000..1770e579f0bd6
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
@@ -0,0 +1,150 @@
+// DEFINE: %{entry} = main
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -arm-sme-outer-product-fusion \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE:   -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils,%arm_sme_abi_shlib
+// RUN: %{compile} | %{run} | FileCheck %s
+// NOTE: QEMU gives incorrect result for SME SMOPA 4-way outer product
+// instruction (version <= 8.2.0, latest version at time of writing), see:
+// https://gitlab.com/qemu-project/qemu/-/issues/2083 This test is expected to
+// fail (CHECK lines are correct) until a fixed version of QEMU can be used.
+// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+// (and installed on CI buildbot).
+// XFAIL: *
+// NOTE: there is no non-widening variant for these types and this test can't
+// be lowered without the widening pass, therefore we can't check if the result
+// is the same without widening pass like 'test-outerproduct-f16f16f32.mlir'
+// does.
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+  func.call @test_outerproduct_i8i8i32 () : () -> ()
+  func.call @test_masked_outerproduct_i8i8i32() : () -> ()
+  return
+func.func @test_outerproduct_i8i8i32() {
+  %undef = llvm.mlir.undef : vector<[4]xi8>
+  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+  %2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+  %3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+  // CHECK:      ( 110,  134,  158,  182 )
+  // CHECK-NEXT: ( 390,  478,  566,  654 )
+  // CHECK-NEXT: ( 670,  822,  974, 1126 )
+  // CHECK-NEXT: ( 950, 1166, 1382, 1598 )
+  vector.print %3 : vector<[4]x[4]xi32>
+  return
+func.func @test_masked_outerproduct_i8i8i32() {
+  %undef = llvm.mlir.undef : vector<[4]xi8>
+  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+  %mask0 = vector.create_mask %c1, %c1 : vector<[4]x[4]xi1>
+  %mask1 = vector.create_mask %c1, %c2 : vector<[4]x[4]xi1>
+  %mask2 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %mask3 = vector.create_mask %c3, %c4 : vector<[4]x[4]xi1>
+  %acc = arith.constant dense<2> : vector<[4]x[4]xi32>
+  %0 = vector.mask %mask0 {
+    vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xi32>, vector<[4]xi32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+  %1 = vector.mask %mask1 {
+    vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+  %2 = vector.mask %mask2 {
+    vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+  %3 = vector.mask %mask3 {
+    vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+  // CHECK:      ( 112, 136, 135,  95 )
+  // CHECK-NEXT: ( 243, 295, 347, 219 )
+  // CHECK-NEXT: ( 211, 255, 299, 343 )
+  // CHECK-NEXT: (   2,   2,   2,   2 )
+  vector.print %3 : vector<[4]x[4]xi32>
+  return
+func.func private @setArmSVLBits(%bits : i32)

>From 43152fc5681633dfbcfec7aa0254f8c3b76dadab Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 2 Feb 2024 15:42:13 +0000
Subject: [PATCH 2/2] Address comments. Changes:

- fix check for consistent masking.
- rewrite as loop that walks outer product chain.
- use lambda for match check.
 .../ArmSME/Transforms/OuterProductFusion.cpp  | 175 ++++++------------
 1 file changed, 55 insertions(+), 120 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 2d404bc593b72..6c3863a60e693 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -282,51 +282,38 @@ class OuterProductFusion4Way
   LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
                                 PatternRewriter &rewriter) const override {
-    Value acc = op.getAcc();
-    if (!acc)
-      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-    arm_sme::OuterProductOp op4 = op;
-    arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
-    if (!op3)
-      return rewriter.notifyMatchFailure(
-    acc = op3.getAcc();
-    if (!acc)
-      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-    arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
-    if (!op2)
-      return rewriter.notifyMatchFailure(
-    acc = op2.getAcc();
-    if (!acc)
-      return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-    arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
-    if (!op1)
-      return rewriter.notifyMatchFailure(
-    arm_sme::CombiningKind kind = op1.getKind();
-    if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
-      return rewriter.notifyMatchFailure(
-    if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
-      return rewriter.notifyMatchFailure(
-    if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
-        bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
-      return rewriter.notifyMatchFailure(op,
-                                         MATCH_FAILURE_INCONSISTENT_MASKING);
+    SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
+    outerProductChain.push_back(op);
+    for (int i = 0; i < 3; ++i) {
+      auto currentOp = outerProductChain.back();
+      auto acc = currentOp.getAcc();
+      if (!acc)
+        return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+      auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
+      if (!previousOp)
+        return rewriter.notifyMatchFailure(
+      if (!previousOp->hasOneUse())
+        return rewriter.notifyMatchFailure(
+      if (previousOp.getKind() != currentOp.getKind())
+        return rewriter.notifyMatchFailure(
+      if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
+        return rewriter.notifyMatchFailure(
+      outerProductChain.push_back(previousOp);
+    }
-    if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
+    if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
       return failure();
+    arm_sme::OuterProductOp op1 = outerProductChain[3];
+    arm_sme::OuterProductOp op2 = outerProductChain[2];
+    arm_sme::OuterProductOp op3 = outerProductChain[1];
+    arm_sme::OuterProductOp op4 = outerProductChain[0];
     auto loc = op.getLoc();
     auto packInputs = [&](Value lhs, Value rhs) {
@@ -364,6 +351,7 @@ class OuterProductFusion4Way
     auto lhsExtOp = op.getLhs().getDefiningOp();
     auto rhsExtOp = op.getRhs().getDefiningOp();
+    arm_sme::CombiningKind kind = op.getKind();
     if (kind == arm_sme::CombiningKind::Add) {
       if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
@@ -414,94 +402,41 @@ class OuterProductFusion4Way
   //     - a floating-point extension for floating-point types.
   // - the types and extension are supported, i.e. there's a 4-way operation
   //   they can be fused into.
-  LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
-                                     arm_sme::OuterProductOp op1,
-                                     arm_sme::OuterProductOp op2,
-                                     arm_sme::OuterProductOp op3,
-                                     arm_sme::OuterProductOp op4) const {
+  LogicalResult
+  canFuseOuterProducts(PatternRewriter &rewriter,
+                       SmallVectorImpl<arm_sme::OuterProductOp> &ops) const {
     // Supported result types.
     auto nxnxv4i32 =
         VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
     auto nxnxv2i64 =
         VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
     // Supported input types.
     // Note: this is before packing so these have 1/4 the number of elements
     // of the input vector types of the 4-way operations.
     auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
     auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
-    if (
-        // signed, i8i8i32
-        (failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
-        // signed, i16i16i64
-        (failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
-         failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
-         failed(
-             isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
-                                             nxv2i16))) &&
-        // unsigned, i8i8i32
-        (failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
-         failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
-        // unsigned, i16i16i64
-        (failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
-         failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
-         failed(
-             isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
-                                             nxv2i16))) &&
-        // signed by unsigned, i8i8i32
-        (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op1, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op2, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op3, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op4, nxnxv4i32, nxv4i8))) &&
-        // signed by unsigned, i16i16i64
-        (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op1, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op2, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op3, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
-             rewriter, op4, nxnxv2i64, nxv2i16))) &&
-        // unsigned by signed, i8i8i32
-        (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op1, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op2, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op3, nxnxv4i32, nxv4i8)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op4, nxnxv4i32, nxv4i8))) &&
-        // unsigned by signed, i16i16i64
-        (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op1, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op2, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op3, nxnxv2i64, nxv2i16)) ||
-         failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
-             rewriter, op4, nxnxv2i64, nxv2i16))))
+    auto failedToMatch = [&](VectorType resultType, VectorType inputType,
+                             auto lhsExtendOp, auto rhsExtendOp) {
+      using LhsExtendOpTy = decltype(lhsExtendOp);
+      using RhsExtendOpTy = decltype(rhsExtendOp);
+      for (auto op : ops) {
+        if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
+                rewriter, op, resultType, inputType)))
+          return true;
+      }
+      return false;
+    };
+    if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
+        failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
+        failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
+        failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
+        failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
+        failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
+        failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
+        failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
       return failure();
     return success();

More information about the Mlir-commits mailing list