[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Feb 6 08:43:21 PST 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/79288
>From 1b12e7d2766bdff46d06f8ad367706969cae3804 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 24 Jan 2024 10:15:14 +0000
Subject: [PATCH 1/8] [mlir][ArmSME] Support 4-way widening outer products
This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.
Changes:
- Adds the following operations:
- smopa_4way, smops_4way
- umopa_4way, umops_4way
- sumopa_4way, sumops_4way
- sumopa_4way, sumops_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Extends 'arm-sme-outer-product' pass.
For a detailed description of these operations see the
'arm_sme.smopa_4way' description.
Address comments. Changes:
- add common match failures.
- move isCompatible to static function.
- update isCompatible to take optional `rhsExtType`.
- use isCompatible in-place of isWidenable.
- add canFuseOuterProducts for 4-way.
- llvm::hasSingleElement -> hasOneUse.
- op.erase -> rewriter.eraseOp.
Address comments. Changes:
Same as 2-way comments.
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 333 +++++++
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 16 +
.../ArmSME/Transforms/OuterProductFusion.cpp | 323 ++++++-
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 176 ++++
mlir/test/Dialect/ArmSME/invalid.mlir | 13 +
.../Dialect/ArmSME/outer-product-fusion.mlir | 811 ++++++++++++++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 160 ++++
.../CPU/ArmSME/test-outerproduct-i8i8i32.mlir | 150 ++++
8 files changed, 1944 insertions(+), 38 deletions(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 51fd4b7ca21bd5..08305973b1ee08 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1103,6 +1103,339 @@ def UMops2WayOp
}];
}
+class OuterProduct4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def SMopa4WayOp
+ : OuterProduct4Way<"smopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ This operation represents a sum of 4 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (i8 to i32):
+
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs :
+ vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+ 4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for i8 to i32, SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS
+ [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+ RHS
+ [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1 A2 A3] | [B0 B4 B8 B12]
+ [A4 A5 A6 A7] | [B1 B5 B9 B13]
+ [A8 A9 A10 A11] | [B2 B6 B10 B14]
+ [A12 A13 A14 A15] | [B3 B7 B11 B15]
+
+ ----------------------------------------------------------------------------
+
+ 4 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B4 B8 B12] | [B1 B5 B9 B13]
+ |
+ [A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
+ A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
+ A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
+ A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
+ |
+ Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
+ ------------- | -------------
+ |
+ [B2, B6, B10, B14] | [B3 B7 B11 B15]
+ |
+ [A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
+ A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
+ A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
+ A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 4 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+ [ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
+ [ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
+ [ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
+ [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 4 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+ ```
+
+ The 4 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+
+ %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+
+ %0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-fusion` pass.
+
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SMops4WayOp
+ : OuterProduct4Way<"smops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and subtract";
+ let description = [{
+ Equivalent to `smopa_4way` but outer products are subtracted from
+ destination `result`.
+
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.smops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UMopa4WayOp
+ : OuterProduct4Way<"umopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.umopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UMops4WayOp
+ : OuterProduct4Way<"umops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.umops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SuMopa4WayOp
+ : OuterProduct4Way<"sumopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed by unsigned integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SuMops4WayOp
+ : OuterProduct4Way<"sumops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed by unsigned integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.sumops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UsMopa4WayOp
+ : OuterProduct4Way<"usmopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned by signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UsMops4WayOp
+ : OuterProduct4Way<"usmops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned by signed integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.usmops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e73388b0906e84..1ba1b88fc1234b 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -939,6 +939,22 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
arm_sme::aarch64_sme_umopa_za32>,
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
arm_sme::aarch64_sme_umops_za32>,
+ OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
+ arm_sme::aarch64_sme_smopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
+ arm_sme::aarch64_sme_smops_wide>,
+ OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
+ arm_sme::aarch64_sme_umopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
+ arm_sme::aarch64_sme_umops_wide>,
+ OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
+ arm_sme::aarch64_sme_sumopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
+ arm_sme::aarch64_sme_sumops_wide>,
+ OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
+ arm_sme::aarch64_sme_usmopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
+ arm_sme::aarch64_sme_usmops_wide>,
ZeroOpConversion, GetTileConversion>(patterns, converter);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 7dc2aacd7e5f3b..1f4370aec37a46 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -31,6 +31,55 @@ using namespace mlir;
using namespace mlir::arm_sme;
namespace {
+
+// Common match failure reasons.
+static constexpr StringLiteral
+ MATCH_FAILURE_NO_ACCUMULATOR("no accumulator operand");
+static constexpr StringLiteral MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP(
+ "defining op of accumulator must be 'arm_sme.outerproduct'");
+static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_COMBINING_KIND(
+ "combining kind (add or sub) of outer products must match");
+static constexpr StringLiteral MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE(
+ "outer product(s) not single use and cannot be removed, no benefit to "
+ "fusing");
+static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_MASKING(
+ "unsupported masking, either both outerproducts are masked "
+ "or neither");
+
+// An outer product is compatible if all of the following are true:
+// - the result type matches `resultType`.
+// - the defining operation of LHS is of the type `LhsExtOp`.
+// - the defining operation of RHS is of the type `RhsExtOp`.
+// - the input types of the defining operations are identical and match
+// `inputType`.
+template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
+static LogicalResult isCompatible(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op,
+ VectorType resultType, VectorType inputType) {
+ if (op.getResultType() != resultType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported result type, expected " << resultType;
+ });
+
+ auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
+ auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
+
+ if (!lhsDefOp || !rhsDefOp)
+ return rewriter.notifyMatchFailure(
+ op, "defining op of outerproduct operands must be one of: "
+ "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+
+ auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
+ auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+
+ if (lhsInType != inputType || rhsInType != inputType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported input type, expected " << inputType;
+ });
+
+ return success();
+}
+
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
@@ -63,18 +112,17 @@ class OuterProductFusion2Way
PatternRewriter &rewriter) const override {
Value acc = op.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
arm_sme::OuterProductOp op2 = op;
if (!op1)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
if (op1.getKind() != op2.getKind())
return rewriter.notifyMatchFailure(
- op, "combining kind (add or sub) of outer products must match");
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
@@ -101,14 +149,12 @@ class OuterProductFusion2Way
// No accumulator would be ok, but it's simpler to prevent this
// altogether, since it has no benefit.
return rewriter.notifyMatchFailure(
- op, "first outer product is not single use and cannot be removed, "
- "no benefit to fusing");
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
}
if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
- return rewriter.notifyMatchFailure(
- op, "unsupported masking, either both outerproducts are masked "
- "or neither");
+ return rewriter.notifyMatchFailure(op,
+ MATCH_FAILURE_INCONSISTENT_MASKING);
if (failed(canFuseOuterProducts(rewriter, op1, op2)))
return failure();
@@ -225,37 +271,238 @@ class OuterProductFusion2Way
return success();
}
+};
+
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
- // An outer product is compatible if all of the following are true:
- // - the result type matches `resultType`.
- // - the defining operations of the inputs are identical and of the type
- // `ExtOp`.
- // - the input types of the defining operations are identical and match
- // `inputType`.
- template <typename ExtOp>
- LogicalResult isCompatible(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op, VectorType resultType,
- VectorType inputType) const {
- if (op.getResultType() != resultType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported result type, expected " << resultType;
- });
-
- auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
- auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
- if (!lhsDefOp || !rhsDefOp)
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ Value acc = op.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+ arm_sme::OuterProductOp op4 = op;
+ arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op3)
return rewriter.notifyMatchFailure(
- op, "defining op of outerproduct operands must be one of: "
- "'arith.extf' or 'arith.extsi' or 'arith.extui'");
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+ acc = op3.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+ arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op2)
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+ acc = op2.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+
+ arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op1)
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+
+ arm_sme::CombiningKind kind = op1.getKind();
+ if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+
+ if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+
+ if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
+ bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
+ return rewriter.notifyMatchFailure(op,
+ MATCH_FAILURE_INCONSISTENT_MASKING);
+
+ if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
+ return failure();
+
+ auto loc = op.getLoc();
+
+ auto packInputs = [&](Value lhs, Value rhs) {
+ auto inputType = cast<VectorType>(lhs.getType());
+ VectorType inputTypeX2 =
+ VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+ return rewriter.create<LLVM::experimental_vector_interleave2>(
+ loc, inputTypeX2, lhs, rhs);
+ };
- auto lhsInType = cast<VectorType>(lhsDefOp.getIn().getType());
- auto rhsInType = cast<VectorType>(rhsDefOp.getIn().getType());
+ auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
+ op3.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
+ op4.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs = packInputs(lhs0, lhs1);
- if (lhsInType != inputType || rhsInType != inputType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported input type, expected " << inputType;
- });
+ auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
+ op3.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
+ op4.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs = packInputs(rhs0, rhs1);
+
+ Value lhsMask, rhsMask;
+ if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
+ op4.getLhsMask()) {
+ auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
+ auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
+ lhsMask = packInputs(lhs0Mask, lhs1Mask);
+
+ auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
+ auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
+ rhsMask = packInputs(rhs0Mask, rhs1Mask);
+ }
+
+ auto lhsExtOp = op.getLhs().getDefiningOp();
+ auto rhsExtOp = op.getRhs().getDefiningOp();
+
+ if (kind == arm_sme::CombiningKind::Add) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ llvm_unreachable("unexpected extend op!");
+ } else if (kind == arm_sme::CombiningKind::Sub) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ llvm_unreachable("unexpected extend op!");
+ } else {
+ llvm_unreachable("unexpected arm_sme::CombiningKind!");
+ }
+
+ rewriter.eraseOp(op3);
+ rewriter.eraseOp(op2);
+ rewriter.eraseOp(op1);
+
+ return success();
+ }
+
+private:
+ // Four outer products can be fused if all of the following are true:
+ // - input and result types match.
+ // - the defining operations of the inputs are identical extensions,
+ // specifically either:
+ // - a signed or unsigned extension for integer types.
+ // - a floating-point extension for floating-point types.
+ // - the types and extension are supported, i.e. there's a 4-way operation
+ // they can be fused into.
+ LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op1,
+ arm_sme::OuterProductOp op2,
+ arm_sme::OuterProductOp op3,
+ arm_sme::OuterProductOp op4) const {
+ // Supported result types.
+ auto nxnxv4i32 =
+ VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+ auto nxnxv2i64 =
+ VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+ // Supported input types.
+ // Note: this is before packing so these have 1/4 the number of elements
+ // of the input vector types of the 4-way operations.
+ auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
+ auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
+ if (
+ // signed, i8i8i32
+ (failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed, i16i16i64
+ (failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
+ nxv2i16))) &&
+ // unsigned, i8i8i32
+ (failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned, i16i16i64
+ (failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
+ nxv2i16))) &&
+ // signed by unsigned, i8i8i32
+ (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed by unsigned, i16i16i64
+ (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))) &&
+ // unsigned by signed, i8i8i32
+ (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned by signed, i16i16i64
+ (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))))
+ return failure();
return success();
}
@@ -380,7 +627,7 @@ void mlir::arm_sme::populateOuterProductFusionPatterns(
// Note: High benefit to ensure extract(extend) are swapped first.
patterns.add<SwapVectorExtractOfArithExtend,
SwapVectorScalableExtractOfArithExtend>(context, 1024);
- patterns.add<OuterProductFusion2Way>(context);
+ patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(context);
}
std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index c41504d0e47245..81087cc02099fb 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -697,3 +697,179 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
%result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_smopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_smops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_umopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_umops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index dcc231332f2082..cc052fac0d9dc9 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -226,3 +226,16 @@ func.func @arm_sme_fmopa_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vecto
%0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smopa_4way__bad_tile_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32>
+{
+ // expected-error at +1 {{op failed to verify that tile element size equals input element size * 4}}
+ %0 = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %0 : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index f24943cac4f767..5e5b1905047368 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -213,6 +213,581 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
return %1 : vector<[4]x[4]xi32>
}
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
+// CHECK-SAME: %[[A0:.*]]: vector<[4]xi8>, %[[B0:.*]]: vector<[4]xi8>, %[[A1:.*]]: vector<[4]xi8>, %[[B1:.*]]: vector<[4]xi8>, %[[A2:.*]]: vector<[4]xi8>, %[[B2:.*]]: vector<[4]xi8>, %[[A3:.*]]: vector<[4]xi8>, %[[B3:.*]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>, %[[A2_MASK:.*]]: vector<[4]xi1>, %[[B2_MASK:.*]]: vector<[4]xi1>, %[[A3_MASK:.*]]: vector<[4]xi1>, %[[B3_MASK:.*]]: vector<[4]xi1>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
+// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i8i8i32
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
/// Tests for related patterns.
// -----
@@ -274,6 +849,34 @@ func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vecto
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__no_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__no_acc(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %2 : vector<[4]x[4]xi32>
+}
+
+// -----
+
/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
// CHECK-LABEL: @outerproduct_widening_2way__bad_acc
@@ -291,6 +894,41 @@ func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vect
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_acc(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ // break chain
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
/// Combining kinds of outer products must match to be fused.
// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
@@ -314,6 +952,40 @@ func.func @outerproduct_widening_2way__bad_combining_kind(
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_combining_kind
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_combining_kind(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<add> acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<add> acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<add> acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
/// If the first outer product has uses other than as the input to another
/// outer product, it can't be erased after fusion. This is a problem when
/// it also has an accumulator as this will be used as the root for tile
@@ -344,6 +1016,41 @@ func.func @outerproduct_widening_2way__cant_erase(
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__cant_erase
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__cant_erase(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
// CHECK-NOT: arm_sme.fmopa_2way
// CHECK: arm_sme.outerproduct
@@ -365,6 +1072,40 @@ func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__unsupported_type_f16f16f64
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__unsupported_type_f16f16f64(
+ %a0 : vector<[2]xf16>, %b0 : vector<[2]xf16>,
+ %a1 : vector<[2]xf16>, %b1 : vector<[2]xf16>,
+ %a2 : vector<[2]xf16>, %b2 : vector<[2]xf16>,
+ %a3 : vector<[2]xf16>, %b3 : vector<[2]xf16>) -> vector<[2]x[2]xf64> {
+ %a0_ext = arith.extf %a0 : vector<[2]xf16> to vector<[2]xf64>
+ %b0_ext = arith.extf %b0 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a1_ext = arith.extf %a1 : vector<[2]xf16> to vector<[2]xf64>
+ %b1_ext = arith.extf %b1 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a2_ext = arith.extf %a2 : vector<[2]xf16> to vector<[2]xf64>
+ %b2_ext = arith.extf %b2 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a3_ext = arith.extf %a3 : vector<[2]xf16> to vector<[2]xf64>
+ %b3_ext = arith.extf %b3 : vector<[2]xf16> to vector<[2]xf64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[2]xf64>, vector<[2]xf64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[2]xf64>, vector<[2]xf64>
+
+ return %3 : vector<[2]x[2]xf64>
+}
+
+// -----
+
/// Fusion only occurs if either both outer products are masked, or neither.
// CHECK-LABEL: @outerproduct_widening_2way__bad_masking
@@ -389,6 +1130,41 @@ func.func @outerproduct_widening_2way__bad_masking(
// -----
+// CHECK-LABEL: @outerproduct_widening_4way__bad_masking
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_masking(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
/// Defining op of outer product must be a supported extension op.
// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
@@ -404,6 +1180,7 @@ func.func @outerproduct_widening_2way__bad_defining_op(
return %1 : vector<[4]x[4]xf32>
}
+<<<<<<< HEAD
/// Negative tests for related patterns.
@@ -456,3 +1233,37 @@ func.func @scalable_extract_from_non_arith_ext(%src: vector<[8]xf32>) -> vector<
%0 = vector.scalable.extract %src[0] : vector<[4]xf32> from vector<[8]xf32>
return %0 : vector<[4]xf32>
}
+||||||| constructed merge base
+=======
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_defining_op(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+>>>>>>> [mlir][ArmSME] Support 4-way widening outer products
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ca096363e7283d..ab46c7adca5966 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1243,3 +1243,163 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
%result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
new file mode 100644
index 00000000000000..1770e579f0bd68
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
@@ -0,0 +1,150 @@
+// DEFINE: %{entry} = main
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE: -arm-sme-outer-product-fusion \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE: -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils,%arm_sme_abi_shlib
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// NOTE: QEMU gives incorrect result for SME SMOPA 4-way outer product
+// instruction (version <= 8.2.0, latest version at time of writing), see:
+// https://gitlab.com/qemu-project/qemu/-/issues/2083 This test is expected to
+// fail (CHECK lines are correct) until a fixed version of QEMU can be used.
+
+// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+// (and installed on CI buildbot).
+// XFAIL: *
+
+// NOTE: there is no non-widening variant for these types and this test can't
+// be lowered without the widening pass, therefore we can't check if the result
+// is the same without widening pass like 'test-outerproduct-f16f16f32.mlir'
+// does.
+
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+ func.call @test_outerproduct_i8i8i32 () : () -> ()
+
+ func.call @test_masked_outerproduct_i8i8i32() : () -> ()
+
+ return
+}
+
+func.func @test_outerproduct_i8i8i32() {
+ %undef = llvm.mlir.undef : vector<[4]xi8>
+
+ %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+ %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+ %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+ %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+ %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+ %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+ %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+ %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+ %2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+ %3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+
+ // CHECK: ( 110, 134, 158, 182 )
+ // CHECK-NEXT: ( 390, 478, 566, 654 )
+ // CHECK-NEXT: ( 670, 822, 974, 1126 )
+ // CHECK-NEXT: ( 950, 1166, 1382, 1598 )
+ vector.print %3 : vector<[4]x[4]xi32>
+
+ return
+}
+
+func.func @test_masked_outerproduct_i8i8i32() {
+ %undef = llvm.mlir.undef : vector<[4]xi8>
+
+ %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+ %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+ %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+ %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+ %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+ %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+ %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+ %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+
+ %mask0 = vector.create_mask %c1, %c1 : vector<[4]x[4]xi1>
+ %mask1 = vector.create_mask %c1, %c2 : vector<[4]x[4]xi1>
+ %mask2 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %mask3 = vector.create_mask %c3, %c4 : vector<[4]x[4]xi1>
+
+ %acc = arith.constant dense<2> : vector<[4]x[4]xi32>
+ %0 = vector.mask %mask0 {
+ vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %1 = vector.mask %mask1 {
+ vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %2 = vector.mask %mask2 {
+ vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %3 = vector.mask %mask3 {
+ vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+
+ // CHECK: ( 112, 136, 135, 95 )
+ // CHECK-NEXT: ( 243, 295, 347, 219 )
+ // CHECK-NEXT: ( 211, 255, 299, 343 )
+ // CHECK-NEXT: ( 2, 2, 2, 2 )
+ vector.print %3 : vector<[4]x[4]xi32>
+
+ return
+}
+
+func.func private @setArmSVLBits(%bits : i32)
>From 63f14fcdb6060cde8189ada55d518b0973a83ee0 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 2 Feb 2024 15:42:13 +0000
Subject: [PATCH 2/8] Address comments. Changes:
- fix check for consistent masking.
- rewrite as loop that walks outer product chain.
- use lambda for match check.
---
.../ArmSME/Transforms/OuterProductFusion.cpp | 175 ++++++------------
1 file changed, 55 insertions(+), 120 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 1f4370aec37a46..df15fe998c41c8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -282,51 +282,38 @@ class OuterProductFusion4Way
LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
PatternRewriter &rewriter) const override {
- Value acc = op.getAcc();
- if (!acc)
- return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-
- arm_sme::OuterProductOp op4 = op;
- arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
- if (!op3)
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
-
- acc = op3.getAcc();
- if (!acc)
- return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-
- arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
- if (!op2)
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
-
- acc = op2.getAcc();
- if (!acc)
- return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
-
- arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
- if (!op1)
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
-
- arm_sme::CombiningKind kind = op1.getKind();
- if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
-
- if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
-
- if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
- bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
- return rewriter.notifyMatchFailure(op,
- MATCH_FAILURE_INCONSISTENT_MASKING);
+ SmallVector<arm_sme::OuterProductOp, 4> outerProductChain;
+ outerProductChain.push_back(op);
+
+ for (int i = 0; i < 3; ++i) {
+ auto currentOp = outerProductChain.back();
+ auto acc = currentOp.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+ auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!previousOp)
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+ if (!previousOp->hasOneUse())
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+ if (previousOp.getKind() != currentOp.getKind())
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ outerProductChain.push_back(previousOp);
+ }
- if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
+ if (failed(canFuseOuterProducts(rewriter, outerProductChain)))
return failure();
+ arm_sme::OuterProductOp op1 = outerProductChain[3];
+ arm_sme::OuterProductOp op2 = outerProductChain[2];
+ arm_sme::OuterProductOp op3 = outerProductChain[1];
+ arm_sme::OuterProductOp op4 = outerProductChain[0];
+
auto loc = op.getLoc();
auto packInputs = [&](Value lhs, Value rhs) {
@@ -364,6 +351,7 @@ class OuterProductFusion4Way
auto lhsExtOp = op.getLhs().getDefiningOp();
auto rhsExtOp = op.getRhs().getDefiningOp();
+ arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
@@ -414,94 +402,41 @@ class OuterProductFusion4Way
// - a floating-point extension for floating-point types.
// - the types and extension are supported, i.e. there's a 4-way operation
// they can be fused into.
- LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op1,
- arm_sme::OuterProductOp op2,
- arm_sme::OuterProductOp op3,
- arm_sme::OuterProductOp op4) const {
+ LogicalResult
+ canFuseOuterProducts(PatternRewriter &rewriter,
+ SmallVectorImpl<arm_sme::OuterProductOp> &ops) const {
// Supported result types.
auto nxnxv4i32 =
VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
auto nxnxv2i64 =
VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+
// Supported input types.
// Note: this is before packing so these have 1/4 the number of elements
// of the input vector types of the 4-way operations.
auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
- if (
- // signed, i8i8i32
- (failed(
- isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // signed, i16i16i64
- (failed(
- isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(
- isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(
- isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
- nxv2i16))) &&
- // unsigned, i8i8i32
- (failed(
- isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(
- isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // unsigned, i16i16i64
- (failed(
- isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(
- isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(
- isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
- nxv2i16))) &&
- // signed by unsigned, i8i8i32
- (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // signed by unsigned, i16i16i64
- (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op4, nxnxv2i64, nxv2i16))) &&
- // unsigned by signed, i8i8i32
- (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // unsigned by signed, i16i16i64
- (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op4, nxnxv2i64, nxv2i16))))
+
+ auto failedToMatch = [&](VectorType resultType, VectorType inputType,
+ auto lhsExtendOp, auto rhsExtendOp) {
+ using LhsExtendOpTy = decltype(lhsExtendOp);
+ using RhsExtendOpTy = decltype(rhsExtendOp);
+ for (auto op : ops) {
+ if (failed(isCompatible<LhsExtendOpTy, RhsExtendOpTy>(
+ rewriter, op, resultType, inputType)))
+ return true;
+ }
+ return false;
+ };
+
+ if (failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
+ failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
+ failedToMatch(nxnxv4i32, nxv4i8, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
+ failedToMatch(nxnxv4i32, nxv4i8, arith::ExtUIOp{}, arith::ExtSIOp{}) &&
+ failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtSIOp{}) &&
+ failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtUIOp{}) &&
+ failedToMatch(nxnxv2i64, nxv2i16, arith::ExtSIOp{}, arith::ExtUIOp{}) &&
+ failedToMatch(nxnxv2i64, nxv2i16, arith::ExtUIOp{}, arith::ExtSIOp{}))
return failure();
return success();
>From df11371c3ccf18df56252fcc797088f614ff2f3a Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 5 Feb 2024 12:33:22 +0000
Subject: [PATCH 3/8] Address comments
add comment to clarify each variant.
---
.../ArmSME/Transforms/OuterProductFusion.cpp | 36 +++++++++++++------
1 file changed, 26 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index df15fe998c41c8..9196ea8bdc84f3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -353,35 +353,51 @@ class OuterProductFusion4Way
arm_sme::CombiningKind kind = op.getKind();
if (kind == arm_sme::CombiningKind::Add) {
- if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
+ // signed
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
+ isa<arith::ExtUIOp>(rhsExtOp)) {
+ // unsigned
rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
+ isa<arith::ExtUIOp>(rhsExtOp)) {
+ // signed by unsigned
rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
+ isa<arith::ExtSIOp>(rhsExtOp)) {
+ // unsigned by signed
rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
+ } else {
llvm_unreachable("unexpected extend op!");
+ }
} else if (kind == arm_sme::CombiningKind::Sub) {
- if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
+ // signed
rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
+ isa<arith::ExtUIOp>(rhsExtOp)) {
+ // unsigned
rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ } else if (isa<arith::ExtSIOp>(lhsExtOp) &&
+ isa<arith::ExtUIOp>(rhsExtOp)) {
+ // signed by unsigned
rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) &&
+ isa<arith::ExtSIOp>(rhsExtOp)) {
+ // unsigned by signed
rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
+ } else {
llvm_unreachable("unexpected extend op!");
+ }
} else {
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
>From bd0d06705a76e0236ff0917f6a05d3c4c3e2d554 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 5 Feb 2024 14:19:24 +0000
Subject: [PATCH 4/8] Rebase and fix op descriptions
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 24 +++----
.../Dialect/ArmSME/outer-product-fusion.mlir | 66 +++++++++----------
2 files changed, 39 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 08305973b1ee08..c9640bbae69d14 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1235,12 +1235,11 @@ def SMopa4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
| Spec | Features |
| ---- | -------- |
| [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1262,6 +1261,7 @@ def SMops4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1269,8 +1269,6 @@ def SMops4WayOp
| Spec | Features |
| ---- | -------- |
| [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1289,6 +1287,7 @@ def UMopa4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1296,8 +1295,6 @@ def UMopa4WayOp
| Spec | Features |
| ---- | -------- |
| [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1316,6 +1313,7 @@ def UMops4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1323,8 +1321,6 @@ def UMops4WayOp
| Spec | Features |
| ---- | -------- |
| [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1343,6 +1339,7 @@ def SuMopa4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1350,8 +1347,6 @@ def SuMopa4WayOp
| Spec | Features |
| ---- | -------- |
| [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1370,6 +1365,7 @@ def SuMops4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1377,8 +1373,6 @@ def SuMops4WayOp
| Spec | Features |
| ---- | -------- |
| [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1397,6 +1391,7 @@ def UsMopa4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1404,8 +1399,6 @@ def UsMopa4WayOp
| Spec | Features |
| ---- | -------- |
| [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
@@ -1424,6 +1417,7 @@ def UsMops4WayOp
Example: I16 to I64
```mlir
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ ```
Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
detailed description of 4-way outer products.
@@ -1431,8 +1425,6 @@ def UsMops4WayOp
| Spec | Features |
| ---- | -------- |
| [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
-
- ```
}];
}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 5e5b1905047368..6286334ef9a5f6 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -1180,7 +1180,37 @@ func.func @outerproduct_widening_2way__bad_defining_op(
return %1 : vector<[4]x[4]xf32>
}
-<<<<<<< HEAD
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_defining_op(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
/// Negative tests for related patterns.
@@ -1233,37 +1263,3 @@ func.func @scalable_extract_from_non_arith_ext(%src: vector<[8]xf32>) -> vector<
%0 = vector.scalable.extract %src[0] : vector<[4]xf32> from vector<[8]xf32>
return %0 : vector<[4]xf32>
}
-||||||| constructed merge base
-=======
-
-// -----
-
-// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
-// CHECK-NOT: arm_sme.fmopa_4way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_4way
-func.func @outerproduct_widening_4way__bad_defining_op(
- %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
- %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
- %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
- %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
- %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
- %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
-
- %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
- %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
-
- %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
- %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
- %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
-
- return %3 : vector<[4]x[4]xi32>
-}
->>>>>>> [mlir][ArmSME] Support 4-way widening outer products
>From aa6733d2427eaf25d6e666d9074e162b8e94adb8 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 6 Feb 2024 10:36:25 +0000
Subject: [PATCH 5/8] Fix broken doc links to smopa_4way
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index c9640bbae69d14..239c4beab10d2a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1263,7 +1263,7 @@ def SMops4WayOp
%result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1289,7 +1289,7 @@ def UMopa4WayOp
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1315,7 +1315,7 @@ def UMops4WayOp
%result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1341,7 +1341,7 @@ def SuMopa4WayOp
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1367,7 +1367,7 @@ def SuMops4WayOp
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1393,7 +1393,7 @@ def UsMopa4WayOp
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
@@ -1419,7 +1419,7 @@ def UsMops4WayOp
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```
- Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.
| Spec | Features |
>From 02382d7a7380510571948f6bf6c436c6e8631a08 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 6 Feb 2024 15:27:03 +0000
Subject: [PATCH 6/8] Address comments
---
.../Dialect/ArmSME/outer-product-fusion.mlir | 221 +++++++++---------
1 file changed, 114 insertions(+), 107 deletions(-)
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 6286334ef9a5f6..de9de86003e610 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -216,8 +216,14 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
// -----
// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
-// CHECK-SAME: %[[A0:.*]]: vector<[4]xi8>, %[[B0:.*]]: vector<[4]xi8>, %[[A1:.*]]: vector<[4]xi8>, %[[B1:.*]]: vector<[4]xi8>, %[[A2:.*]]: vector<[4]xi8>, %[[B2:.*]]: vector<[4]xi8>, %[[A3:.*]]: vector<[4]xi8>, %[[B3:.*]]: vector<[4]xi8>,
-// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>, %[[A2_MASK:.*]]: vector<[4]xi1>, %[[B2_MASK:.*]]: vector<[4]xi1>, %[[A3_MASK:.*]]: vector<[4]xi1>, %[[B3_MASK:.*]]: vector<[4]xi1>
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[4]xi8>, %[[B0:[a-z0-9]+]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[4]xi8>, %[[B1:[a-z0-9]+]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[4]xi8>, %[[B2:[a-z0-9]+]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[4]xi8>, %[[B3:[a-z0-9]+]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A0_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B0_MASK:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[A1_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B1_MASK:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
@@ -521,24 +527,24 @@ func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32(
%a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
%a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
%a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
- %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
- %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
- %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
- %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
- %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
- %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
- %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
- %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
%acc = arith.constant dense<0> : vector<[4]x[4]xi32>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
return %3 : vector<[4]x[4]xi32>
}
@@ -556,24 +562,24 @@ func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32(
%a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
%a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
%a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
- %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
- %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
- %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
- %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
- %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
- %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
- %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
- %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
%acc = arith.constant dense<0> : vector<[4]x[4]xi32>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
return %3 : vector<[4]x[4]xi32>
}
@@ -591,24 +597,24 @@ func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64(
%a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
%a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
%a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
- %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
- %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+ %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
- %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
- %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+ %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
- %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
- %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+ %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
- %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
- %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+ %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
%acc = arith.constant dense<0> : vector<[2]x[2]xi64>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
return %3 : vector<[2]x[2]xi64>
}
@@ -626,24 +632,24 @@ func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64(
%a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
%a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
%a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
- %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
- %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+ %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
- %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
- %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+ %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
- %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
- %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+ %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
- %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
- %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+ %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
%acc = arith.constant dense<0> : vector<[2]x[2]xi64>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
return %3 : vector<[2]x[2]xi64>
}
@@ -661,24 +667,24 @@ func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32(
%a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
%a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
%a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
- %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
- %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
- %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
- %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
- %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
- %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
- %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
- %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
%acc = arith.constant dense<0> : vector<[4]x[4]xi32>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
return %3 : vector<[4]x[4]xi32>
}
@@ -696,24 +702,24 @@ func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32(
%a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
%a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
%a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
- %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
- %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
- %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
- %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
- %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
- %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
- %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
- %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
%acc = arith.constant dense<0> : vector<[4]x[4]xi32>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
return %3 : vector<[4]x[4]xi32>
}
@@ -731,24 +737,24 @@ func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64(
%a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
%a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
%a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
- %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
- %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+ %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
- %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
- %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+ %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
- %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
- %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+ %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
- %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
- %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+ %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
%acc = arith.constant dense<0> : vector<[2]x[2]xi64>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
return %3 : vector<[2]x[2]xi64>
}
@@ -766,24 +772,24 @@ func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
%a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
%a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
%a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
- %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
- %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+ %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
- %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
- %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+ %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
- %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
- %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+ %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
- %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
- %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+ %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
%acc = arith.constant dense<0> : vector<[2]x[2]xi64>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
- %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
- %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
return %3 : vector<[2]x[2]xi64>
}
@@ -894,14 +900,14 @@ func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vect
// -----
-// CHECK-LABEL: @outerproduct_widening_4way__bad_acc
+// CHECK-LABEL: @outerproduct_widening_4way__missing_acc
// CHECK-NOT: arm_sme.fmopa_4way
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK-NOT: arm_sme.fmopa_4way
-func.func @outerproduct_widening_4way__bad_acc(
+func.func @outerproduct_widening_4way__missing_acc(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
%a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
%a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
@@ -921,7 +927,7 @@ func.func @outerproduct_widening_4way__bad_acc(
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
- // break chain
+ // Missing accumulator breaks use-def chain.
%3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
return %3 : vector<[4]x[4]xi32>
@@ -952,14 +958,14 @@ func.func @outerproduct_widening_2way__bad_combining_kind(
// -----
-// CHECK-LABEL: @outerproduct_widening_4way__bad_combining_kind
+// CHECK-LABEL: @outerproduct_widening_4way__inconsistent_combining_kind
// CHECK-NOT: arm_sme.fmopa_4way
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK-NOT: arm_sme.fmopa_4way
-func.func @outerproduct_widening_4way__bad_combining_kind(
+func.func @outerproduct_widening_4way__inconsistent_combining_kind(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
%a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
%a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
@@ -1016,14 +1022,14 @@ func.func @outerproduct_widening_2way__cant_erase(
// -----
-// CHECK-LABEL: @outerproduct_widening_4way__cant_erase
+// CHECK-LABEL: @outerproduct_widening_4way__multi_use_cant_erase
// CHECK-NOT: arm_sme.fmopa_4way
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK-NOT: arm_sme.fmopa_4way
-func.func @outerproduct_widening_4way__cant_erase(
+func.func @outerproduct_widening_4way__multi_use_cant_erase(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
%a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
%a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
@@ -1130,14 +1136,14 @@ func.func @outerproduct_widening_2way__bad_masking(
// -----
-// CHECK-LABEL: @outerproduct_widening_4way__bad_masking
+// CHECK-LABEL: @outerproduct_widening_4way__inconsistent_masking
// CHECK-NOT: arm_sme.fmopa_4way
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK: arm_sme.outerproduct
// CHECK-NOT: arm_sme.fmopa_4way
-func.func @outerproduct_widening_4way__bad_masking(
+func.func @outerproduct_widening_4way__inconsistent_masking(
%a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
%a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
%a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
@@ -1206,6 +1212,7 @@ func.func @outerproduct_widening_4way__bad_defining_op(
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ /// Inputs must come from an arith.ext.
%2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
%3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
>From 513c2b1849e014ce210379b42dc5e8a1950044d2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 6 Feb 2024 15:45:11 +0000
Subject: [PATCH 7/8] Address comments
---
.../ArmSME/Transforms/OuterProductFusion.cpp | 67 +++++++++----------
1 file changed, 33 insertions(+), 34 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 9196ea8bdc84f3..5f4ce24c848d79 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -34,17 +34,17 @@ namespace {
// Common match failure reasons.
static constexpr StringLiteral
- MATCH_FAILURE_NO_ACCUMULATOR("no accumulator operand");
-static constexpr StringLiteral MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP(
+ matchFailureNoAccumulator("no accumulator operand");
+static constexpr StringLiteral matchFailureExpectedOuterProductDefOp(
"defining op of accumulator must be 'arm_sme.outerproduct'");
-static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_COMBINING_KIND(
+static constexpr StringLiteral matchFailureInconsistentCombiningKind(
"combining kind (add or sub) of outer products must match");
-static constexpr StringLiteral MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE(
- "outer product(s) not single use and cannot be removed, no benefit to "
- "fusing");
-static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_MASKING(
+static constexpr StringLiteral matchFailureInconsistentMasking(
"unsupported masking, either both outerproducts are masked "
"or neither");
+static constexpr StringLiteral matchFailureOuterProductNotSingleUse(
+ "outer product(s) not single use and cannot be removed, no benefit to "
+ "fusing");
// An outer product is compatible if all of the following are true:
// - the result type matches `resultType`.
@@ -80,6 +80,16 @@ static LogicalResult isCompatible(PatternRewriter &rewriter,
return success();
}
+// Create 'llvm.experimental.vector.interleave2' intrinsic from `lhs` and `rhs`.
+static Value createInterleave2Intrinsic(RewriterBase &rewriter, Location loc,
+ Value lhs, Value rhs) {
+ auto inputType = cast<VectorType>(lhs.getType());
+ VectorType inputTypeX2 =
+ VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+ return rewriter.create<LLVM::experimental_vector_interleave2>(
+ loc, inputTypeX2, lhs, rhs);
+}
+
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
@@ -112,17 +122,17 @@ class OuterProductFusion2Way
PatternRewriter &rewriter) const override {
Value acc = op.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+ return rewriter.notifyMatchFailure(op, matchFailureNoAccumulator);
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
arm_sme::OuterProductOp op2 = op;
if (!op1)
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+ return rewriter.notifyMatchFailure(op,
+ matchFailureExpectedOuterProductDefOp);
if (op1.getKind() != op2.getKind())
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ return rewriter.notifyMatchFailure(op,
+ matchFailureInconsistentCombiningKind);
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
@@ -148,25 +158,19 @@ class OuterProductFusion2Way
//
// No accumulator would be ok, but it's simpler to prevent this
// altogether, since it has no benefit.
- return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+ return rewriter.notifyMatchFailure(op,
+ matchFailureOuterProductNotSingleUse);
}
if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
- return rewriter.notifyMatchFailure(op,
- MATCH_FAILURE_INCONSISTENT_MASKING);
+ return rewriter.notifyMatchFailure(op, matchFailureInconsistentMasking);
if (failed(canFuseOuterProducts(rewriter, op1, op2)))
return failure();
auto loc = op.getLoc();
-
auto packInputs = [&](Value lhs, Value rhs) {
- auto inputType = cast<VectorType>(lhs.getType());
- VectorType inputTypeX2 =
- VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
- return rewriter.create<LLVM::experimental_vector_interleave2>(
- loc, inputTypeX2, lhs, rhs);
+ return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
};
auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -289,20 +293,20 @@ class OuterProductFusion4Way
auto currentOp = outerProductChain.back();
auto acc = currentOp.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
+ return rewriter.notifyMatchFailure(op, matchFailureNoAccumulator);
auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
if (!previousOp)
return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
+ op, matchFailureExpectedOuterProductDefOp);
if (!previousOp->hasOneUse())
return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
+ op, matchFailureOuterProductNotSingleUse);
if (previousOp.getKind() != currentOp.getKind())
return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ op, matchFailureInconsistentCombiningKind);
if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
return rewriter.notifyMatchFailure(
- op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
+ op, matchFailureInconsistentCombiningKind);
outerProductChain.push_back(previousOp);
}
@@ -315,13 +319,8 @@ class OuterProductFusion4Way
arm_sme::OuterProductOp op4 = outerProductChain[0];
auto loc = op.getLoc();
-
auto packInputs = [&](Value lhs, Value rhs) {
- auto inputType = cast<VectorType>(lhs.getType());
- VectorType inputTypeX2 =
- VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
- return rewriter.create<LLVM::experimental_vector_interleave2>(
- loc, inputTypeX2, lhs, rhs);
+ return createInterleave2Intrinsic(rewriter, loc, lhs, rhs);
};
auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
@@ -420,7 +419,7 @@ class OuterProductFusion4Way
// they can be fused into.
LogicalResult
canFuseOuterProducts(PatternRewriter &rewriter,
- SmallVectorImpl<arm_sme::OuterProductOp> &ops) const {
+ ArrayRef<arm_sme::OuterProductOp> ops) const {
// Supported result types.
auto nxnxv4i32 =
VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
>From fd24f68c3399c4d741579e534ffa2e689313e373 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 6 Feb 2024 16:41:25 +0000
Subject: [PATCH 8/8] Address comments
---
.../ArmSME/Transforms/OuterProductFusion.cpp | 34 +++++++++----------
1 file changed, 17 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 5f4ce24c848d79..d3751d4ba7e73b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -34,15 +34,15 @@ namespace {
// Common match failure reasons.
static constexpr StringLiteral
- matchFailureNoAccumulator("no accumulator operand");
-static constexpr StringLiteral matchFailureExpectedOuterProductDefOp(
+ kMatchFailureNoAccumulator("no accumulator operand");
+static constexpr StringLiteral kMatchFailureExpectedOuterProductDefOp(
"defining op of accumulator must be 'arm_sme.outerproduct'");
-static constexpr StringLiteral matchFailureInconsistentCombiningKind(
+static constexpr StringLiteral kMatchFailureInconsistentCombiningKind(
"combining kind (add or sub) of outer products must match");
-static constexpr StringLiteral matchFailureInconsistentMasking(
+static constexpr StringLiteral kMatchFailureInconsistentMasking(
"unsupported masking, either both outerproducts are masked "
"or neither");
-static constexpr StringLiteral matchFailureOuterProductNotSingleUse(
+static constexpr StringLiteral kMatchFailureOuterProductNotSingleUse(
"outer product(s) not single use and cannot be removed, no benefit to "
"fusing");
@@ -122,17 +122,17 @@ class OuterProductFusion2Way
PatternRewriter &rewriter) const override {
Value acc = op.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, matchFailureNoAccumulator);
+ return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
arm_sme::OuterProductOp op2 = op;
if (!op1)
- return rewriter.notifyMatchFailure(op,
- matchFailureExpectedOuterProductDefOp);
+ return rewriter.notifyMatchFailure(
+ op, kMatchFailureExpectedOuterProductDefOp);
if (op1.getKind() != op2.getKind())
- return rewriter.notifyMatchFailure(op,
- matchFailureInconsistentCombiningKind);
+ return rewriter.notifyMatchFailure(
+ op, kMatchFailureInconsistentCombiningKind);
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
@@ -159,11 +159,11 @@ class OuterProductFusion2Way
// No accumulator would be ok, but it's simpler to prevent this
// altogether, since it has no benefit.
return rewriter.notifyMatchFailure(op,
- matchFailureOuterProductNotSingleUse);
+ kMatchFailureOuterProductNotSingleUse);
}
if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
- return rewriter.notifyMatchFailure(op, matchFailureInconsistentMasking);
+ return rewriter.notifyMatchFailure(op, kMatchFailureInconsistentMasking);
if (failed(canFuseOuterProducts(rewriter, op1, op2)))
return failure();
@@ -293,20 +293,20 @@ class OuterProductFusion4Way
auto currentOp = outerProductChain.back();
auto acc = currentOp.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, matchFailureNoAccumulator);
+ return rewriter.notifyMatchFailure(op, kMatchFailureNoAccumulator);
auto previousOp = acc.getDefiningOp<arm_sme::OuterProductOp>();
if (!previousOp)
return rewriter.notifyMatchFailure(
- op, matchFailureExpectedOuterProductDefOp);
+ op, kMatchFailureExpectedOuterProductDefOp);
if (!previousOp->hasOneUse())
return rewriter.notifyMatchFailure(
- op, matchFailureOuterProductNotSingleUse);
+ op, kMatchFailureOuterProductNotSingleUse);
if (previousOp.getKind() != currentOp.getKind())
return rewriter.notifyMatchFailure(
- op, matchFailureInconsistentCombiningKind);
+ op, kMatchFailureInconsistentCombiningKind);
if (bool(previousOp.getLhsMask()) != bool(currentOp.getLhsMask()))
return rewriter.notifyMatchFailure(
- op, matchFailureInconsistentCombiningKind);
+ op, kMatchFailureInconsistentCombiningKind);
outerProductChain.push_back(previousOp);
}
More information about the Mlir-commits
mailing list