[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Jan 30 07:07:07 PST 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/79288
>From 601bbaee30209422c4d1ecdf01fdfbc19502e3c7 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 12 Dec 2023 15:03:34 +0000
Subject: [PATCH 1/6] [mlir][ArmSME] Support 2-way widening outer products
This patch introduces support for 2-way widening outer products. This
enables the folding of 2 'arm_sme.outerproduct' operations that are
chained via the accumulator into single widened operations.
Changes:
- Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants.
These map to instruction variants added in SME2 and use different
intrinsics. Intrinsics are already implemented for widening variants
from SME1.
- Adds the following operations:
- fmopa_wide_2way, fmops_wide_2way
- smopa_wide_2way, smops_wide_2way
- umopa_wide_2way, umops_wide_2way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Adds a pass 'arm-sme-outer-product' widening that folds
'arm_sme.outerproduct' operations.
For a detailed description of these operations see the
'arm_sme.fmopa_wide_2way' description.
The reason for introducing many operations rather than one is the
signed/unsigned variants can't be distinguished with types (e.g., ui16,
si16) since 'arith.extui' and 'arith.extsi' only support signless
integers. A single operation would require this information and an
attribute (for example) for the sign doesn't feel right if
floating-point types are also supported where this wouldn't apply.
Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32)
introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but
no subtract variant. Whilst these are not supported in this patch, it
felt simpler to have separate ops for add/subtract given this.
---
.../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td | 4 +
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 294 ++++++++++++++++++
.../mlir/Dialect/ArmSME/Transforms/Passes.h | 3 +
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 39 +++
.../Dialect/ArmSME/Transforms/Transforms.h | 4 +
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 66 +++-
.../Dialect/ArmSME/Transforms/CMakeLists.txt | 2 +
.../Transforms/OuterProductWidening.cpp | 238 ++++++++++++++
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 96 ++++++
mlir/test/Dialect/ArmSME/invalid.mlir | 53 ++++
.../ArmSME/outer-product-widening.mlir | 192 ++++++++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 112 +++++++
.../ArmSME/test-outerproduct-f16f16f32.mlir | 100 ++++++
mlir/test/Target/LLVMIR/arm-sme.mlir | 12 +
14 files changed, 1213 insertions(+), 2 deletions(-)
create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
create mode 100644 mlir/test/Dialect/ArmSME/outer-product-widening.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc..f051e03efbcda 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
class ArmSME_IntrLoadStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012..3544df494d33d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -814,6 +814,300 @@ let arguments = (ins
}];
}
+class OuterProductWideBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
+ ArmSME_Op<mnemonic, [
+ ArmSMETileOpInterface,
+ AttrSizedOperandSegments,
+ AllTypesMatch<["lhs", "rhs"]>,
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+ PredOpTrait<
+ "both `lhsMask` and `rhsMask` should be provided or neither",
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+ >,
+ OptionalTypesMatchWith<"result and acc have the same type",
+ "result", "acc", "::llvm::cast<Type>($_self)">,
+ // this trait ensures the input type match the correct output type for ops
+ // that takes multiple inputs and outputs (i.e., 4-way).
+ PredOpTrait<
+ "tile element size equals lhs element size * " # numOuterProducts,
+ CPred<"getTileType().getElementTypeBitWidth() == "
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+ >,
+ ]> {
+
+ let arguments = (ins
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+ Optional<AnyVector>:$acc);
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ oilist(
+ `acc` `` `(` $acc `)`
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+ // The outerproduct op allocates a new tile if no accumulator is passed.
+ if (!getAcc())
+ return arm_sme::getSMETileType(getResultType());
+ return std::nullopt;
+ }
+ VectorType getTileType() {
+ return getResultType();
+ }
+ }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+def FMopaWide2WayOp
+ : OuterProductWide2Way<"fmopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and accumulate";
+
+ let description = [{
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (fp16 to fp32):
+
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS RHS
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1] |
+ [A2 A3] | [B0 B2 B4 B6]
+ [A4 A5] | [B1 B3 B5 B7]
+ [A6 A7] |
+
+ ----------------------------------------------------------------------------
+
+ 2 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
+ |
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 2 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 2 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ The 2 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ ```
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+
+ [1] https://developer.arm.com/documentation/ddi0616
+ }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+ : OuterProductWide2Way<"fmops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+ [nxnxv4f32]> {
+ let summary = "Floating-point sum of 2 outer products and subtract";
+ let description = [{
+ Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+ destination `result`.
+
+ Example: FP16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ Example: BF16 to FP32
+ ```mlir
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+ ```
+ }];
+}
+
+def SMopaWide2WayOp
+ : OuterProductWide2Way<"smopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def SMopsWide2WayOp
+ : OuterProductWide2Way<"smops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Signed integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopaWide2WayOp
+ : OuterProductWide2Way<"umopa_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+ ```
+ }];
+}
+
+def UMopsWide2WayOp
+ : OuterProductWide2Way<"umops_wide_2way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32]> {
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
+ let description = [{
+ Example:
+ ```mlir
+ %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+ Refer to
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+ detailed description of 2-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+ ```
+ }];
+}
+
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7c..d3e4fccd62848 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
+/// Pass that folds 'arm_sme.outerproduct' ops into widening variants.
+std::unique_ptr<Pass> createOuterProductWideningPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e80..aa9ad9b2e3340 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,43 @@ def TileAllocation
let dependentDialects = ["func::FuncDialect"];
}
+def OuterProductWidening
+ : Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
+ let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
+ let description = [{
+ This pass folds 'arm_sme.outerproduct' operations that are chained via the
+ accumulator into 2-way or 4-way ArmSME outer product operations.
+
+ For example:
+ ```mlir
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+ ```
+
+ Becomes:
+
+ ```mlir
+ %undef = llvm.mlir.undef : vector<[8]xf16>
+ %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ ```
+
+ For further information on the widening ops see:
+ https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop
+ https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_wide_4way-arm_smesmopa_wide_4wayop
+ }];
+ let constructor = "mlir::arm_sme::createOuterProductWideningPass()";
+ let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect", "LLVM::LLVMDialect"];
+}
+
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index f622bc0562e9e..09e3b4375fa5f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -15,6 +15,10 @@ class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
+namespace arm_sme {
+void populateOuterProductWideningPatterns(RewritePatternSet &patterns);
+} // namespace arm_sme
+
} // namespace mlir
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index bbef3b996e40b..0871658bc3653 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -776,6 +776,49 @@ struct OuterProductOpConversion
}
};
+/// Lower 2-way and 4-way outer products to intrinsics.
+template <class OuterProductWideOp, class OuterProductWideIntrOp>
+struct OuterProductWideOpConversion
+ : public ConvertArmSMEOpToLLVMPattern<OuterProductWideOp> {
+ using ConvertArmSMEOpToLLVMPattern<
+ OuterProductWideOp>::ConvertArmSMEOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(OuterProductWideOp op,
+ typename OuterProductWideOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tileId = getTileIdOrError(op);
+ if (!tileId)
+ return failure();
+
+ Value acc = op.getAcc();
+ if (!acc)
+ // Initalize accumulator with zero.
+ acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
+ rewriter, op.getLoc(), op.getResultType());
+
+ Value lhsMask = op.getLhsMask();
+ Value rhsMask = op.getRhsMask();
+ if (!lhsMask || !rhsMask) {
+ auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
+ Value allActiveMask = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), DenseElementsAttr::get(predTy, true));
+ lhsMask = allActiveMask;
+ rhsMask = allActiveMask;
+ }
+
+ rewriter.create<OuterProductWideIntrOp>(op.getLoc(), tileId, lhsMask,
+ rhsMask, adaptor.getLhs(),
+ adaptor.getRhs());
+
+ // The outerproduct intrinsics have no result, replace
+ // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
+ rewriter.replaceOp(op, acc);
+
+ return success();
+ }
+};
+
/// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
///
/// Example:
@@ -854,6 +897,13 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
+ arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
+ arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
+ arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
+ arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
+ arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
+ arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
+ arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect,
@@ -876,8 +926,20 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
addArmSMEConversionPatterns<
LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
- OuterProductOpConversion, ZeroOpConversion, GetTileConversion,
- StreamingVLOpConversion>(patterns, converter);
+ StreamingVLOpConversion, OuterProductOpConversion,
+ OuterProductWideOpConversion<arm_sme::FMopaWide2WayOp,
+ arm_sme::aarch64_sme_mopa_wide>,
+ OuterProductWideOpConversion<arm_sme::FMopsWide2WayOp,
+ arm_sme::aarch64_sme_mops_wide>,
+ OuterProductWideOpConversion<arm_sme::SMopaWide2WayOp,
+ arm_sme::aarch64_sme_smopa_za32>,
+ OuterProductWideOpConversion<arm_sme::SMopsWide2WayOp,
+ arm_sme::aarch64_sme_smops_za32>,
+ OuterProductWideOpConversion<arm_sme::UMopaWide2WayOp,
+ arm_sme::aarch64_sme_umopa_za32>,
+ OuterProductWideOpConversion<arm_sme::UMopsWide2WayOp,
+ arm_sme::aarch64_sme_umops_za32>,
+ ZeroOpConversion, GetTileConversion>(patterns, converter);
}
std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb584420438..24942b6f28d2c 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
+ OuterProductWidening.cpp
TileAllocation.cpp
ADDITIONAL_HEADER_DIRS
@@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
LINK_LIBS PUBLIC
MLIRArmSMEDialect
+ MLIRArmSVEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
new file mode 100644
index 0000000000000..935ed63c84c68
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
@@ -0,0 +1,238 @@
+//===- OuterProductWidening.cpp - Widen 'arm_sme.outerproduct' ops --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that fold 'arm_sme.outerproduct' operations
+// into the 2-way or 4-way widening outerproduct operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "arm-sme-outerproduct-widening"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_OUTERPRODUCTWIDENING
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+// Fold two 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 2-way outer product operation.
+//
+// For example:
+//
+// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// Becomes:
+//
+// %a_packed = arm_sve.zip %a0, %a1 : vector<[8]xf16> to vector<[8]xf16>
+// %b_packed = arm_sve.zip %b0, %b1 : vector<[8]xf16> to vector<[8]xf16>
+// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>,
+// vector<[4]xf32>
+class OuterProduct2WayWidening
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ Value acc = op.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ arm_sme::OuterProductOp op2 = op;
+ if (!op1)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ if (op1.getKind() != op2.getKind())
+ return rewriter.notifyMatchFailure(
+ op, "combining kind (add or sub) of outer products must match");
+
+ if (!llvm::hasSingleElement(op1->getUses())) {
+ // We could still widen, but if the first outer product has an
+ // accumulator it will be used as the root for tile allocation and since
+ // the widening outer product uses the same accumulator it will get
+ // assigned the same tile ID, resulting in 3 outer products and incorrect
+ // results. No accumulator would be ok, but it's simpler to prevent this
+ // altogether, since it has no benefit.
+ return rewriter.notifyMatchFailure(
+ op, "first outer product is not single use and cannot be removed, "
+ "no benefit to widening");
+ }
+
+ auto nxnxv4i32 =
+ VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+ auto nxnxv4f32 =
+ VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
+ auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
+ auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
+ auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
+ if ((failed(
+ isWidenable<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
+ failed(
+ isWidenable<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
+ (failed(
+ isWidenable<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
+ failed(
+ isWidenable<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4bf16))) &&
+ (failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i16))) &&
+ (failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
+ return failure();
+
+ auto loc = op.getLoc();
+
+ // zip(lhs, rhs)
+ auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
+ auto undef = rewriter.create<LLVM::UndefOp>(loc, type);
+ auto insertLHS =
+ rewriter.create<vector::ScalableInsertOp>(loc, lhs, undef, 0);
+ auto insertRHS =
+ rewriter.create<vector::ScalableInsertOp>(loc, rhs, undef, 0);
+ return rewriter.create<arm_sve::Zip1IntrOp>(loc, type, insertLHS,
+ insertRHS);
+ };
+
+ auto extOp = op.getLhs().getDefiningOp();
+ VectorType extSourceVectorType =
+ cast<VectorType>(extOp->getOperand(0).getType());
+ VectorType widenedVectorType =
+ VectorType::Builder(extSourceVectorType)
+ .setDim(0, extSourceVectorType.getShape()[0] * 2);
+ auto lhs = packInputs(widenedVectorType,
+ op1.getLhs().getDefiningOp()->getOperand(0),
+ op2.getLhs().getDefiningOp()->getOperand(0));
+ auto rhs = packInputs(widenedVectorType,
+ op1.getRhs().getDefiningOp()->getOperand(0),
+ op2.getRhs().getDefiningOp()->getOperand(0));
+
+ Value lhsMask, rhsMask;
+ if (op1.getLhsMask() || op2.getLhsMask()) {
+ if (!(op1.getLhsMask() && op2.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, "unsupported masking, either both outerproducts are masked "
+ "or neither");
+
+ VectorType maskType = VectorType::Builder(widenedVectorType)
+ .setElementType(rewriter.getI1Type());
+ lhsMask = packInputs(maskType, op1.getLhsMask(), op2.getLhsMask());
+ rhsMask = packInputs(maskType, op1.getRhsMask(), op2.getRhsMask());
+ }
+
+ arm_sme::CombiningKind kind = op.getKind();
+ assert((kind == arm_sme::CombiningKind::Add ||
+ kind == arm_sme::CombiningKind::Sub) &&
+ "unhandled arm_sme::CombiningKind!");
+
+ if (isa<arith::ExtFOp>(extOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::FMopaWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::FMopsWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else if (isa<arith::ExtSIOp>(extOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::SMopaWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::SMopsWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else if (isa<arith::ExtUIOp>(extOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::UMopaWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::UMopsWide2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else
+ llvm_unreachable("unexpected extend op!");
+
+ op1.erase();
+
+ return success();
+ }
+
+private:
+ template <typename ExtOp>
+ LogicalResult isWidenable(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op, VectorType resultType,
+ VectorType inputType) const {
+ if (op.getResultType() != resultType)
+ return rewriter.notifyMatchFailure(
+ op, "unsupported result type, expected 'vector<[4]x[4]xi32>' or "
+ "'vector<[4]x[4]xf32>'");
+
+ auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
+ auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
+
+ if (!lhsDefOp || !rhsDefOp)
+ return rewriter.notifyMatchFailure(
+ op, "defining op of outerproduct operands must be 'arith.extf' or "
+ "'arith.extsi' or 'arith.extui'");
+
+ auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
+ auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
+
+ if (lhsInType != inputType || rhsInType != inputType)
+ return rewriter.notifyMatchFailure(
+ op, "unsupported input types, expected 'vector<[4]xi16>' or "
+ "'vector<[4]xf16>' or 'vector<[4]xbf16>'");
+ return success();
+ }
+};
+
+struct OuterProductWideningPass
+ : public arm_sme::impl::OuterProductWideningBase<OuterProductWideningPass> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateOuterProductWideningPatterns(patterns);
+
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::arm_sme::populateOuterProductWideningPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<OuterProduct2WayWidening>(patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::arm_sme::createOuterProductWideningPass() {
+ return std::make_unique<OuterProductWideningPass>();
+}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index f9cf77ca15ffb..e0e0d90d8e85f 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -601,3 +601,99 @@ func.func @arm_sme_streaming_vl_double_words() -> index {
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.fmopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_fmopa_wide_2way_f16f16_to_f32
+// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+func.func @arm_sme_fmopa_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_fmopa_wide_2way_bf16bf16_to_f32
+// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_fmopa_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.fmops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_fmops_wide_2way_f16f16_to_f32
+// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+func.func @arm_sme_fmops_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_fmops_wide_2way_bf16bf16_to_f32
+// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_fmops_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smopa_wide_2way_i16i16_to_i32
+// CHECK: "arm_sme.intr.smopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smops_wide_2way_i16i16_to_i32
+// CHECK: "arm_sme.intr.smops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umopa_wide_2way_i16i16_to_i32
+// CHECK: "arm_sme.intr.umopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umops_wide_2way_i16i16_to_i32
+// CHECK: "arm_sme.intr.umops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 85b95a8b6cf12..1f63de927ea00 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -173,3 +173,56 @@ func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB:
%0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
return %0 : vector<[4]x[4]xf32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.fmopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way__bad_rhs_vector_type(%vecA: vector<[8]xf16>, %vecB: vector<[4]xf32>) -> vector<[4]x[4]xf32>
+{
+ // expected-error at +1 {{op failed to verify that all of {lhs, rhs} have same type}}
+ %0 = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[4]xf32> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way__bad_lhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[4]xi1>, %maskB : vector<[8]xi1>) -> vector<[4]x[4]xf32>
+{
+ // expected-note at -2 {{prior use here}}
+ // expected-error at +1 {{use of value '%maskA' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}}
+ %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way__bad_rhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>, %maskB : vector<[4]xi1>) -> vector<[4]x[4]xf32>
+{
+ // expected-note at -2 {{prior use here}}
+ // expected-error at +1 {{use of value '%maskB' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}}
+ %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way__no_rhs_mask(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>) -> vector<[4]x[4]xf32>
+{
+ // expected-error at +1 {{op failed to verify that both `lhsMask` and `rhsMask` should be provided or neither}}
+ %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA,) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32>
+{
+ %acc = arm_sme.zero : vector<[2]x[2]xi64>
+ // expected-note at -1 {{prior use here}}
+ // expected-error at +1 {{use of value '%acc' expects different type than prior uses: 'vector<[4]x[4]xf32>' vs 'vector<[2]x[2]xi64>'}}
+ %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %0 : vector<[4]x[4]xf32>
+}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-widening.mlir b/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
new file mode 100644
index 0000000000000..0f3bba6714667
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s -arm-sme-outer-product-widening -cse -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
+// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
+// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
+// CHECK-DAG: %[[VEC_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xf16>
+// CHECK-DAG: %[[A0_INSERT:.*]] = vector.scalable.insert %[[A0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
+// CHECK-DAG: %[[B0_INSERT:.*]] = vector.scalable.insert %[[B0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
+// CHECK-DAG: %[[A1_INSERT:.*]] = vector.scalable.insert %[[A1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
+// CHECK-DAG: %[[B1_INSERT:.*]] = vector.scalable.insert %[[B1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
+// CHECK-DAG: %[[LHS:.*]] = "arm_sve.intr.zip1"(%[[A0_INSERT]], %[[A1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK-DAG: %[[RHS:.*]] = "arm_sve.intr.zip1"(%[[B0_INSERT]], %[[B1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK-DAG: %[[MASK_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xi1>
+// CHECK-DAG: %[[A0_MASK_INSERT:.*]] = vector.scalable.insert %[[A0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
+// CHECK-DAG: %[[B0_MASK_INSERT:.*]] = vector.scalable.insert %[[B0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
+// CHECK-DAG: %[[A1_MASK_INSERT:.*]] = vector.scalable.insert %[[A1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
+// CHECK-DAG: %[[B1_MASK_INSERT:.*]] = vector.scalable.insert %[[B1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[A0_MASK_INSERT]], %[[A1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[B0_MASK_INSERT]], %[[B1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: arm_sme.fmopa_wide_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @outerproduct_add_widening_2way_f16f16f32(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_2way_f16f16f32
+// CHECK: arm_sme.fmops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @outerproduct_sub_widening_2way_f16f16f32(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_2way_bf16bf16f32
+// CHECK: arm_sme.fmopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @outerproduct_add_widening_2way_bf16bf16f32(
+ %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
+ %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_2way_bf16bf16f32
+// CHECK: arm_sme.fmops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @outerproduct_sub_widening_2way_bf16bf16f32(
+ %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
+ %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_2way_signed_i16i16i32
+// CHECK: arm_sme.smopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_2way_signed_i16i16i32(
+ %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
+ %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %1 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_2way_signed_i16i16i32
+// CHECK: arm_sme.smops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_2way_signed_i16i16i32(
+ %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
+ %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %1 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_2way_unsigned_i16i16i32
+// CHECK: arm_sme.umopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_2way_unsigned_i16i16i32(
+ %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
+ %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32>
+ %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %1 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_2way_unsigned_i16i16i32
+// CHECK: arm_sme.umops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
+ %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
+ %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32>
+ %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %1 : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 2ad742493408b..a96756f4d3426 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1131,3 +1131,115 @@ func.func @arm_sme_streaming_vl_double_words() -> index {
%svl_d = arm_sme.streaming_vl <double>
return %svl_d : index
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.fmopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way_with_masking(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way_with_acc(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} acc({{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmopa_wide_2way_with_everything(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_wide_2way %vecA, %vecB acc(%acc) masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.fmops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_fmops_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmops_wide_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_fmops_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmops_wide_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smopa_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smops_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umopa_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_wide_2way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umops_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
new file mode 100644
index 0000000000000..8fbdf5d0011ce
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
@@ -0,0 +1,100 @@
+// DEFINE: %{entry} = test_outerproduct_f16f16f32
+// DEFINE: %{widening_opts} = -arm-sme-outer-product-widening
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{widening_opts} \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE: -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+// Check result is the same when outerproducts are not combined into widening
+// variant.
+
+// REDEFINE: %{widening_opts} =
+// RUN: %{run} | FileCheck %s
+
+func.func @test_outerproduct_f16f16f32() {
+ %undef = llvm.mlir.undef : vector<[4]xf16>
+
+ %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16>
+ %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16>
+ %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16>
+ %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32>
+ %0 = vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
+ %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
+
+ // CHECK: ( 79, 95, 111, 127
+ // CHECK-NEXT: ( 99, 123, 147, 171
+ // CHECK-NEXT: ( 119, 151, 183, 215
+ // CHECK-NEXT: ( 139, 179, 219, 259
+ vector.print %1 : vector<[4]x[4]xf32>
+
+ return
+}
+
+// TODO: A bug in QEMU causes masked FMOPAs to hang [1]. Should be fixed in
+// 8.2.0, this test currently isn't run, once this version is available in CI
+// it can be run. The check lines here are correct and have been verified on a
+// version with the fix.
+// [1] https://gitlab.com/qemu-project/qemu/-/issues/1985
+func.func @test_masked_outerproduct_f16f16f32() {
+ %undef = llvm.mlir.undef : vector<[4]xf16>
+
+ %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16>
+ %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16>
+ %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16>
+ %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32>
+
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %mask0 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %mask1 = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+
+ %0 = vector.mask %mask0 {
+ vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+ %1 = vector.mask %mask1 {
+ vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+ // MASKED: ( 79, 95, 7, 7
+ // MASKED-NEXT: ( 99, 123, 17, 7
+ // MASKED-NEXT: ( 115, 139, 7, 7
+ // MASKED-NEXT: ( 7, 7, 7, 7
+ vector.print %1 : vector<[4]x[4]xf32>
+
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 7a42033dc04bc..aedb6730b06bb 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -63,6 +63,12 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
// CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
"arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
(vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.smopa.za32.nxv8i16
+ "arm_sme.intr.smopa.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.umopa.za32.nxv8i16
+ "arm_sme.intr.umopa.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
llvm.return
}
@@ -122,6 +128,12 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
// CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
"arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
(vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.smops.za32.nxv8i16
+ "arm_sme.intr.smops.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+ // CHECK: call void @llvm.aarch64.sme.umops.za32.nxv8i16
+ "arm_sme.intr.umops.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+ (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
llvm.return
}
>From 26f705d17db82489437009cc35c815ee4f967e32 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 25 Jan 2024 15:12:42 +0000
Subject: [PATCH 2/6] replace arm_sve.intr.zip1 with target-agnostic
interleave2 intrinsic
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 9 ++-------
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 9 ++-------
.../Transforms/OuterProductWidening.cpp | 20 ++++++++-----------
.../ArmSME/outer-product-widening.mlir | 18 ++++-------------
4 files changed, 16 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 3544df494d33d..c37faa17fd1b6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -958,13 +958,8 @@ def FMopaWide2WayOp
product as follows:
```mlir
- %undef = llvm.mlir.undef : vector<[8]xf16>
- %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
- %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+ %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index aa9ad9b2e3340..d8d2e70e10182 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -143,13 +143,8 @@ def OuterProductWidening
Becomes:
```mlir
- %undef = llvm.mlir.undef : vector<[8]xf16>
- %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
- %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
- %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+ %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
index 935ed63c84c68..437d82d2f8264 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
@@ -48,10 +48,12 @@ namespace {
//
// Becomes:
//
-// %a_packed = arm_sve.zip %a0, %a1 : vector<[8]xf16> to vector<[8]xf16>
-// %b_packed = arm_sve.zip %b0, %b1 : vector<[8]xf16> to vector<[8]xf16>
-// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>,
-// vector<[4]xf32>
+// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed
+// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
class OuterProduct2WayWidening
: public OpRewritePattern<arm_sme::OuterProductOp> {
public:
@@ -113,15 +115,9 @@ class OuterProduct2WayWidening
auto loc = op.getLoc();
- // zip(lhs, rhs)
auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
- auto undef = rewriter.create<LLVM::UndefOp>(loc, type);
- auto insertLHS =
- rewriter.create<vector::ScalableInsertOp>(loc, lhs, undef, 0);
- auto insertRHS =
- rewriter.create<vector::ScalableInsertOp>(loc, rhs, undef, 0);
- return rewriter.create<arm_sve::Zip1IntrOp>(loc, type, insertLHS,
- insertRHS);
+ return rewriter.create<LLVM::experimental_vector_interleave2>(loc, type,
+ lhs, rhs);
};
auto extOp = op.getLhs().getDefiningOp();
diff --git a/mlir/test/Dialect/ArmSME/outer-product-widening.mlir b/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
index 0f3bba6714667..0feb30f950366 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
@@ -4,20 +4,10 @@
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32>
-// CHECK-DAG: %[[VEC_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xf16>
-// CHECK-DAG: %[[A0_INSERT:.*]] = vector.scalable.insert %[[A0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
-// CHECK-DAG: %[[B0_INSERT:.*]] = vector.scalable.insert %[[B0]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
-// CHECK-DAG: %[[A1_INSERT:.*]] = vector.scalable.insert %[[A1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
-// CHECK-DAG: %[[B1_INSERT:.*]] = vector.scalable.insert %[[B1]], %[[VEC_UNDEF]][0] : vector<[4]xf16> into vector<[8]xf16>
-// CHECK-DAG: %[[LHS:.*]] = "arm_sve.intr.zip1"(%[[A0_INSERT]], %[[A1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[RHS:.*]] = "arm_sve.intr.zip1"(%[[B0_INSERT]], %[[B1_INSERT]]) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
-// CHECK-DAG: %[[MASK_UNDEF:.*]] = llvm.mlir.undef : vector<[8]xi1>
-// CHECK-DAG: %[[A0_MASK_INSERT:.*]] = vector.scalable.insert %[[A0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
-// CHECK-DAG: %[[B0_MASK_INSERT:.*]] = vector.scalable.insert %[[B0_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
-// CHECK-DAG: %[[A1_MASK_INSERT:.*]] = vector.scalable.insert %[[A1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
-// CHECK-DAG: %[[B1_MASK_INSERT:.*]] = vector.scalable.insert %[[B1_MASK]], %[[MASK_UNDEF]][0] : vector<[4]xi1> into vector<[8]xi1>
-// CHECK-DAG: %[[LHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[A0_MASK_INSERT]], %[[A1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: %[[RHS_MASK:.*]] = "arm_sve.intr.zip1"(%[[B0_MASK_INSERT]], %[[B1_MASK_INSERT]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
// CHECK-DAG: arm_sme.fmopa_wide_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
func.func @outerproduct_add_widening_2way_f16f16f32(
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
>From 589c0d2f06608e78728b9d35917ed54876d8d6db Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 26 Jan 2024 14:54:53 +0000
Subject: [PATCH 3/6] Address comments. Changes:
- hasOneUse
- Negative tests for pass
- move types and isWidening checks to isSupported method
- negative test for unsupported type
- move masking check
- op1.erase() -> rewriter.eraseOp(op1);
- braces round dangling-else
- TypeSwitch
- rename isWidenable and add comments for clarity
- add TODO/REDEFINE for QEMU bug to make it clearer
- s/widening/fusion/g
- drop wide from op names
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 88 +++---
.../mlir/Dialect/ArmSME/Transforms/Passes.h | 5 +-
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 18 +-
.../Dialect/ArmSME/Transforms/Transforms.h | 2 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 44 +--
.../Dialect/ArmSME/Transforms/CMakeLists.txt | 2 +-
.../ArmSME/Transforms/OuterProductFusion.cpp | 292 ++++++++++++++++++
.../Transforms/OuterProductWidening.cpp | 234 --------------
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 60 ++--
mlir/test/Dialect/ArmSME/invalid.mlir | 22 +-
...idening.mlir => outer-product-fusion.mlir} | 168 +++++++++-
mlir/test/Dialect/ArmSME/roundtrip.mlir | 78 ++---
.../ArmSME/test-outerproduct-f16f16f32.mlir | 12 +-
13 files changed, 620 insertions(+), 405 deletions(-)
create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
delete mode 100644 mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
rename mlir/test/Dialect/ArmSME/{outer-product-widening.mlir => outer-product-fusion.mlir} (54%)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index c37faa17fd1b6..a8ed9d0288707 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -814,10 +814,10 @@ let arguments = (ins
}];
}
-class OuterProductWideBase<string mnemonic,
- list<Type> allowedInputVectorTypes,
- list<Type> allowedResultVectorTypes,
- int numOuterProducts> :
+class OuterProductWideningBase<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes,
+ int numOuterProducts> :
ArmSME_Op<mnemonic, [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
@@ -869,14 +869,14 @@ class OuterProductWideBase<string mnemonic,
}];
}
-class OuterProductWide2Way<string mnemonic,
- list<Type> allowedInputVectorTypes,
- list<Type> allowedResultVectorTypes>
- : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
- allowedResultVectorTypes, /*numOuterProducts=*/2>;
+class OuterProduct2Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
-def FMopaWide2WayOp
- : OuterProductWide2Way<"fmopa_wide_2way",
+def FMopa2WayOp
+ : OuterProduct2Way<"fmopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
[nxnxv4f32]> {
let summary = "Floating-point sum of 2 outer products and accumulate";
@@ -888,14 +888,14 @@ def FMopaWide2WayOp
For example (fp16 to fp32):
```mlir
- %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+ %result = arm_sme.fmopa_2way %lhs, %rhs :
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
elements in a vector of SVL bits. To illustrate, below is a breakdown of
- this operation for SVL=128 (i.e., vscale=1):
+ this operation for fp16 to fp32, SVL=128 (i.e., vscale=1):
```
LHS RHS
@@ -960,19 +960,19 @@ def FMopaWide2WayOp
```mlir
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
- %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
- This is implemented in the `-arm-sme-outer-product-widening` pass.
+ This is implemented in the `-arm-sme-outer-product-fusion` pass.
Example: FP16 to FP32
```mlir
- %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
Example: BF16 to FP32
```mlir
- %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
```
| Spec | Features |
@@ -989,27 +989,27 @@ def FMopaWide2WayOp
// - FMOPA 4-way FP16 to FP32
// once intrinsic support lands in the backend.
-def FMopsWide2WayOp
- : OuterProductWide2Way<"fmops_wide_2way",
+def FMops2WayOp
+ : OuterProduct2Way<"fmops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
[nxnxv4f32]> {
let summary = "Floating-point sum of 2 outer products and subtract";
let description = [{
- Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+ Equivalent to `fmopa_2way` but outer products are subtracted from
destination `result`.
Example: FP16 to FP32
```mlir
- %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
Example: BF16 to FP32
```mlir
- %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
Refer to
- [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
- detailed description of 2-way outer products.
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
+ description of 2-way outer products.
| Spec | Features |
| ---- | -------- |
@@ -1019,19 +1019,19 @@ def FMopsWide2WayOp
}];
}
-def SMopaWide2WayOp
- : OuterProductWide2Way<"smopa_wide_2way",
+def SMopa2WayOp
+ : OuterProduct2Way<"smopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Signed integer sum of 2 outer products and accumulate";
let description = [{
Example:
```mlir
- %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
Refer to
- [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
- detailed description of 2-way outer products.
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
+ description of 2-way outer products.
| Spec | Features |
| ---- | -------- |
@@ -1040,19 +1040,19 @@ def SMopaWide2WayOp
}];
}
-def SMopsWide2WayOp
- : OuterProductWide2Way<"smops_wide_2way",
+def SMops2WayOp
+ : OuterProduct2Way<"smops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Signed integer sum of 2 outer products and subtract";
let description = [{
Example:
```mlir
- %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
Refer to
- [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
- detailed description of 2-way outer products.
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
+ description of 2-way outer products.
| Spec | Features |
| ---- | -------- |
@@ -1061,19 +1061,19 @@ def SMopsWide2WayOp
}];
}
-def UMopaWide2WayOp
- : OuterProductWide2Way<"umopa_wide_2way",
+def UMopa2WayOp
+ : OuterProduct2Way<"umopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Unsiged integer sum of 2 outer products and accumulate";
let description = [{
Example:
```mlir
- %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
Refer to
- [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
- detailed description of 2-way outer products.
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
+ description of 2-way outer products.
| Spec | Features |
| ---- | -------- |
@@ -1082,19 +1082,19 @@ def UMopaWide2WayOp
}];
}
-def UMopsWide2WayOp
- : OuterProductWide2Way<"umops_wide_2way",
+def UMops2WayOp
+ : OuterProduct2Way<"umops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Unsiged integer sum of 2 outer products and subtract";
let description = [{
Example:
```mlir
- %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
Refer to
- [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
- detailed description of 2-way outer products.
+ [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
+ description of 2-way outer products.
| Spec | Features |
| ---- | -------- |
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index d3e4fccd62848..bb49ce4c62723 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,8 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();
-/// Pass that folds 'arm_sme.outerproduct' ops into widening variants.
-std::unique_ptr<Pass> createOuterProductWideningPass();
+/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
+/// variants.
+std::unique_ptr<Pass> createOuterProductFusionPass();
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index d8d2e70e10182..063cfef404097 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,11 +122,11 @@ def TileAllocation
let dependentDialects = ["func::FuncDialect"];
}
-def OuterProductWidening
- : Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
- let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
+def OuterProductFusion
+ : Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> {
+ let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants";
let description = [{
- This pass folds 'arm_sme.outerproduct' operations that are chained via the
+ This pass fuses 'arm_sme.outerproduct' operations that are chained via the
accumulator into 2-way or 4-way ArmSME outer product operations.
For example:
@@ -145,14 +145,14 @@ def OuterProductWidening
```mlir
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
- %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```
- For further information on the widening ops see:
- https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop
- https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_wide_4way-arm_smesmopa_wide_4wayop
+ For further information on the 2-way or 4-way widening ops see:
+ https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop
+ https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
}];
- let constructor = "mlir::arm_sme::createOuterProductWideningPass()";
+ let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect", "LLVM::LLVMDialect"];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index 09e3b4375fa5f..e00c7503e6999 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -16,7 +16,7 @@ class LLVMTypeConverter;
class RewritePatternSet;
namespace arm_sme {
-void populateOuterProductWideningPatterns(RewritePatternSet &patterns);
+void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 0871658bc3653..e73388b0906e8 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -776,16 +776,16 @@ struct OuterProductOpConversion
}
};
-/// Lower 2-way and 4-way outer products to intrinsics.
-template <class OuterProductWideOp, class OuterProductWideIntrOp>
-struct OuterProductWideOpConversion
- : public ConvertArmSMEOpToLLVMPattern<OuterProductWideOp> {
+/// Lower 2-way and 4-way widening outer products to intrinsics.
+template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
+struct OuterProductWideningOpConversion
+ : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
using ConvertArmSMEOpToLLVMPattern<
- OuterProductWideOp>::ConvertArmSMEOpToLLVMPattern;
+ OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
LogicalResult
- matchAndRewrite(OuterProductWideOp op,
- typename OuterProductWideOp::Adaptor adaptor,
+ matchAndRewrite(OuterProductWideningOp op,
+ typename OuterProductWideningOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tileId = getTileIdOrError(op);
if (!tileId)
@@ -807,9 +807,9 @@ struct OuterProductWideOpConversion
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideIntrOp>(op.getLoc(), tileId, lhsMask,
- rhsMask, adaptor.getLhs(),
- adaptor.getRhs());
+ rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
+ rhsMask, adaptor.getLhs(),
+ adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -927,18 +927,18 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
StreamingVLOpConversion, OuterProductOpConversion,
- OuterProductWideOpConversion<arm_sme::FMopaWide2WayOp,
- arm_sme::aarch64_sme_mopa_wide>,
- OuterProductWideOpConversion<arm_sme::FMopsWide2WayOp,
- arm_sme::aarch64_sme_mops_wide>,
- OuterProductWideOpConversion<arm_sme::SMopaWide2WayOp,
- arm_sme::aarch64_sme_smopa_za32>,
- OuterProductWideOpConversion<arm_sme::SMopsWide2WayOp,
- arm_sme::aarch64_sme_smops_za32>,
- OuterProductWideOpConversion<arm_sme::UMopaWide2WayOp,
- arm_sme::aarch64_sme_umopa_za32>,
- OuterProductWideOpConversion<arm_sme::UMopsWide2WayOp,
- arm_sme::aarch64_sme_umops_za32>,
+ OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
+ arm_sme::aarch64_sme_mopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
+ arm_sme::aarch64_sme_mops_wide>,
+ OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
+ arm_sme::aarch64_sme_smopa_za32>,
+ OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
+ arm_sme::aarch64_sme_smops_za32>,
+ OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
+ arm_sme::aarch64_sme_umopa_za32>,
+ OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
+ arm_sme::aarch64_sme_umops_za32>,
ZeroOpConversion, GetTileConversion>(patterns, converter);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 24942b6f28d2c..0484843e6b010 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
- OuterProductWidening.cpp
+ OuterProductFusion.cpp
TileAllocation.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
new file mode 100644
index 0000000000000..60e2a020b6712
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -0,0 +1,292 @@
+//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrites that fuse 'arm_sme.outerproduct' operations
+// into the 2-way or 4-way widening outerproduct operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#define DEBUG_TYPE "arm-sme-outerproduct-fusion"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_OUTERPRODUCTFUSION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+// Fuse two 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 2-way outer product operation.
+//
+// For example:
+//
+// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
+// vector<[4]xf32>
+//
+// Becomes:
+//
+// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
+// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
+// %0 = arm_sme.fmopa_2way %a_packed, %b_packed
+// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+class OuterProductFusion2Way
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ Value acc = op.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ arm_sme::OuterProductOp op2 = op;
+ if (!op1)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ if (op1.getKind() != op2.getKind())
+ return rewriter.notifyMatchFailure(
+ op, "combining kind (add or sub) of outer products must match");
+
+ if (!op1->hasOneUse()) {
+ // If the first outer product has uses other than as the input to another
+ // outer product, it can't be erased after fusion. This is a problem when
+ // it also has an accumulator as this will be used as the root for tile
+ // allocation and since the widening outer product uses the same
+ // accumulator it will get assigned the same tile ID, resulting in 3
+ // outer products accumulating to the same tile and incorrect results.
+ //
+ // Example:
+ //
+ // %acc = arith.constant dense<0.0> ; root for tile allocation
+ // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc)
+ // vector.print %0 ; intermediary use, can't erase %0
+ // %1 = arm_sme.outerproduct %a1, %b1 acc(%0)
+ //
+ // After fusion and tile allocation
+ //
+ // %0 = arm_sme.zero {tile_id = 0 : i32}
+ // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32}
+ // vector.print %1
+ // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32}
+ //
+ // No accumulator would be ok, but it's simpler to prevent this
+ // altogether, since it has no benefit.
+ return rewriter.notifyMatchFailure(
+ op, "first outer product is not single use and cannot be removed, "
+ "no benefit to fusing");
+ }
+
+ if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, "unsupported masking, either both outerproducts are masked "
+ "or neither");
+
+ if (failed(canFuseOuterProducts(rewriter, op1, op2)))
+ return failure();
+
+ auto loc = op.getLoc();
+
+ auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
+ return rewriter.create<LLVM::experimental_vector_interleave2>(loc, type,
+ lhs, rhs);
+ };
+
+ auto extOp = op.getLhs().getDefiningOp();
+ VectorType extSourceVectorType =
+ cast<VectorType>(extOp->getOperand(0).getType());
+ VectorType widenedVectorType =
+ VectorType::Builder(extSourceVectorType)
+ .setDim(0, extSourceVectorType.getShape()[0] * 2);
+ auto lhs = packInputs(widenedVectorType,
+ op1.getLhs().getDefiningOp()->getOperand(0),
+ op2.getLhs().getDefiningOp()->getOperand(0));
+ auto rhs = packInputs(widenedVectorType,
+ op1.getRhs().getDefiningOp()->getOperand(0),
+ op2.getRhs().getDefiningOp()->getOperand(0));
+
+ Value lhsMask, rhsMask;
+ if (op1.getLhsMask() || op2.getLhsMask()) {
+ VectorType maskType = VectorType::Builder(widenedVectorType)
+ .setElementType(rewriter.getI1Type());
+ lhsMask = packInputs(maskType, op1.getLhsMask(), op2.getLhsMask());
+ rhsMask = packInputs(maskType, op1.getRhsMask(), op2.getRhsMask());
+ }
+
+ arm_sme::CombiningKind kind = op.getKind();
+ if (kind == arm_sme::CombiningKind::Add) {
+ TypeSwitch<Operation *>(extOp)
+ .Case<arith::ExtFOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::FMopa2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Case<arith::ExtSIOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::SMopa2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Case<arith::ExtUIOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::UMopa2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ } else if (kind == arm_sme::CombiningKind::Sub) {
+ TypeSwitch<Operation *>(extOp)
+ .Case<arith::ExtFOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::FMops2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Case<arith::ExtSIOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::SMops2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Case<arith::ExtUIOp>([&](auto) {
+ rewriter.replaceOpWithNewOp<arm_sme::UMops2WayOp>(
+ op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask,
+ op1.getAcc());
+ })
+ .Default([&](auto) { llvm_unreachable("unexpected extend op!"); });
+ } else {
+ llvm_unreachable("unexpected arm_sme::CombiningKind!");
+ }
+
+ rewriter.eraseOp(op1);
+
+ return success();
+ }
+
+private:
+ // A pair of outer product can be fused if all of the following are true:
+ // - input and result types match.
+ // - the defining operations of the inputs are identical extensions,
+ // specifically either:
+ // - a signed or unsigned extension for integer types.
+ // - a floating-point extension for floating-point types.
+ // - the types and extension are supported, i.e. there's a 2-way operation
+ // they can be fused into.
+ LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op1,
+ arm_sme::OuterProductOp op2) const {
+ // Supported result types.
+ auto nxnxv4i32 =
+ VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+ auto nxnxv4f32 =
+ VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
+ // Supported input types.
+ // Note: this is before packing so these have half the number of elements
+ // of the input vector types of the 2-way operations.
+ auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
+ auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
+ auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
+ if ((failed(
+ isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
+ failed(
+ isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
+ (failed(
+ isCompatible<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
+ failed(isCompatible<arith::ExtFOp>(rewriter, op2, nxnxv4f32,
+ nxv4bf16))) &&
+ (failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
+ failed(isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32,
+ nxv4i16))) &&
+ (failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
+ return failure();
+
+ return success();
+ }
+
+ // An outer product is compatible if all of the following are true:
+ // - the result type matches `resultType`.
+ // - the defining operations of the inputs are identical and of the type
+ // `ExtOp`.
+ // - the input types of the defining operations are identical and match
+ // `inputType`.
+ template <typename ExtOp>
+ LogicalResult isCompatible(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op, VectorType resultType,
+ VectorType inputType) const {
+ if (op.getResultType() != resultType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported result type, expected " << resultType;
+ });
+
+ auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
+ auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
+
+ if (!lhsDefOp || !rhsDefOp)
+ return rewriter.notifyMatchFailure(
+ op, "defining op of outerproduct operands must be 'arith.extf' or "
+ "'arith.extsi' or 'arith.extui'");
+
+ auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
+ auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
+
+ if (lhsInType != inputType || rhsInType != inputType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported input type, expected " << inputType;
+ });
+
+ return success();
+ }
+};
+
+struct OuterProductFusionPass
+ : public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateOuterProductFusionPatterns(patterns);
+
+ if (failed(
+ applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::arm_sme::populateOuterProductFusionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<OuterProductFusion2Way>(patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
+ return std::make_unique<OuterProductFusionPass>();
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
deleted file mode 100644
index 437d82d2f8264..0000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp
+++ /dev/null
@@ -1,234 +0,0 @@
-//===- OuterProductWidening.cpp - Widen 'arm_sme.outerproduct' ops --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements rewrites that fold 'arm_sme.outerproduct' operations
-// into the 2-way or 4-way widening outerproduct operations.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
-#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-#define DEBUG_TYPE "arm-sme-outerproduct-widening"
-
-namespace mlir::arm_sme {
-#define GEN_PASS_DEF_OUTERPRODUCTWIDENING
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
-} // namespace mlir::arm_sme
-
-using namespace mlir;
-using namespace mlir::arm_sme;
-
-namespace {
-// Fold two 'arm_sme.outerproduct' operations that are chained via the
-// accumulator into 2-way outer product operation.
-//
-// For example:
-//
-// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
-// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
-// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>,
-// vector<[4]xf32>
-//
-// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
-// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
-// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>,
-// vector<[4]xf32>
-//
-// Becomes:
-//
-// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1)
-// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
-// %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed
-// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
-class OuterProduct2WayWidening
- : public OpRewritePattern<arm_sme::OuterProductOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
- PatternRewriter &rewriter) const override {
- Value acc = op.getAcc();
- if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
-
- arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
- arm_sme::OuterProductOp op2 = op;
- if (!op1)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
-
- if (op1.getKind() != op2.getKind())
- return rewriter.notifyMatchFailure(
- op, "combining kind (add or sub) of outer products must match");
-
- if (!llvm::hasSingleElement(op1->getUses())) {
- // We could still widen, but if the first outer product has an
- // accumulator it will be used as the root for tile allocation and since
- // the widening outer product uses the same accumulator it will get
- // assigned the same tile ID, resulting in 3 outer products and incorrect
- // results. No accumulator would be ok, but it's simpler to prevent this
- // altogether, since it has no benefit.
- return rewriter.notifyMatchFailure(
- op, "first outer product is not single use and cannot be removed, "
- "no benefit to widening");
- }
-
- auto nxnxv4i32 =
- VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
- auto nxnxv4f32 =
- VectorType::get({4, 4}, rewriter.getF32Type(), {true, true});
- auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true);
- auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true);
- auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true);
- if ((failed(
- isWidenable<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4f16)) ||
- failed(
- isWidenable<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4f16))) &&
- (failed(
- isWidenable<arith::ExtFOp>(rewriter, op1, nxnxv4f32, nxv4bf16)) ||
- failed(
- isWidenable<arith::ExtFOp>(rewriter, op2, nxnxv4f32, nxv4bf16))) &&
- (failed(
- isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i16))) &&
- (failed(
- isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i16)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i16))))
- return failure();
-
- auto loc = op.getLoc();
-
- auto packInputs = [&](VectorType type, Value lhs, Value rhs) {
- return rewriter.create<LLVM::experimental_vector_interleave2>(loc, type,
- lhs, rhs);
- };
-
- auto extOp = op.getLhs().getDefiningOp();
- VectorType extSourceVectorType =
- cast<VectorType>(extOp->getOperand(0).getType());
- VectorType widenedVectorType =
- VectorType::Builder(extSourceVectorType)
- .setDim(0, extSourceVectorType.getShape()[0] * 2);
- auto lhs = packInputs(widenedVectorType,
- op1.getLhs().getDefiningOp()->getOperand(0),
- op2.getLhs().getDefiningOp()->getOperand(0));
- auto rhs = packInputs(widenedVectorType,
- op1.getRhs().getDefiningOp()->getOperand(0),
- op2.getRhs().getDefiningOp()->getOperand(0));
-
- Value lhsMask, rhsMask;
- if (op1.getLhsMask() || op2.getLhsMask()) {
- if (!(op1.getLhsMask() && op2.getLhsMask()))
- return rewriter.notifyMatchFailure(
- op, "unsupported masking, either both outerproducts are masked "
- "or neither");
-
- VectorType maskType = VectorType::Builder(widenedVectorType)
- .setElementType(rewriter.getI1Type());
- lhsMask = packInputs(maskType, op1.getLhsMask(), op2.getLhsMask());
- rhsMask = packInputs(maskType, op1.getRhsMask(), op2.getRhsMask());
- }
-
- arm_sme::CombiningKind kind = op.getKind();
- assert((kind == arm_sme::CombiningKind::Add ||
- kind == arm_sme::CombiningKind::Sub) &&
- "unhandled arm_sme::CombiningKind!");
-
- if (isa<arith::ExtFOp>(extOp)) {
- if (kind == arm_sme::CombiningKind::Add)
- rewriter.replaceOpWithNewOp<arm_sme::FMopaWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::FMopsWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else if (isa<arith::ExtSIOp>(extOp)) {
- if (kind == arm_sme::CombiningKind::Add)
- rewriter.replaceOpWithNewOp<arm_sme::SMopaWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::SMopsWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else if (isa<arith::ExtUIOp>(extOp)) {
- if (kind == arm_sme::CombiningKind::Add)
- rewriter.replaceOpWithNewOp<arm_sme::UMopaWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::UMopsWide2WayOp>(
- op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else
- llvm_unreachable("unexpected extend op!");
-
- op1.erase();
-
- return success();
- }
-
-private:
- template <typename ExtOp>
- LogicalResult isWidenable(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op, VectorType resultType,
- VectorType inputType) const {
- if (op.getResultType() != resultType)
- return rewriter.notifyMatchFailure(
- op, "unsupported result type, expected 'vector<[4]x[4]xi32>' or "
- "'vector<[4]x[4]xf32>'");
-
- auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
- auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
- if (!lhsDefOp || !rhsDefOp)
- return rewriter.notifyMatchFailure(
- op, "defining op of outerproduct operands must be 'arith.extf' or "
- "'arith.extsi' or 'arith.extui'");
-
- auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
- auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
-
- if (lhsInType != inputType || rhsInType != inputType)
- return rewriter.notifyMatchFailure(
- op, "unsupported input types, expected 'vector<[4]xi16>' or "
- "'vector<[4]xf16>' or 'vector<[4]xbf16>'");
- return success();
- }
-};
-
-struct OuterProductWideningPass
- : public arm_sme::impl::OuterProductWideningBase<OuterProductWideningPass> {
-
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateOuterProductWideningPatterns(patterns);
-
- if (failed(
- applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
- signalPassFailure();
- }
-};
-
-} // namespace
-
-void mlir::arm_sme::populateOuterProductWideningPatterns(
- RewritePatternSet &patterns) {
- patterns.add<OuterProduct2WayWidening>(patterns.getContext());
-}
-
-std::unique_ptr<Pass> mlir::arm_sme::createOuterProductWideningPass() {
- return std::make_unique<OuterProductWideningPass>();
-}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index e0e0d90d8e85f..c41504d0e4724 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -603,97 +603,97 @@ func.func @arm_sme_streaming_vl_double_words() -> index {
}
//===----------------------------------------------------------------------===//
-// arm_sme.fmopa_wide_2way
+// arm_sme.fmopa_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_fmopa_wide_2way_f16f16_to_f32
+// CHECK-LABEL: arm_sme_fmopa_2way_f16f16_to_f32
// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
-func.func @arm_sme_fmopa_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-// CHECK-LABEL: arm_sme_fmopa_wide_2way_bf16bf16_to_f32
+// CHECK-LABEL: arm_sme_fmopa_2way_bf16bf16_to_f32
// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_fmopa_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.fmops_wide_2way
+// arm_sme.fmops_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_fmops_wide_2way_f16f16_to_f32
+// CHECK-LABEL: arm_sme_fmops_2way_f16f16_to_f32
// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
-func.func @arm_sme_fmops_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
- %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-// CHECK-LABEL: arm_sme_fmops_wide_2way_bf16bf16_to_f32
+// CHECK-LABEL: arm_sme_fmops_2way_bf16bf16_to_f32
// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_fmops_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
- %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.smopa_wide_2way
+// arm_sme.smopa_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_smopa_wide_2way_i16i16_to_i32
+// CHECK-LABEL: arm_sme_smopa_2way_i16i16_to_i32
// CHECK: "arm_sme.intr.smopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- %result = arm_sme.smopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.smops_wide_2way
+// arm_sme.smops_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_smops_wide_2way_i16i16_to_i32
+// CHECK-LABEL: arm_sme_smops_2way_i16i16_to_i32
// CHECK: "arm_sme.intr.smops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_smops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- %result = arm_sme.smops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.umopa_wide_2way
+// arm_sme.umopa_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_umopa_wide_2way_i16i16_to_i32
+// CHECK-LABEL: arm_sme_umopa_2way_i16i16_to_i32
// CHECK: "arm_sme.intr.umopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- %result = arm_sme.umopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.umops_wide_2way
+// arm_sme.umops_2way
//===----------------------------------------------------------------------===//
// -----
-// CHECK-LABEL: arm_sme_umops_wide_2way_i16i16_to_i32
+// CHECK-LABEL: arm_sme_umops_2way_i16i16_to_i32
// CHECK: "arm_sme.intr.umops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
-func.func @arm_sme_umops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- %result = arm_sme.umops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 1f63de927ea00..dcc231332f208 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -175,54 +175,54 @@ func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB:
}
//===----------------------------------------------------------------------===//
-// arm_sme.fmopa_wide_2way
+// arm_sme.fmopa_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_fmopa_wide_2way__bad_rhs_vector_type(%vecA: vector<[8]xf16>, %vecB: vector<[4]xf32>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way__bad_rhs_vector_type(%vecA: vector<[8]xf16>, %vecB: vector<[4]xf32>) -> vector<[4]x[4]xf32>
{
// expected-error at +1 {{op failed to verify that all of {lhs, rhs} have same type}}
- %0 = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[4]xf32> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[4]xf32> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way__bad_lhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[4]xi1>, %maskB : vector<[8]xi1>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way__bad_lhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[4]xi1>, %maskB : vector<[8]xi1>) -> vector<[4]x[4]xf32>
{
// expected-note at -2 {{prior use here}}
// expected-error at +1 {{use of value '%maskA' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}}
- %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way__bad_rhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>, %maskB : vector<[4]xi1>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way__bad_rhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>, %maskB : vector<[4]xi1>) -> vector<[4]x[4]xf32>
{
// expected-note at -2 {{prior use here}}
// expected-error at +1 {{use of value '%maskB' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}}
- %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way__no_rhs_mask(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way__no_rhs_mask(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>) -> vector<[4]x[4]xf32>
{
// expected-error at +1 {{op failed to verify that both `lhsMask` and `rhsMask` should be provided or neither}}
- %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA,) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA,) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32>
{
%acc = arm_sme.zero : vector<[2]x[2]xi64>
// expected-note at -1 {{prior use here}}
// expected-error at +1 {{use of value '%acc' expects different type than prior uses: 'vector<[4]x[4]xf32>' vs 'vector<[2]x[2]xi64>'}}
- %0 = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-widening.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
similarity index 54%
rename from mlir/test/Dialect/ArmSME/outer-product-widening.mlir
rename to mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 0feb30f950366..0383d9aebfef7 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-widening.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arm-sme-outer-product-widening -cse -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
@@ -8,7 +8,7 @@
// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
-// CHECK-DAG: arm_sme.fmopa_wide_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
func.func @outerproduct_add_widening_2way_f16f16f32(
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
%a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
@@ -30,7 +30,7 @@ func.func @outerproduct_add_widening_2way_f16f16f32(
// -----
// CHECK-LABEL: @outerproduct_sub_widening_2way_f16f16f32
-// CHECK: arm_sme.fmops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
func.func @outerproduct_sub_widening_2way_f16f16f32(
%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
%a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
@@ -52,7 +52,7 @@ func.func @outerproduct_sub_widening_2way_f16f16f32(
// -----
// CHECK-LABEL: @outerproduct_add_widening_2way_bf16bf16f32
-// CHECK: arm_sme.fmopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+// CHECK: arm_sme.fmopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
func.func @outerproduct_add_widening_2way_bf16bf16f32(
%a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
%a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
@@ -74,7 +74,7 @@ func.func @outerproduct_add_widening_2way_bf16bf16f32(
// -----
// CHECK-LABEL: @outerproduct_sub_widening_2way_bf16bf16f32
-// CHECK: arm_sme.fmops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
func.func @outerproduct_sub_widening_2way_bf16bf16f32(
%a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>,
%a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>,
@@ -96,7 +96,7 @@ func.func @outerproduct_sub_widening_2way_bf16bf16f32(
// -----
// CHECK-LABEL: @outerproduct_add_widening_2way_signed_i16i16i32
-// CHECK: arm_sme.smopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+// CHECK: arm_sme.smopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
func.func @outerproduct_add_widening_2way_signed_i16i16i32(
%a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
%a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
@@ -118,7 +118,7 @@ func.func @outerproduct_add_widening_2way_signed_i16i16i32(
// -----
// CHECK-LABEL: @outerproduct_sub_widening_2way_signed_i16i16i32
-// CHECK: arm_sme.smops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+// CHECK: arm_sme.smops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
func.func @outerproduct_sub_widening_2way_signed_i16i16i32(
%a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
%a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
@@ -140,7 +140,7 @@ func.func @outerproduct_sub_widening_2way_signed_i16i16i32(
// -----
// CHECK-LABEL: @outerproduct_add_widening_2way_unsigned_i16i16i32
-// CHECK: arm_sme.umopa_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+// CHECK: arm_sme.umopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
func.func @outerproduct_add_widening_2way_unsigned_i16i16i32(
%a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
%a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
@@ -162,7 +162,7 @@ func.func @outerproduct_add_widening_2way_unsigned_i16i16i32(
// -----
// CHECK-LABEL: @outerproduct_sub_widening_2way_unsigned_i16i16i32
-// CHECK: arm_sme.umops_wide_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+// CHECK: arm_sme.umops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
%a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>,
%a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>,
@@ -180,3 +180,153 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
return %1 : vector<[4]x[4]xi32>
}
+
+/// Negative tests
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_2way__no_acc
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_acc
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+/// Combining kinds of outer products must match to be fused.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_combining_kind(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<add> : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+/// If the first outer product has uses other than as the input to another
+/// outer product, it can't be erased after fusion. This is a problem when
+/// it also has an accumulator as this will be used as the root for tile
+/// allocation and since the widening outer product uses the same
+/// accumulator it will get assigned the same tile ID, resulting in 3
+/// outer products and incorrect results. Check this is prevented.
+
+// CHECK-LABEL: @outerproduct_widening_2way__cant_erase
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__cant_erase(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
+ "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
+ %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>,
+ %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> {
+ %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64>
+ %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64>
+ %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64>
+ %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
+
+ return %1 : vector<[2]x[2]xf64>
+}
+
+// -----
+
+/// Fusion only occurs if either both outer products are masked, or neither.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_masking
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_masking(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+/// Defining op of outer product must be a supported extension op.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_defining_op(
+ %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>,
+ %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> {
+ %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index a96756f4d3426..ca096363e7283 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1133,113 +1133,113 @@ func.func @arm_sme_streaming_vl_double_words() -> index {
}
//===----------------------------------------------------------------------===//
-// arm_sme.fmopa_wide_2way
+// arm_sme.fmopa_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_fmopa_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way_with_masking(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_with_masking(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way_with_acc(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} acc({{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_with_acc(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} acc({{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way %vecA, %vecB acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmopa_wide_2way_with_everything(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmopa_wide_2way {{.*}}, {{.*}} acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmopa_wide_2way %vecA, %vecB acc(%acc) masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmopa_2way_with_everything(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmopa_2way %vecA, %vecB acc(%acc) masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.fmops_wide_2way
+// arm_sme.fmops_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_fmops_wide_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmops_wide_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmops_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
// -----
-func.func @arm_sme_fmops_wide_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
- // CHECK: arm_sme.fmops_wide_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
- %result = arm_sme.fmops_wide_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> {
+ // CHECK: arm_sme.fmops_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+ %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
return %result : vector<[4]x[4]xf32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.smopa_wide_2way
+// arm_sme.smopa_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_smopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- // CHECK: arm_sme.smopa_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
- %result = arm_sme.smopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smopa_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.smops_wide_2way
+// arm_sme.smops_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_smops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- // CHECK: arm_sme.smops_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
- %result = arm_sme.smops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smops_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.umopa_wide_2way
+// arm_sme.umopa_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_umopa_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- // CHECK: arm_sme.umopa_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
- %result = arm_sme.umopa_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umopa_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
//===----------------------------------------------------------------------===//
-// arm_sme.umops_wide_2way
+// arm_sme.umops_2way
//===----------------------------------------------------------------------===//
// -----
-func.func @arm_sme_umops_wide_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
- // CHECK: arm_sme.umops_wide_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
- %result = arm_sme.umops_wide_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umops_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
index 8fbdf5d0011ce..4cded2b26559d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
@@ -1,8 +1,8 @@
// DEFINE: %{entry} = test_outerproduct_f16f16f32
-// DEFINE: %{widening_opts} = -arm-sme-outer-product-widening
+// DEFINE: %{fusion_opts} = -arm-sme-outer-product-fusion
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{widening_opts} \
+// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{fusion_opts} \
// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm -o %t
@@ -18,9 +18,15 @@
// Check result is the same when outerproducts are not combined into widening
// variant.
-// REDEFINE: %{widening_opts} =
+// REDEFINE: %{fusion_opts} =
// RUN: %{run} | FileCheck %s
+// TODO: Add run line for masked test once QEMU is fixed.
+// REDEFINE: %{entry} = test_masked_outerproduct_f16f16f32
+
+// TODO: Add run line for masked test once QEMU is fixed.
+// REDEFINE: %{fusion_opts} =
+
func.func @test_outerproduct_f16f16f32() {
%undef = llvm.mlir.undef : vector<[4]xf16>
>From b9e3a5b85f0f60293928c66864d4835fb9e0c75f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 30 Jan 2024 10:28:01 +0000
Subject: [PATCH 4/6] use setArmSVLBits in integration test to set SVL to 128
---
.../ArmSME/test-outerproduct-f16f16f32.mlir | 48 +++++++++++--------
1 file changed, 28 insertions(+), 20 deletions(-)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
index 4cded2b26559d..f081838300a9a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir
@@ -1,15 +1,15 @@
-// DEFINE: %{entry} = test_outerproduct_f16f16f32
+// DEFINE: %{entry} = main
// DEFINE: %{fusion_opts} = -arm-sme-outer-product-fusion
// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{fusion_opts} \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
// DEFINE: -test-lower-to-llvm -o %t
// DEFINE: %{run} = %mcr_aarch64_cmd %t \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry} -entry-point-result=void \
-// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils,%arm_sme_abi_shlib
// RUN: %{compile}
@@ -21,11 +21,21 @@
// REDEFINE: %{fusion_opts} =
// RUN: %{run} | FileCheck %s
-// TODO: Add run line for masked test once QEMU is fixed.
-// REDEFINE: %{entry} = test_masked_outerproduct_f16f16f32
+func.func @main() {
+ %c128 = arith.constant 128 : i32
+ func.call @setArmSVLBits(%c128) : (i32) -> ()
-// TODO: Add run line for masked test once QEMU is fixed.
-// REDEFINE: %{fusion_opts} =
+ func.call @test_outerproduct_f16f16f32() : () -> ()
+
+ // TODO: A bug in QEMU causes masked FMOPAs to hang [1]. Should be fixed in
+ // 8.2.0, this test currently isn't run, once this version is available in CI
+ // it can be run. The output without check lines in the function are correct
+ // and have been verified on a version with the fix.
+ // [1] https://gitlab.com/qemu-project/qemu/-/issues/1985
+ //func.call @test_masked_outerproduct_f16f16f32() : () -> ()
+
+ return
+}
func.func @test_outerproduct_f16f16f32() {
%undef = llvm.mlir.undef : vector<[4]xf16>
@@ -49,20 +59,15 @@ func.func @test_outerproduct_f16f16f32() {
%0 = vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
%1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
- // CHECK: ( 79, 95, 111, 127
- // CHECK-NEXT: ( 99, 123, 147, 171
- // CHECK-NEXT: ( 119, 151, 183, 215
- // CHECK-NEXT: ( 139, 179, 219, 259
+ // CHECK: ( 79, 95, 111, 127 )
+ // CHECK-NEXT: ( 99, 123, 147, 171 )
+ // CHECK-NEXT: ( 119, 151, 183, 215 )
+ // CHECK-NEXT: ( 139, 179, 219, 259 )
vector.print %1 : vector<[4]x[4]xf32>
return
}
-// TODO: A bug in QEMU causes masked FMOPAs to hang [1]. Should be fixed in
-// 8.2.0, this test currently isn't run, once this version is available in CI
-// it can be run. The check lines here are correct and have been verified on a
-// version with the fix.
-// [1] https://gitlab.com/qemu-project/qemu/-/issues/1985
func.func @test_masked_outerproduct_f16f16f32() {
%undef = llvm.mlir.undef : vector<[4]xf16>
@@ -96,11 +101,14 @@ func.func @test_masked_outerproduct_f16f16f32() {
vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
} : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
- // MASKED: ( 79, 95, 7, 7
- // MASKED-NEXT: ( 99, 123, 17, 7
- // MASKED-NEXT: ( 115, 139, 7, 7
- // MASKED-NEXT: ( 7, 7, 7, 7
+ // TODO: CHECK these lines once QEMU is fixed.
+ // ( 79, 95, 7, 7 )
+ // ( 99, 123, 17, 7 )
+ // ( 115, 139, 7, 7 )
+ // ( 7, 7, 7, 7 )
vector.print %1 : vector<[4]x[4]xf32>
return
}
+
+func.func private @setArmSVLBits(%bits : i32)
>From eb2874c9c36c5ccb9d0b0a68f33b15bd3587be92 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 24 Jan 2024 10:15:14 +0000
Subject: [PATCH 5/6] [mlir][ArmSME] Support 4-way widening outer products
This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.
Changes:
- Adds the following operations:
- smopa_4way, smops_4way
- umopa_4way, umops_4way
- sumopa_4way, sumops_4way
- sumopa_4way, sumops_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Extends 'arm-sme-outer-product' pass.
For a detailed description of these operations see the
'arm_sme.smopa_4way' description.
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 333 ++++++++++
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 16 +
.../ArmSME/Transforms/OuterProductFusion.cpp | 248 +++++++-
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 176 ++++++
mlir/test/Dialect/ArmSME/invalid.mlir | 13 +
.../Dialect/ArmSME/outer-product-fusion.mlir | 575 ++++++++++++++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 160 +++++
.../CPU/ArmSME/test-outerproduct-i8i8i32.mlir | 142 +++++
mlir/test/Target/LLVMIR/arm-sve.mlir | 7 +
9 files changed, 1669 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index a8ed9d0288707..ede71fdb647eb 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1103,6 +1103,339 @@ def UMops2WayOp
}];
}
+class OuterProduct4Way<string mnemonic,
+ list<Type> allowedInputVectorTypes,
+ list<Type> allowedResultVectorTypes>
+ : OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
+ allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def SMopa4WayOp
+ : OuterProduct4Way<"smopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ This operation represents a sum of 4 widened outer products. It takes 2 1-D
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+ For example (i8 to i32):
+
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs :
+ vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+ 4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
+ this operation for i8 to i32, SVL=128 (i.e., vscale=1):
+
+ ```
+ LHS
+ [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+ RHS
+ [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+ ----------------------------------------------------------------------------
+
+ implicit layout
+
+ [A0 A1 A2 A3] | [B0 B4 B8 B12]
+ [A4 A5 A6 A7] | [B1 B5 B9 B13]
+ [A8 A9 A10 A11] | [B2 B6 B10 B14]
+ [A12 A13 A14 A15] | [B3 B7 B11 B15]
+
+ ----------------------------------------------------------------------------
+
+ 4 outer products
+
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
+ ------------- | -------------
+ |
+ [B0 B4 B8 B12] | [B1 B5 B9 B13]
+ |
+ [A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
+ A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
+ A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
+ A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
+ |
+ Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
+ ------------- | -------------
+ |
+ [B2, B6, B10, B14] | [B3 B7 B11 B15]
+ |
+ [A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
+ A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
+ A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
+ A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
+ |
+
+ ----------------------------------------------------------------------------
+
+ sum of 4 outer products
+
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+ [ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
+ [ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
+ [ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
+ [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+ ----------------------------------------------------------------------------
+ ```
+
+ This operation enables the folding of 4 outer products chained via the
+ accumulator into a single outer product.
+
+ For example:
+
+ ```mlir
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+ ```
+
+ The 4 outer products in the example above can be fused into a single outer
+ product as follows:
+
+ ```mlir
+ %lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+
+ %rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+ %rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+
+ %0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ This is implemented in the `-arm-sme-outer-product-fusion` pass.
+
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SMops4WayOp
+ : OuterProduct4Way<"smops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed integer sum of 4 outer products and subtract";
+ let description = [{
+ Equivalent to `smopa_4way` but outer products are subtracted from
+ destination `result`.
+
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.smops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UMopa4WayOp
+ : OuterProduct4Way<"umopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.umopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UMops4WayOp
+ : OuterProduct4Way<"umops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.umops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SuMopa4WayOp
+ : OuterProduct4Way<"sumopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed by unsigned integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def SuMops4WayOp
+ : OuterProduct4Way<"sumops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Signed by unsigned integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.sumops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UsMopa4WayOp
+ : OuterProduct4Way<"usmopa_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned by signed integer sum of 4 outer products and accumulate";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
+def UsMops4WayOp
+ : OuterProduct4Way<"usmops_4way",
+ [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+ ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+ [nxnxv4i32, nxnxv2i64]> {
+ let summary = "Unsigned by signed integer sum of 4 outer products and subtract";
+ let description = [{
+ Example: I8 to I32
+ ```mlir
+ %result = arm_sme.usmops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ ```
+
+ Example: I16 to I64
+ ```mlir
+ %result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+
+ Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa_4wayop) for a
+ detailed description of 4-way outer products.
+
+ | Spec | Features |
+ | ---- | -------- |
+ | [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
+
+ ```
+ }];
+}
+
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e73388b0906e8..1ba1b88fc1234 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -939,6 +939,22 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
arm_sme::aarch64_sme_umopa_za32>,
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
arm_sme::aarch64_sme_umops_za32>,
+ OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
+ arm_sme::aarch64_sme_smopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
+ arm_sme::aarch64_sme_smops_wide>,
+ OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
+ arm_sme::aarch64_sme_umopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
+ arm_sme::aarch64_sme_umops_wide>,
+ OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
+ arm_sme::aarch64_sme_sumopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
+ arm_sme::aarch64_sme_sumops_wide>,
+ OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
+ arm_sme::aarch64_sme_usmopa_wide>,
+ OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
+ arm_sme::aarch64_sme_usmops_wide>,
ZeroOpConversion, GetTileConversion>(patterns, converter);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 60e2a020b6712..1fc6418c09d49 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -267,6 +267,251 @@ class OuterProductFusion2Way
}
};
+// Fold four 'arm_sme.outerproduct' operations that are chained via the
+// accumulator into 4-way outer product operation.
+class OuterProductFusion4Way
+ : public OpRewritePattern<arm_sme::OuterProductOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arm_sme::OuterProductOp op,
+ PatternRewriter &rewriter) const override {
+ Value acc = op.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op4 = op;
+ arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op3)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ acc = op3.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op2)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ acc = op2.getAcc();
+ if (!acc)
+ return rewriter.notifyMatchFailure(op, "no accumulator operand");
+
+ arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
+ if (!op1)
+ return rewriter.notifyMatchFailure(op,
+ "defining op of accumulator operand "
+ "must be an 'arm_sme.outerproduct'");
+
+ arm_sme::CombiningKind kind = op1.getKind();
+ if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
+ return rewriter.notifyMatchFailure(
+ op, "combining kind (add or sub) of outer products must match");
+
+ if (!llvm::hasSingleElement(op1->getUses()) ||
+ !llvm::hasSingleElement(op2->getUses()) ||
+ !llvm::hasSingleElement(op3->getUses()))
+ return rewriter.notifyMatchFailure(
+ op, "outer products are not single use and cannot be removed, "
+ "no benefit to widening");
+
+ auto nxnxv4i32 =
+ VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+ auto nxnxv2i64 =
+ VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+ auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
+ auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
+ if (
+ // signed, i8i8i32
+ (failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed, i16i16i64
+ (failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtSIOp>(rewriter, op4, nxnxv2i64, nxv2i16))) &&
+ // unsigned, i8i8i32
+ (failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned, i16i16i64
+ (failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isWidenable<arith::ExtUIOp>(rewriter, op4, nxnxv2i64, nxv2i16))) &&
+ // signed by unsigned, i8i8i32
+ (failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed by unsigned, i16i16i64
+ (failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))) &&
+ // unsigned by signed, i8i8i32
+ (failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned by signed, i16i16i64
+ (failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))))
+ return failure();
+
+ auto loc = op.getLoc();
+
+ auto packInputs = [&](Value lhs, Value rhs) {
+ auto inputType = cast<VectorType>(lhs.getType());
+ VectorType widenedType =
+ VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2);
+ return rewriter.create<LLVM::experimental_vector_interleave2>(
+ loc, widenedType, lhs, rhs);
+ };
+
+ auto lhsExtOp = op.getLhs().getDefiningOp();
+ auto rhsExtOp = op.getRhs().getDefiningOp();
+ auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0),
+ op3.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs1 = packInputs(op2.getLhs().getDefiningOp()->getOperand(0),
+ op4.getLhs().getDefiningOp()->getOperand(0));
+ auto lhs = packInputs(lhs0, lhs1);
+
+ auto rhs0 = packInputs(op1.getRhs().getDefiningOp()->getOperand(0),
+ op3.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs1 = packInputs(op2.getRhs().getDefiningOp()->getOperand(0),
+ op4.getRhs().getDefiningOp()->getOperand(0));
+ auto rhs = packInputs(rhs0, rhs1);
+
+ Value lhsMask, rhsMask;
+ if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
+ op4.getLhsMask()) {
+ if (!(op1.getLhsMask() && op2.getLhsMask() && op3.getLhsMask() &&
+ op4.getLhsMask()))
+ return rewriter.notifyMatchFailure(
+ op, "unsupported masking, either all outerproducts are masked "
+ "or none");
+
+ auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
+ auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
+ lhsMask = packInputs(lhs0Mask, lhs1Mask);
+
+ auto rhs0Mask = packInputs(op1.getRhsMask(), op3.getRhsMask());
+ auto rhs1Mask = packInputs(op2.getRhsMask(), op4.getRhsMask());
+ rhsMask = packInputs(rhs0Mask, rhs1Mask);
+ }
+
+ assert((kind == arm_sme::CombiningKind::Add ||
+ kind == arm_sme::CombiningKind::Sub) &&
+ "unhandled arm_sme::CombiningKind!");
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
+ if (kind == arm_sme::CombiningKind::Add)
+ rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ } else
+ llvm_unreachable("unexpected extend op!");
+
+ op3.erase();
+ op2.erase();
+ op1.erase();
+
+ return success();
+ }
+
+private:
+ template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
+ LogicalResult isWidenable(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op, VectorType resultType,
+ VectorType inputType) const {
+ if (op.getResultType() != resultType)
+ return rewriter.notifyMatchFailure(
+ op, "unsupported result type, expected 'vector<[4]x[4]xi32>' or "
+ "'vector<[2]x[2]xi64>'");
+
+ auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
+ auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
+
+ if (!lhsDefOp || !rhsDefOp)
+ return rewriter.notifyMatchFailure(
+ op, "defining op of outerproduct operands must be 'arith.extsi' or "
+ "'arith.extui'");
+
+ auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
+ auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
+
+ if (lhsInType != inputType || rhsInType != inputType)
+ return rewriter.notifyMatchFailure(
+ op, "unsupported input types, expected 'vector<[4]xi8>' or "
+ "'vector<[2]xi16>'");
+ return success();
+ }
+};
+
struct OuterProductFusionPass
: public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
@@ -284,7 +529,8 @@ struct OuterProductFusionPass
void mlir::arm_sme::populateOuterProductFusionPatterns(
RewritePatternSet &patterns) {
- patterns.add<OuterProductFusion2Way>(patterns.getContext());
+ patterns.add<OuterProductFusion2Way, OuterProductFusion4Way>(
+ patterns.getContext());
}
std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index c41504d0e4724..81087cc02099f 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -697,3 +697,179 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
%result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_smopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_smops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_smops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.smops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_umopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_umops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_umops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.umops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_sumops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.sumops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmopa_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmopa_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmops_4way_i8i8_to_i32
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: arm_sme_usmops_4way_i16i16_to_i64
+// CHECK: "arm_sme.intr.usmops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index dcc231332f208..91bb7f51daad0 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -226,3 +226,16 @@ func.func @arm_sme_fmopa_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vecto
%0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
return %0 : vector<[4]x[4]xf32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smopa_4way__bad_tile_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32>
+{
+ // expected-error at +1 {{op failed to verify that tile element size equals lhs element size * 4}}
+ %0 = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+ return %0 : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 0383d9aebfef7..f0e738720932c 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -330,3 +330,578 @@ func.func @outerproduct_widening_2way__bad_defining_op(
return %1 : vector<[4]x[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
+// CHECK-SAME: %[[A0:.*]]: vector<[4]xi8>, %[[B0:.*]]: vector<[4]xi8>, %[[A1:.*]]: vector<[4]xi8>, %[[B1:.*]]: vector<[4]xi8>, %[[A2:.*]]: vector<[4]xi8>, %[[B2:.*]]: vector<[4]xi8>, %[[A3:.*]]: vector<[4]xi8>, %[[B3:.*]]: vector<[4]xi8>,
+// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1>, %[[A2_MASK:.*]]: vector<[4]xi1>, %[[B2_MASK:.*]]: vector<[4]xi1>, %[[A3_MASK:.*]]: vector<[4]xi1>, %[[B3_MASK:.*]]: vector<[4]xi1>
+// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32>
+// CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
+// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
+// CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1>
+// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1>
+// CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i8i8i32
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i16i16i64
+// CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i8i8i32
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i16i16i64
+// CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64
+// CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>,
+ %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %acc = arith.constant dense<0> : vector<[4]x[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64
+// CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
+ %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>,
+ %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>,
+ %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>,
+ %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>,
+ %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>,
+ %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>,
+ %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>,
+ %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> {
+ %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64>
+ %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64>
+ %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64>
+ %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64>
+
+ %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64>
+ %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64>
+
+ %acc = arith.constant dense<0> : vector<[2]x[2]xi64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64>
+
+ return %3 : vector<[2]x[2]xi64>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ca096363e7283..ab46c7adca596 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1243,3 +1243,163 @@ func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vecto
%result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
return %result : vector<[4]x[4]xi32>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_smopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.smopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.smopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.smops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_smops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_smops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.smops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.smops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_umopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.umopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.umopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.umops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_umops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_umops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.umops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.umops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_sumopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_sumopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.sumopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.sumopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.sumops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_sumops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %result : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_sumops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.sumops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %result = arm_sme.sumops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %result : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmopa_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_usmopa_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_usmopa_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.usmopa_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %reuslt = arm_sme.usmopa_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.usmops_4way
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_usmops_4way_i8i8_to_i32(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>) -> vector<[4]x[4]xi32> {
+ // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+ return %reuslt : vector<[4]x[4]xi32>
+}
+
+// -----
+
+func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x[2]xi64> {
+ // CHECK: arm_sme.usmops_4way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
+ return %reuslt : vector<[2]x[2]xi64>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
new file mode 100644
index 0000000000000..98b26beccc25b
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir
@@ -0,0 +1,142 @@
+// DEFINE: %{entry} = test_outerproduct_i8i8i32
+// DEFINE: %{widening_opts} = -arm-sme-outer-product-widening
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
+// DEFINE: -convert-vector-to-arm-sme %{widening_opts} \
+// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE: -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+// REDEFINE: %{entry} = test_masked_outerproduct_i8i8i32
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK
+
+// NOTE: QEMU gives incorrect result for SME SMOPA 4-way outer product
+// instruction (version <= 8.2.0, latest version at time of writing), see:
+// https://gitlab.com/qemu-project/qemu/-/issues/2083
+// This test is expected to fail until a fixed version of QEMU can be used.
+
+// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+// (and installed on CI buildbot).
+// XFAIL: *
+
+// NOTE: there is no non-widening variant for these types and this test can't
+// currently be lowered without the widening pass, therefore we can't check if
+// the result is the same without widening pass like
+// 'test-outerproduct-f16f16f32.mlir' does.
+
+func.func @test_outerproduct_i8i8i32() {
+ %undef = llvm.mlir.undef : vector<[4]xi8>
+
+ %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+ %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+ %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+ %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+ %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+ %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+ %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+ %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+ %2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+ %3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+
+ // CHECK: ( 110, 134, 158, 182
+ // CHECK-NEXT: ( 390, 478, 566, 654
+ // CHECK-NEXT: ( 670, 822, 974, 1126
+ // CHECK-NEXT: ( 950, 1166, 1382, 1598
+ vector.print %3 : vector<[4]x[4]xi32>
+
+ return
+}
+
+func.func @test_masked_outerproduct_i8i8i32() {
+ %undef = llvm.mlir.undef : vector<[4]xi8>
+
+ %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+ %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+ %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+ %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+ %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+ %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+ %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+ %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+ %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+ %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %c4 = arith.constant 4 : index
+
+ %mask0 = vector.create_mask %c1, %c1 : vector<[4]x[4]xi1>
+ %mask1 = vector.create_mask %c1, %c2 : vector<[4]x[4]xi1>
+ %mask2 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+ %mask3 = vector.create_mask %c3, %c4 : vector<[4]x[4]xi1>
+
+ %acc = arith.constant dense<2> : vector<[4]x[4]xi32>
+ %0 = vector.mask %mask0 {
+ vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %1 = vector.mask %mask1 {
+ vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %2 = vector.mask %mask2 {
+ vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+ %3 = vector.mask %mask3 {
+ vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+ } : vector<[4]x[4]xi1> -> vector<[4]x[4]xi32>
+
+ // WITH-MASK: ( 112, 136, 135, 95
+ // WITH-MASK-NEXT: ( 243, 295, 347, 219
+ // WITH-MASK-NEXT: ( 211, 255, 299, 343
+ // WITH-MASK-NEXT: ( 2, 2, 2, 2
+ vector.print %3 : vector<[4]x[4]xi32>
+
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index b63d3f0651569..002b1f9d804a7 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,3 +314,10 @@ llvm.func @arm_sve_convert_to_svbool(
: (vector<[1]xi1>) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: @arm_sve_zip1
+// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
+llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
+ %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+ llvm.return %0 : vector<[8]xf16>
+}
>From 79aaf2357c247271487ea9eb5618792ae896ab79 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 30 Jan 2024 13:00:40 +0000
Subject: [PATCH 6/6] Address comments. Changes:
- add common match failures.
- move isCompatible to static function.
- update isCompatible to take optional `rhsExtType`.
- use isCompatible in-place of isWidenable.
- add canFuseOuterProducts for 4-way.
- llvm::hasSingleElement -> hasOneUse.
- op.erase -> rewriter.eraseOp.
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 4 +-
.../ArmSME/Transforms/OuterProductFusion.cpp | 393 ++++++-------
.../Dialect/ArmSME/outer-product-fusion.mlir | 532 +++++++++++++-----
mlir/test/Target/LLVMIR/arm-sve.mlir | 7 -
4 files changed, 581 insertions(+), 355 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ede71fdb647eb..c93f9bdc00ac8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -1213,7 +1213,7 @@ def SMopa4WayOp
The 4 outer products in the example above can be fused into a single outer
product as follows:
- ```mlir
+ ```mlir
%lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
@@ -1223,7 +1223,7 @@ def SMopa4WayOp
%rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>
%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
- ```
+ ```
This is implemented in the `-arm-sme-outer-product-fusion` pass.
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 1fc6418c09d49..507ccdb476924 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -32,6 +32,55 @@ using namespace mlir;
using namespace mlir::arm_sme;
namespace {
+
+// Common match failure reasons.
+static constexpr StringLiteral
+ MATCH_FAILURE_NO_ACCUMULATOR("no accumulator operand");
+static constexpr StringLiteral MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP(
+ "defining op of accumulator must be 'arm_sme.outerproduct'");
+static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_COMBINING_KIND(
+ "combining kind (add or sub) of outer products must match");
+static constexpr StringLiteral MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE(
+ "outer product(s) not single use and cannot be removed, no benefit to "
+ "fusing");
+static constexpr StringLiteral MATCH_FAILURE_INCONSISTENT_MASKING(
+ "unsupported masking, either both outerproducts are masked "
+ "or neither");
+
+// An outer product is compatible if all of the following are true:
+// - the result type matches `resultType`.
+// - the defining operation of LHS is of the type `LhsExtOp`.
+// - the defining operation of RHS is of the type `RhsExtOp`.
+// - the input types of the defining operations are identical and match
+// `inputType`.
+template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
+static LogicalResult isCompatible(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op,
+ VectorType resultType, VectorType inputType) {
+ if (op.getResultType() != resultType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported result type, expected " << resultType;
+ });
+
+ auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
+ auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
+
+ if (!lhsDefOp || !rhsDefOp)
+ return rewriter.notifyMatchFailure(
+ op, "defining op of outerproduct operands must be 'arith.extf' or "
+ "'arith.extsi' or 'arith.extui'");
+
+ auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
+ auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
+
+ if (lhsInType != inputType || rhsInType != inputType)
+ return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
+ diag << "unsupported input type, expected " << inputType;
+ });
+
+ return success();
+}
+
// Fuse two 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 2-way outer product operation.
//
@@ -64,18 +113,17 @@ class OuterProductFusion2Way
PatternRewriter &rewriter) const override {
Value acc = op.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
arm_sme::OuterProductOp op2 = op;
if (!op1)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
if (op1.getKind() != op2.getKind())
return rewriter.notifyMatchFailure(
- op, "combining kind (add or sub) of outer products must match");
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
if (!op1->hasOneUse()) {
// If the first outer product has uses other than as the input to another
@@ -102,14 +150,12 @@ class OuterProductFusion2Way
// No accumulator would be ok, but it's simpler to prevent this
// altogether, since it has no benefit.
return rewriter.notifyMatchFailure(
- op, "first outer product is not single use and cannot be removed, "
- "no benefit to fusing");
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
}
if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()))
- return rewriter.notifyMatchFailure(
- op, "unsupported masking, either both outerproducts are masked "
- "or neither");
+ return rewriter.notifyMatchFailure(op,
+ MATCH_FAILURE_INCONSISTENT_MASKING);
if (failed(canFuseOuterProducts(rewriter, op1, op2)))
return failure();
@@ -231,43 +277,9 @@ class OuterProductFusion2Way
return success();
}
-
- // An outer product is compatible if all of the following are true:
- // - the result type matches `resultType`.
- // - the defining operations of the inputs are identical and of the type
- // `ExtOp`.
- // - the input types of the defining operations are identical and match
- // `inputType`.
- template <typename ExtOp>
- LogicalResult isCompatible(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op, VectorType resultType,
- VectorType inputType) const {
- if (op.getResultType() != resultType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported result type, expected " << resultType;
- });
-
- auto lhsDefOp = op.getLhs().getDefiningOp<ExtOp>();
- auto rhsDefOp = op.getRhs().getDefiningOp<ExtOp>();
-
- if (!lhsDefOp || !rhsDefOp)
- return rewriter.notifyMatchFailure(
- op, "defining op of outerproduct operands must be 'arith.extf' or "
- "'arith.extsi' or 'arith.extui'");
-
- auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
- auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
-
- if (lhsInType != inputType || rhsInType != inputType)
- return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) {
- diag << "unsupported input type, expected " << inputType;
- });
-
- return success();
- }
};
-// Fold four 'arm_sme.outerproduct' operations that are chained via the
+// Fuse four 'arm_sme.outerproduct' operations that are chained via the
// accumulator into 4-way outer product operation.
class OuterProductFusion4Way
: public OpRewritePattern<arm_sme::OuterProductOp> {
@@ -278,126 +290,47 @@ class OuterProductFusion4Way
PatternRewriter &rewriter) const override {
Value acc = op.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
arm_sme::OuterProductOp op4 = op;
arm_sme::OuterProductOp op3 = acc.getDefiningOp<arm_sme::OuterProductOp>();
if (!op3)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
acc = op3.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
arm_sme::OuterProductOp op2 = acc.getDefiningOp<arm_sme::OuterProductOp>();
if (!op2)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
acc = op2.getAcc();
if (!acc)
- return rewriter.notifyMatchFailure(op, "no accumulator operand");
+ return rewriter.notifyMatchFailure(op, MATCH_FAILURE_NO_ACCUMULATOR);
arm_sme::OuterProductOp op1 = acc.getDefiningOp<arm_sme::OuterProductOp>();
if (!op1)
- return rewriter.notifyMatchFailure(op,
- "defining op of accumulator operand "
- "must be an 'arm_sme.outerproduct'");
+ return rewriter.notifyMatchFailure(
+ op, MATCH_FAILURE_EXPECTED_OUTERPRODUCT_DEF_OP);
arm_sme::CombiningKind kind = op1.getKind();
if (op2.getKind() != kind || op3.getKind() != kind || op4.getKind() != kind)
return rewriter.notifyMatchFailure(
- op, "combining kind (add or sub) of outer products must match");
+ op, MATCH_FAILURE_INCONSISTENT_COMBINING_KIND);
- if (!llvm::hasSingleElement(op1->getUses()) ||
- !llvm::hasSingleElement(op2->getUses()) ||
- !llvm::hasSingleElement(op3->getUses()))
+ if (!op1->hasOneUse() || !op2->hasOneUse() || !op3->hasOneUse())
return rewriter.notifyMatchFailure(
- op, "outer products are not single use and cannot be removed, "
- "no benefit to widening");
+ op, MATCH_FAILURE_OUTERPRODUCT_NOT_SINGLE_USE);
- auto nxnxv4i32 =
- VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
- auto nxnxv2i64 =
- VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
- auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
- auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
- if (
- // signed, i8i8i32
- (failed(
- isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // signed, i16i16i64
- (failed(
- isWidenable<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtSIOp>(rewriter, op4, nxnxv2i64, nxv2i16))) &&
- // unsigned, i8i8i32
- (failed(
- isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // unsigned, i16i16i64
- (failed(
- isWidenable<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(
- isWidenable<arith::ExtUIOp>(rewriter, op4, nxnxv2i64, nxv2i16))) &&
- // signed by unsigned, i8i8i32
- (failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // signed by unsigned, i16i16i64
- (failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtSIOp, arith::ExtUIOp>(
- rewriter, op4, nxnxv2i64, nxv2i16))) &&
- // unsigned by signed, i8i8i32
- (failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op1, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op2, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op3, nxnxv4i32, nxv4i8)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op4, nxnxv4i32, nxv4i8))) &&
- // unsigned by signed, i16i16i64
- (failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op1, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op2, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op3, nxnxv2i64, nxv2i16)) ||
- failed(isWidenable<arith::ExtUIOp, arith::ExtSIOp>(
- rewriter, op4, nxnxv2i64, nxv2i16))))
+ if (bool(op1.getLhsMask()) != bool(op2.getLhsMask()) !=
+ bool(op3.getLhsMask()) != bool(op4.getLhsMask()))
+ return rewriter.notifyMatchFailure(op,
+ MATCH_FAILURE_INCONSISTENT_MASKING);
+
+ if (failed(canFuseOuterProducts(rewriter, op1, op2, op3, op4)))
return failure();
auto loc = op.getLoc();
@@ -427,12 +360,6 @@ class OuterProductFusion4Way
Value lhsMask, rhsMask;
if (op1.getLhsMask() || op2.getLhsMask() || op3.getLhsMask() ||
op4.getLhsMask()) {
- if (!(op1.getLhsMask() && op2.getLhsMask() && op3.getLhsMask() &&
- op4.getLhsMask()))
- return rewriter.notifyMatchFailure(
- op, "unsupported masking, either all outerproducts are masked "
- "or none");
-
auto lhs0Mask = packInputs(op1.getLhsMask(), op3.getLhsMask());
auto lhs1Mask = packInputs(op2.getLhsMask(), op4.getLhsMask());
lhsMask = packInputs(lhs0Mask, lhs1Mask);
@@ -442,72 +369,146 @@ class OuterProductFusion4Way
rhsMask = packInputs(rhs0Mask, rhs1Mask);
}
- assert((kind == arm_sme::CombiningKind::Add ||
- kind == arm_sme::CombiningKind::Sub) &&
- "unhandled arm_sme::CombiningKind!");
- if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
- if (kind == arm_sme::CombiningKind::Add)
+ if (kind == arm_sme::CombiningKind::Add) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
- op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp)) {
- if (kind == arm_sme::CombiningKind::Add)
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
- op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp)) {
- if (kind == arm_sme::CombiningKind::Add)
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- else
- rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
- op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp)) {
- if (kind == arm_sme::CombiningKind::Add)
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else
+ llvm_unreachable("unexpected extend op!");
+ } else if (kind == arm_sme::CombiningKind::Sub) {
+ if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
+ rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
+ op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
+ else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
- } else
- llvm_unreachable("unexpected extend op!");
+ else
+ llvm_unreachable("unexpected extend op!");
+ } else {
+ llvm_unreachable("unexpected arm_sme::CombiningKind!");
+ }
- op3.erase();
- op2.erase();
- op1.erase();
+ rewriter.eraseOp(op3);
+ rewriter.eraseOp(op2);
+ rewriter.eraseOp(op1);
return success();
}
private:
- template <typename LhsExtOp, typename RhsExtOp = LhsExtOp>
- LogicalResult isWidenable(PatternRewriter &rewriter,
- arm_sme::OuterProductOp op, VectorType resultType,
- VectorType inputType) const {
- if (op.getResultType() != resultType)
- return rewriter.notifyMatchFailure(
- op, "unsupported result type, expected 'vector<[4]x[4]xi32>' or "
- "'vector<[2]x[2]xi64>'");
-
- auto lhsDefOp = op.getLhs().getDefiningOp<LhsExtOp>();
- auto rhsDefOp = op.getRhs().getDefiningOp<RhsExtOp>();
-
- if (!lhsDefOp || !rhsDefOp)
- return rewriter.notifyMatchFailure(
- op, "defining op of outerproduct operands must be 'arith.extsi' or "
- "'arith.extui'");
-
- auto lhsInType = cast<VectorType>(lhsDefOp->getOperand(0).getType());
- auto rhsInType = cast<VectorType>(rhsDefOp->getOperand(0).getType());
+ // Four outer products can be fused if all of the following are true:
+ // - input and result types match.
+ // - the defining operations of the inputs are identical extensions,
+ // specifically either:
+ // - a signed or unsigned extension for integer types.
+ // - a floating-point extension for floating-point types.
+ // - the types and extension are supported, i.e. there's a 4-way operation
+ // they can be fused into.
+ LogicalResult canFuseOuterProducts(PatternRewriter &rewriter,
+ arm_sme::OuterProductOp op1,
+ arm_sme::OuterProductOp op2,
+ arm_sme::OuterProductOp op3,
+ arm_sme::OuterProductOp op4) const {
+ // Supported result types.
+ auto nxnxv4i32 =
+ VectorType::get({4, 4}, rewriter.getI32Type(), {true, true});
+ auto nxnxv2i64 =
+ VectorType::get({2, 2}, rewriter.getI64Type(), {true, true});
+ // Supported input types.
+ // Note: this is before packing so these have 1/4 the number of elements
+ // of the input vector types of the 4-way operations.
+ auto nxv4i8 = VectorType::get({4}, rewriter.getI8Type(), true);
+ auto nxv2i16 = VectorType::get({2}, rewriter.getI16Type(), true);
+ if (
+ // signed, i8i8i32
+ (failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed, i16i16i64
+ (failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtSIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp>(rewriter, op4, nxnxv2i64,
+ nxv2i16))) &&
+ // unsigned, i8i8i32
+ (failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned, i16i16i64
+ (failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(
+ isCompatible<arith::ExtUIOp>(rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp>(rewriter, op4, nxnxv2i64,
+ nxv2i16))) &&
+ // signed by unsigned, i8i8i32
+ (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // signed by unsigned, i16i16i64
+ (failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtSIOp, arith::ExtUIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))) &&
+ // unsigned by signed, i8i8i32
+ (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv4i32, nxv4i8)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv4i32, nxv4i8))) &&
+ // unsigned by signed, i16i16i64
+ (failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op1, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op2, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op3, nxnxv2i64, nxv2i16)) ||
+ failed(isCompatible<arith::ExtUIOp, arith::ExtSIOp>(
+ rewriter, op4, nxnxv2i64, nxv2i16))))
+ return failure();
- if (lhsInType != inputType || rhsInType != inputType)
- return rewriter.notifyMatchFailure(
- op, "unsupported input types, expected 'vector<[4]xi8>' or "
- "'vector<[2]xi16>'");
return success();
}
};
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index f0e738720932c..c6c17cec457e3 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -181,156 +181,6 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
return %1 : vector<[4]x[4]xi32>
}
-/// Negative tests
-
-// -----
-
-// CHECK-LABEL: @outerproduct_widening_2way__no_acc
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
- %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
- %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
-
- return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
-
-// CHECK-LABEL: @outerproduct_widening_2way__bad_acc
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
- %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
- %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
-
- return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-/// Combining kinds of outer products must match to be fused.
-
-// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__bad_combining_kind(
- %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
- %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
- %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
- %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
- %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
- %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<add> : vector<[4]xf32>, vector<[4]xf32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) : vector<[4]xf32>, vector<[4]xf32>
-
- return %1 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-/// If the first outer product has uses other than as the input to another
-/// outer product, it can't be erased after fusion. This is a problem when
-/// it also has an accumulator as this will be used as the root for tile
-/// allocation and since the widening outer product uses the same
-/// accumulator it will get assigned the same tile ID, resulting in 3
-/// outer products and incorrect results. Check this is prevented.
-
-// CHECK-LABEL: @outerproduct_widening_2way__cant_erase
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__cant_erase(
- %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
- %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
- %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
- %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
- %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
- %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
-
- %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
- "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
-
- return %1 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
- %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>,
- %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> {
- %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64>
- %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64>
- %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64>
- %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
-
- return %1 : vector<[2]x[2]xf64>
-}
-
-// -----
-
-/// Fusion only occurs if either both outer products are masked, or neither.
-
-// CHECK-LABEL: @outerproduct_widening_2way__bad_masking
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__bad_masking(
- %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
- %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
- %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
- %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
- %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
- %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
- %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
-
- %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
- %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
-
- return %1 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-/// Defining op of outer product must be a supported extension op.
-
-// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
-// CHECK-NOT: arm_sme.fmopa_2way
-// CHECK: arm_sme.outerproduct
-// CHECK: arm_sme.outerproduct
-// CHECK-NOT: arm_sme.fmopa_2way
-func.func @outerproduct_widening_2way__bad_defining_op(
- %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>,
- %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> {
- %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32>
- %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32>
-
- return %1 : vector<[4]x[4]xf32>
-}
-
// -----
// CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32
@@ -905,3 +755,385 @@ func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64(
return %3 : vector<[2]x[2]xi64>
}
+
+/// Negative tests
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_2way__no_acc
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__no_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__no_acc(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %2 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_acc
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %0 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_acc
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_acc(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ // break chain
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+/// Combining kinds of outer products must match to be fused.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_combining_kind(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<add> : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_combining_kind
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_combining_kind(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<add> acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<add> acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<add> acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+/// If the first outer product has uses other than as the input to another
+/// outer product, it can't be erased after fusion. This is a problem when
+/// it also has an accumulator as this will be used as the root for tile
+/// allocation and since the widening outer product uses the same
+/// accumulator it will get assigned the same tile ID, resulting in 3
+/// outer products and incorrect results. Check this is prevented.
+
+// CHECK-LABEL: @outerproduct_widening_2way__cant_erase
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__cant_erase(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
+ "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__cant_erase
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__cant_erase(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__unsupported_type_f32f32f64(
+ %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>,
+ %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> {
+ %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64>
+ %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64>
+ %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64>
+ %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
+
+ return %1 : vector<[2]x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__unsupported_type_f16f16f64
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__unsupported_type_f16f16f64(
+ %a0 : vector<[2]xf16>, %b0 : vector<[2]xf16>,
+ %a1 : vector<[2]xf16>, %b1 : vector<[2]xf16>,
+ %a2 : vector<[2]xf16>, %b2 : vector<[2]xf16>,
+ %a3 : vector<[2]xf16>, %b3 : vector<[2]xf16>) -> vector<[2]x[2]xf64> {
+ %a0_ext = arith.extf %a0 : vector<[2]xf16> to vector<[2]xf64>
+ %b0_ext = arith.extf %b0 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a1_ext = arith.extf %a1 : vector<[2]xf16> to vector<[2]xf64>
+ %b1_ext = arith.extf %b1 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a2_ext = arith.extf %a2 : vector<[2]xf16> to vector<[2]xf64>
+ %b2_ext = arith.extf %b2 : vector<[2]xf16> to vector<[2]xf64>
+
+ %a3_ext = arith.extf %a3 : vector<[2]xf16> to vector<[2]xf64>
+ %b3_ext = arith.extf %b3 : vector<[2]xf16> to vector<[2]xf64>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[2]xf64>, vector<[2]xf64>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[2]xf64>, vector<[2]xf64>
+
+ return %3 : vector<[2]x[2]xf64>
+}
+
+// -----
+
+/// Fusion only occurs if either both outer products are masked, or neither.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_masking
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_masking(
+ %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>,
+ %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>,
+ %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_masking
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_masking(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>,
+ %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+ %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
+
+// -----
+
+/// Defining op of outer product must be a supported extension op.
+
+// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_2way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_2way
+func.func @outerproduct_widening_2way__bad_defining_op(
+ %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>,
+ %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> {
+ %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32>
+ %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+
+ return %1 : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op
+// CHECK-NOT: arm_sme.fmopa_4way
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK: arm_sme.outerproduct
+// CHECK-NOT: arm_sme.fmopa_4way
+func.func @outerproduct_widening_4way__bad_defining_op(
+ %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>,
+ %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>,
+ %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>,
+ %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> {
+ %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+ %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+ %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+ %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+ %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+ %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+ %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+
+ return %3 : vector<[4]x[4]xi32>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 002b1f9d804a7..b63d3f0651569 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -314,10 +314,3 @@ llvm.func @arm_sve_convert_to_svbool(
: (vector<[1]xi1>) -> vector<[16]xi1>
llvm.return
}
-
-// CHECK-LABEL: @arm_sve_zip1
-// CHECK-NEXT: call <vscale x 8 x half> @llvm.aarch64.sve.zip1.nxv8f16(<vscale x 8 x half> %{{.*}}, <vscale x 8 x half> {{.*}})
-llvm.func @arm_sve_zip1(%arg0 : vector<[8]xf16>) -> vector<[8]xf16> {
- %0 = "arm_sve.intr.zip1"(%arg0, %arg0) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
- llvm.return %0 : vector<[8]xf16>
-}
More information about the Mlir-commits
mailing list