[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Oct 30 06:56:10 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/69604

>From 1e2197d37e087f68a94d906346476fd1fda8041d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 11:58:34 +0000
Subject: [PATCH 1/9] [mlir][ArmSME] Support lowering masked
 vector.outerproduct ops to SME

This patch adds support for lowering masked outer products to SME. This
is done in two stages. First, vector.outerproducts (both masked and
non-masked) are rewritten to arm_sme.outerproducts. The
arm_sme.outerproduct op is close to vector.outerproduct, but supports
masking on the operands rather than the result. It also limits the
cases it handles to things that could be lowered to SME, but does not
enforce that everything matches SME tiles at this level.

This currently requires that the source of the mask is a
vector.create_mask op. E.g.:

```mlir
%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
```
Is rewritten to:
```
%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
```
(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 108 ++++++++++++++++++
 .../VectorToArmSME/VectorToArmSME.cpp         |  77 ++++++++++++-
 .../Transforms/LegalizeForLLVMExport.cpp      |  60 +++++-----
 mlir/test/Dialect/ArmSME/invalid.mlir         |  33 ++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       |  53 +++++++++
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  74 +++++++++++-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  96 ++++++++++++++++
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  70 ++++++++++++
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  35 ++++++
 9 files changed, 566 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index b30d0fdb866bd23..3e22398e513992a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -79,6 +79,23 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
   let defaultValue = "TileSliceLayout::Horizontal";
 }
 
+def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
+  I32EnumAttrCase<"Add", 0, "add">,
+  I32EnumAttrCase<"Sub", 1, "sub">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies how to combine a newly produced value with the
+/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
+/// the functions that are valid for SME outer products.
+def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
+                                          "kind"> {
+  let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "CombiningKind::Add";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -561,4 +578,95 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
+  TypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
+    "!" # maskGetter # "() || std::equal_to<>()">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    TypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)",
+      "!getAcc() || std::equal_to<>()">
+  ]>
+{
+  let summary = "Vector outerproduct with optional fused add";
+
+  let description = [{
+    This op is based on `vector.outerproduct` with the extra conditions that:
+
+    * AXPY operations are not supported
+    * The only combining functions are "add" and "sub"
+    * Masking is performed on the inputs (rather than the output)
+
+    This is meant as an intermediate op for lowering `vector.outerproduct` to
+    SME. Types are not restricted to SVE/SME vectors at this level.
+
+    Example 1: Unmasked outerproduct (without accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 2: Unmasked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Masked outerproduct
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 4: Masked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    VectorOfRank<[1]>:$lhs, VectorOfRank<[1]>:$rhs,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$lhsMask,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$rhsMask,
+    Optional<VectorOfRank<[2]>>: $acc,
+    ArmSME_CombiningKindAttr:$kind);
+  let results = (outs VectorOfRank<[2]>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `kind` `` $kind
+      | `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict
+    `:` type($lhs) `,` type($rhs) `,` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return getLhs().getType(); }
+    VectorType getRhsType() { return getRhs().getType(); }
+    VectorType getResultType() { return getResult().getType(); }
+  }];
+}
+
 #endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..b7ec2603477d59e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,6 +427,80 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Lowers a masked `vector.outerproduct` to `arm_sme.outerproduct`.
+/// The 2-D mask of the `vector.outerproduct` (if from a `vector.create_mask`)
+/// is decomposed into two 1-D masks for the operands.
+///
+///  BEFORE:
+///  ```mlir
+///  %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///  %result = vector.mask %mask {
+///               vector.outerproduct %vecA, %vecB
+///                : vector<[4]xf32>, vector<[4]xf32>
+///            } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///  %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///  ```
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+    // AXPY operation not suited for SME.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
+
+    auto kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      return outerProductOp.emitError("unsupported kind");
+
+    Value lhsMask = {};
+    Value rhsMask = {};
+    Operation *rootOp = outerProductOp;
+    if (outerProductOp.isMasked()) {
+      auto maskingOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskingOp);
+      rootOp = maskingOp;
+
+      // Attempt to extract masks from vector.create_mask.
+      // TODO: Add support for other mask sources.
+      auto mask = maskingOp.getMask();
+      auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+
+      auto maskType = createMaskOp.getVectorType();
+      if (maskType.getRank() != 2)
+        return failure();
+
+      auto loc = outerProductOp.getLoc();
+
+      Value lhsMaskDim = createMaskOp.getOperand(0);
+      Value rhsMaskDim = createMaskOp.getOperand(1);
+
+      VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+      lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      lhsMaskDim);
+      rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      rhsMaskDim);
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -434,5 +508,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
                SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
+               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+               VectorOuterProductToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 1231da356f8ed95..32c01d3b1f49c26 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -474,13 +474,13 @@ struct MoveTileSliceToVectorArmSMELowering
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
-struct VectorOuterProductToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
-  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+struct OuterProductToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
+  using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp,
-                  vector::OuterProductOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
+                  arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
@@ -512,24 +512,13 @@ struct VectorOuterProductToArmSMELowering
       return true;
     };
 
-    auto resultVectorType = outerProductOp.getResultVectorType();
-    if (!isSupportedType(resultVectorType))
-      return outerProductOp.emitError("unsupported type");
-
-    vector::CombiningKind kind = outerProductOp.getKind();
-    if (kind != vector::CombiningKind::ADD)
-      // TODO: support subtract.
+    // TODO: Support CombiningKind::Sub for outer products.
+    if (outerProductOp.getKind() != CombiningKind::Add)
       return outerProductOp.emitError("unsupported kind");
 
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: support masking.
-      return outerProductOp.emitError("masking is currently unsupported");
-
-    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      // AXPY operation not suited for SME.
-      return failure();
+    auto resultVectorType = outerProductOp.getResultType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
 
     auto loc = outerProductOp.getLoc();
 
@@ -542,21 +531,24 @@ struct VectorOuterProductToArmSMELowering
     auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
         loc, rewriter.getIntegerType(elementWidth), acc);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy =
-        VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
-                        /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
-
     auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
 
+    Value lhsMask = outerProductOp.getLhsMask();
+    Value rhsMask = outerProductOp.getRhsMask();
+
+    if (!lhsMask || !rhsMask) {
+      auto predTy =
+          outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
+      Value allActiveMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predTy, true));
+      lhsMask = allActiveMask;
+      rhsMask = allActiveMask;
+    }
+
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(
-        loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
-        outerProductOp.getRhs());
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+                                               outerProductOp.getLhs(),
+                                               outerProductOp.getRhs());
 
     // Create `CastTileToVectorOp` to use as the output.
     rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
@@ -733,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      VectorOuterProductToArmSMELowering, ZeroOpConversion,
+      OuterProductToArmSMELowering, ZeroOpConversion,
       VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 25c62f78d843543..d373953e6ed4836 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -150,3 +150,36 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_mask_type(%vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<5xi1>, %maskB: vector<[2]xi1>)-> vector<3x[2]xf32>
+{
+  // expected-error at +1 {{use of value '%maskA' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x1xi16>
+{
+  // expected-error at +1 {{op failed to verify that result type is derived from `lhs` and `rhs`}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[2]x1xi16>
+  return %0 : vector<[2]x1xi16>
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_acc_type(%vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<6x6xi32>) -> vector<7x6xi32>
+{
+  // expected-error at +1 {{use of value '%acc' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %0 : vector<7x6xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 6866137267dc66a..7a1fa26d0e64807 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1161,3 +1161,56 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  return %result : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_masking(
+  %vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<3xi1>, %maskB: vector<[2]xi1>
+) -> vector<3x[2]xf32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %result : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_acc(
+  %vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<7x6xi32>
+) -> vector<7x6xi32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %result : vector<7x6xi32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  return %result : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_everything(
+  %vecA: vector<[4]xf16>, %vecB: vector<4xf16>, %acc: vector<[4]x4xf16>,
+  %maskA: vector<[4]xi1>, %maskB: vector<4xi1>
+) -> vector<[4]x4xf16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}})
+  // CHECK-SAME: : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB)
+              : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  return %result : vector<[4]x4xf16>
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 32f46d9fd817c9d..f615a9ef0231443 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -463,9 +463,75 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
+  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
 // CHECK-LABEL: @vector_outerproduct_unsupported_axpy
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // CHECK-NOT: arm_sme
+  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
@@ -473,7 +539,6 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 // -----
 
 func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -490,9 +555,8 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
-  // expected-error at +1 {{masking is currently unsupported}}
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..21edbc259e124c3 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -578,3 +578,99 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
+func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
+func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>
+func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>
+func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 38ba489e2fafb2c..ab51efcc8036507 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -14,6 +14,12 @@
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
 
+// REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+
 func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
 
@@ -61,3 +67,67 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
 
   return
 }
+
+func.func @test_masked_outerproduct_no_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 3 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 3x2xf32 section.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 1, 2, 0, 0
+  // WITH-MASK-NEXT: ( 2, 4, 0, 0
+  // WITH-MASK-NEXT: ( 3, 6, 0, 0
+  // WITH-MASK-NEXT: ( 0, 0, 0, 0
+  // WITH-MASK:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[4]x[4]xf32>
+  vector.print str "TILE END"
+
+  return
+}
+
+func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+  %f10 = arith.constant 10.0 : f32
+
+  %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 2 : index
+  %rhsDim = arith.constant 3 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 2x3xf32 section.
+  //
+  // WITH-MASK-AND-ACC:      TILE BEGIN
+  // WITH-MASK-AND-ACC-NEXT: ( 11, 12, 13, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 12, 14, 16, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[4]x[4]xf32>
+  vector.print str "TILE END"
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 82f14595a24da2f..33ac503c027b3b5 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -11,6 +11,9 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+
 func.func @test_outerproduct_with_accumulator_2x2xf64() {
   %f1 = arith.constant 1.0 : f64
   %f2 = arith.constant 2.0 : f64
@@ -36,3 +39,35 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
 
   return
 }
+
+func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %acc = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 1 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
+  } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 11, 12
+  // WITH-MASK-NEXT: ( 10, 10
+  // WITH-MASK:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}

>From 895233682ecfdeacc7649611020bcc97326118ac Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 14:35:18 +0000
Subject: [PATCH 2/9] Make comment in VectorToArmSME.cpp match others in style

---
 .../VectorToArmSME/VectorToArmSME.cpp         | 37 ++++++++++---------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index b7ec2603477d59e..44e285211a0566c 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,26 +427,27 @@ struct TransposeOpToArmSMELowering
   }
 };
 
-/// Lowers a masked `vector.outerproduct` to `arm_sme.outerproduct`.
-/// The 2-D mask of the `vector.outerproduct` (if from a `vector.create_mask`)
-/// is decomposed into two 1-D masks for the operands.
+/// Conversion pattern for vector.outerproduct.
 ///
-///  BEFORE:
-///  ```mlir
-///  %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
-///  %result = vector.mask %mask {
-///               vector.outerproduct %vecA, %vecB
-///                : vector<[4]xf32>, vector<[4]xf32>
-///            } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
-///  ```
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///   %result = vector.mask %mask {
+///                vector.outerproduct %vecA, %vecB
+///                 : vector<[4]xf32>, vector<[4]xf32>
+///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
 ///
-///  AFTER:
-///  ```mlir
-///  %maskA = vector.create_mask %dimA : vector<[4]xi1>
-///  %maskB = vector.create_mask %dimB : vector<[4]xi1>
-///  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
-///              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
-///  ```
 struct VectorOuterProductToArmSMELowering
     : public OpRewritePattern<vector::OuterProductOp> {
 

>From 9c5a737a5661811b38fd5bed1f25069367f9db87 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 20 Oct 2023 09:14:17 +0000
Subject: [PATCH 3/9] Switch to OptionalTypesMatchWith

---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 3e22398e513992a..2fc25f618a315ae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -578,19 +578,18 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
-class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
-  TypesMatchWith<
+class HasMatchingMaskTypeConstraint<string operand> :
+  OptionalTypesMatchWith<
     "shape of `" # operand #  "Mask` matches `" # operand # "`",
     operand, operand # "Mask",
-    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
-    "!" # maskGetter # "() || std::equal_to<>()">;
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [Pure,
     AttrSizedOperandSegments,
     AllElementTypesMatch<["lhs", "rhs", "result"]>,
-    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
-    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    HasMatchingMaskTypeConstraint<"lhs">,
+    HasMatchingMaskTypeConstraint<"rhs">,
     PredOpTrait<
       "both `lhsMask` and `rhsMask` should be provided or neither",
       CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
@@ -601,10 +600,9 @@ def OuterProductOp :
           "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
           "getRhsType().getElementType(),"
           "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
-    TypesMatchWith<"`result` and `acc` have the same type",
+    OptionalTypesMatchWith<"`result` and `acc` have the same type",
       "result", "acc",
-      "::llvm::cast<mlir::Type>($_self)",
-      "!getAcc() || std::equal_to<>()">
+      "::llvm::cast<mlir::Type>($_self)">
   ]>
 {
   let summary = "Vector outerproduct with optional fused add";

>From b0776f14a80c0c07110f403daddc80b9d139763e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 25 Oct 2023 15:57:33 +0000
Subject: [PATCH 4/9] Fixups

This also adds some no accumulator tests to test-outerproduct-f64.mlir
(and removes some now done TODOs, e.g. vector.splat support).
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |   8 +-
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  12 +-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  24 ++--
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |   4 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     | 103 ++++++++++++++----
 5 files changed, 112 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 2fc25f618a315ae..bacfa282f6880a5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -89,7 +89,11 @@ def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
 
 /// An attribute that specifies how to combine a newly produced value with the
 /// accumulator. This is similar to vector::CombiningKindAttr, but limited to
-/// the functions that are valid for SME outer products.
+/// the functions that are valid for SME outer products. Add corresponds to a
+/// MOPA and sub to a MOPS.
+/// E.g. For f32:
+/// FMOPA: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
+/// FMOPS: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPS--non-widening---Floating-point-outer-product-and-subtract-
 def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
                                           "kind"> {
   let assemblyFormat = "`<` $value `>`";
@@ -226,7 +230,7 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let results = (outs SMETile:$res);
   let description = [{
     Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction. 
+    `zero` intrinsic and instruction.
 
     Example 1: Zero an 8-bit element ZA tile.
 
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index f615a9ef0231443..13f24ba8f0e7e3d 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -464,10 +464,16 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
-  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
-  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
+  // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM0]] : vector<[4]xi32>
+  // CHECK: %[[DIM1_I32:.*]] = arith.index_cast %[[DIM1]] : index to i32
+  // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertelement %[[DIM1_I32]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[SPLAT_DIM1:.*]] = llvm.shufflevector %[[INSERT_DIM1]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM1]] : vector<[4]xi32>
   // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
   // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 21edbc259e124c3..c62eb1bd7dfbe71 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -586,11 +586,11 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
   // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
@@ -599,11 +599,11 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
   // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
@@ -612,11 +612,11 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
   // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
@@ -625,11 +625,11 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
 func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
   // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index ab51efcc8036507..5ed6bffd9335de4 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -47,7 +47,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
   %f10 = arith.constant 10.0 : f32
 
-  %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+  %acc = vector.splat %f10 : vector<[4]x[4]xf32>
   %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32>
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
@@ -104,7 +104,7 @@ func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
   %ones = arith.constant dense<1> : vector<[4]xi32>
   %f10 = arith.constant 10.0 : f32
 
-  %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+  %acc = vector.splat %f10 : vector<[4]x[4]xf32>
   %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
   %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 33ac503c027b3b5..fe5d655daf770ac 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,4 +1,4 @@
-// DEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
+// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
@@ -11,27 +11,35 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
-// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
+// REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_2x2xf64
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
 
-func.func @test_outerproduct_with_accumulator_2x2xf64() {
-  %f1 = arith.constant 1.0 : f64
-  %f2 = arith.constant 2.0 : f64
-  %f10 = arith.constant 10.0 : f64
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+
+func.func @test_outerproduct_no_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
 
-  %a = vector.splat %f1 : vector<[2]xf64>
-  %b = vector.splat %f2 : vector<[2]xf64>
-  // TODO: vector.splat doesn't support ArmSME.
-  %c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 1 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
 
-  %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
+  %tile = vector.outerproduct %vector, %vector : vector<[2]xf64>, vector<[2]xf64>
 
   // Print the tile. The smallest SVL is 128-bits so the tile will be at least
   // 2x2xf64.
   //
   // CHECK:      TILE BEGIN
-  // CHECK-NEXT: ( 12, 12
-  // CHECK-NEXT: ( 12, 12
+  // CHECK-NEXT: ( 1, 2
+  // CHECK-NEXT: ( 2, 4
   // CHECK:      TILE END
   vector.print str "TILE BEGIN"
   vector.print %tile : vector<[2]x[2]xf64>
@@ -40,12 +48,68 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
   return
 }
 
+func.func @test_outerproduct_with_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %acc = vector.splat %f10 : vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %tile = vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
+
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
+  //
+  // WITH-ACC:      TILE BEGIN
+  // WITH-ACC-NEXT: ( 11, 12
+  // WITH-ACC-NEXT: ( 12, 14
+  // WITH-ACC:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}
+
+func.func @test_masked_outerproduct_no_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 2 : index
+  %rhsDim = arith.constant 1 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector : vector<[2]xf64>, vector<[2]xf64>
+  } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+
+  // Print the tile. Due to masking the result will be the top 2x1xf64 section.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 1, 0
+  // WITH-MASK-NEXT: ( 2, 0
+  // WITH-MASK:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}
+
 func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
   %c0 = arith.constant 0 : index
   %ones = arith.constant dense<1> : vector<[2]xi32>
   %f10 = arith.constant 10.0 : f64
 
-  %acc = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+  %acc = vector.splat %f10 : vector<[2]x[2]xf64>
   %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
   %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
   %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
@@ -58,13 +122,12 @@ func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
     vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
   } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
 
-  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
-  // 2x2xf64.
+  // Print the tile. Due to masking the result will be the top 1x2xf64 section.
   //
-  // WITH-MASK:      TILE BEGIN
-  // WITH-MASK-NEXT: ( 11, 12
-  // WITH-MASK-NEXT: ( 10, 10
-  // WITH-MASK:      TILE END
+  // WITH-MASK-AND-ACC:      TILE BEGIN
+  // WITH-MASK-AND-ACC-NEXT: ( 11, 12
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10
+  // WITH-MASK-AND-ACC:      TILE END
   vector.print str "TILE BEGIN"
   vector.print %tile : vector<[2]x[2]xf64>
   vector.print str "TILE END"

>From d8206d28fe1247b697ffd7c3a2feb4d2fd44eda0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 26 Oct 2023 16:09:14 +0000
Subject: [PATCH 5/9] Rework a little

The arm_sme.outerproduct op has now been updated to match SME tile sizes,
which cleans things up a little. A few other fixups included.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 67 +++++++++----------
 .../VectorToArmSME/VectorToArmSME.cpp         | 32 ++++++---
 mlir/test/Dialect/ArmSME/invalid.mlir         | 27 +++-----
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 41 +++++-------
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 16 ++---
 5 files changed, 88 insertions(+), 95 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index bacfa282f6880a5..2f6e52ff2badbeb 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -588,71 +588,69 @@ class HasMatchingMaskTypeConstraint<string operand> :
     operand, operand # "Mask",
     "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
 
+class OuterProductResultTileTypeConstraint<string operand> :
+  OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
+    "lhs", operand,
+    "[&]{"
+    "  auto vectorType = ::llvm::cast<mlir::VectorType>($_self);"
+    "  int64_t size = vectorType.getDimSize(0);"
+    "  return VectorType::get("
+    "    { size, size }, vectorType.getElementType(), { true, true });"
+    "}()">;
+
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [Pure,
     AttrSizedOperandSegments,
-    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    AllTypesMatch<["lhs", "rhs"]>,
     HasMatchingMaskTypeConstraint<"lhs">,
     HasMatchingMaskTypeConstraint<"rhs">,
     PredOpTrait<
       "both `lhsMask` and `rhsMask` should be provided or neither",
       CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
-    PredOpTrait<
-      "result type is derived from `lhs` and `rhs`",
-      CPred<
-        "getResultType() == VectorType::get({"
-          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
-          "getRhsType().getElementType(),"
-          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
-    OptionalTypesMatchWith<"`result` and `acc` have the same type",
-      "result", "acc",
-      "::llvm::cast<mlir::Type>($_self)">
+    OuterProductResultTileTypeConstraint<"result">,
+    OuterProductResultTileTypeConstraint<"acc">
   ]>
 {
-  let summary = "Vector outerproduct with optional fused add";
+  let summary = "Outer product with optional fused add/sub";
 
   let description = [{
-    This op is based on `vector.outerproduct` with the extra conditions that:
-
-    * AXPY operations are not supported
-    * The only combining functions are "add" and "sub"
-    * Masking is performed on the inputs (rather than the output)
-
-    This is meant as an intermediate op for lowering `vector.outerproduct` to
-    SME. Types are not restricted to SVE/SME vectors at this level.
+    This operation represents an outer product that fits within an SME tile.
+    All operands must be SVE vectors and the result a SME tile. Unlike
+    `vector.outerproduct` masking is on the operands (rather than the result),
+    which mirrors the SME instructions.
 
     Example 1: Unmasked outerproduct (without accumulator)
     ```mlir
-    %result = arm_sme.outerproduct $lhs, $rhs
-                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    // Not specifying an accumulator implicitly zeros the destination tile.
+    %result = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
     ```
 
     Example 2: Unmasked outerproduct (with accumulator)
     ```mlir
     %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
-                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+                : vector<[4]xf32>, vector<[4]xf32>
     ```
 
     Example 3: Masked outerproduct
     ```mlir
     %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
-                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+                : vector<[4]xf32>, vector<[4]xf32>
     ```
 
     Example 4: Masked outerproduct (with accumulator)
     ```mlir
     %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
-                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+                : vector<[4]xf32>, vector<[4]xf32>
     ```
   }];
 
   let arguments = (ins
-    VectorOfRank<[1]>:$lhs, VectorOfRank<[1]>:$rhs,
-    Optional<VectorOfRankAndType<[1],[I1]>>:$lhsMask,
-    Optional<VectorOfRankAndType<[1],[I1]>>:$rhsMask,
-    Optional<VectorOfRank<[2]>>: $acc,
+    SVEVector:$lhs, SVEVector:$rhs,
+    Optional<SVEPredicate>:$lhsMask,
+    Optional<SVEPredicate>:$rhsMask,
+    Optional<SMETile>: $acc,
     ArmSME_CombiningKindAttr:$kind);
-  let results = (outs VectorOfRank<[2]>:$result);
+  let results = (outs SMETile:$result);
 
   let assemblyFormat = [{
     $lhs `,` $rhs
@@ -660,14 +658,13 @@ def OuterProductOp :
         `kind` `` $kind
       | `acc` `` `(` $acc `)`
       | `masks` `` `(` $lhsMask `,` $rhsMask `)`
-    ) attr-dict
-    `:` type($lhs) `,` type($rhs) `,` type($result)
+    ) attr-dict `:` type($lhs) `,` type($rhs)
   }];
 
   let extraClassDeclaration = [{
-    VectorType getLhsType() { return getLhs().getType(); }
-    VectorType getRhsType() { return getRhs().getType(); }
-    VectorType getResultType() { return getResult().getType(); }
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
   }];
 }
 
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 44e285211a0566c..5c91e97249809da 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -446,7 +446,19 @@ struct TransposeOpToArmSMELowering
 ///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
 ///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
 ///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
-///                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///                : vector<[4]xf32>, vector<[4]xf32>
+///
+/// Unmasked outerproducts can be directly replaced with the arm_sme op.
+///
+/// Example:
+///
+///   %result = vector.outerproduct %vecA, %vecB
+///              : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+///   %result = arm_sme.outerproduct %vecA, %vecB
+///              : vector<[4]xf32>, vector<[4]xf32>
 ///
 struct VectorOuterProductToArmSMELowering
     : public OpRewritePattern<vector::OuterProductOp> {
@@ -455,13 +467,21 @@ struct VectorOuterProductToArmSMELowering
 
   LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
                                 PatternRewriter &rewriter) const override {
-    // AXPY operation not suited for SME.
+
+    // We don't yet support lowering AXPY operations to SME. These could be
+    // lowered by masking out all but the first element of the LHS.
     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
       return outerProductOp.emitError("AXPY operations not supported");
 
+    if (!arm_sme::isValidSMETileVectorType(
+            outerProductOp.getResultVectorType()))
+      return outerProductOp.emitError(
+          "outer product does not fit into SME tile");
+
     auto kind = outerProductOp.getKind();
     if (kind != vector::CombiningKind::ADD)
-      return outerProductOp.emitError("unsupported kind");
+      return outerProductOp.emitError(
+          "unsupported kind (lowering to SME only supports ADD at the moment)");
 
     Value lhsMask = {};
     Value rhsMask = {};
@@ -478,12 +498,8 @@ struct VectorOuterProductToArmSMELowering
       if (!createMaskOp)
         return failure();
 
-      auto maskType = createMaskOp.getVectorType();
-      if (maskType.getRank() != 2)
-        return failure();
-
       auto loc = outerProductOp.getLoc();
-
+      auto maskType = createMaskOp.getVectorType();
       Value lhsMaskDim = createMaskOp.getOperand(0);
       Value rhsMaskDim = createMaskOp.getOperand(1);
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index d373953e6ed4836..716dd86b00872d2 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -157,29 +157,18 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
 
 // -----
 
-// expected-note at +1 {{prior use here}}
-func.func @arm_sme_outproduct__bad_mask_type(%vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<5xi1>, %maskB: vector<[2]xi1>)-> vector<3x[2]xf32>
+func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
 {
-  // expected-error at +1 {{use of value '%maskA' expects different type than prior uses}}
-  %0 = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
-  return %0 : vector<3x[2]xf32>
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
+  return %0 : vector<[2]x[2]xi16>
 }
 
 // -----
 
-func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x1xi16>
+func.func @arm_sme_outproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
 {
-  // expected-error at +1 {{op failed to verify that result type is derived from `lhs` and `rhs`}}
-  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[2]x1xi16>
-  return %0 : vector<[2]x1xi16>
-}
-
-// -----
-
-// expected-note at +1 {{prior use here}}
-func.func @arm_sme_outproduct__bad_acc_type(%vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<6x6xi32>) -> vector<7x6xi32>
-{
-  // expected-error at +1 {{use of value '%acc' expects different type than prior uses}}
-  %0 = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
-  return %0 : vector<7x6xi32>
+  // expected-error at +1 {{op failed to verify that all of {lhs, rhs} have same type}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
+  return %0 : vector<[4]x[4]xf32>
 }
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 7a1fa26d0e64807..49d79f0cfe9b3e2 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1169,48 +1169,39 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>
 // -----
 
 func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
-  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
-  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>
+  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>
   return %result : vector<[8]x[8]xi16>
 }
 
 // -----
 
-func.func @arm_sme_outproduct_with_masking(
-  %vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<3xi1>, %maskB: vector<[2]xi1>
-) -> vector<3x[2]xf32> {
-  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
-  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
-  return %result : vector<3x[2]xf32>
+func.func @arm_sme_outproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[4]xf32>, vector<[4]xf32>
+  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<[4]xf32>, vector<[4]xf32>
+  return %result : vector<[4]x[4]xf32>
 }
 
 // -----
 
-func.func @arm_sme_outproduct_with_acc(
-  %vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<7x6xi32>
-) -> vector<7x6xi32> {
-  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
-  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
-  return %result : vector<7x6xi32>
+func.func @arm_sme_outproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<[2]xi64>, vector<[2]xi64>
+  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<[2]xi64>, vector<[2]xi64>
+  return %result : vector<[2]x[2]xi64>
 }
 
 // -----
 
 func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
-  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
-  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>
   return %result : vector<[2]x[2]xf64>
 }
 
 // -----
 
-func.func @arm_sme_outproduct_with_everything(
-  %vecA: vector<[4]xf16>, %vecB: vector<4xf16>, %acc: vector<[4]x4xf16>,
-  %maskA: vector<[4]xi1>, %maskB: vector<4xi1>
-) -> vector<[4]x4xf16> {
-  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}})
-  // CHECK-SAME: : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
-  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB)
-              : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
-  return %result : vector<[4]x4xf16>
+func.func @arm_sme_outproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[16]xi8>, vector<[16]xi8>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
+  return %result : vector<[16]x[16]xi8>
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index c62eb1bd7dfbe71..07d295870969d3b 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -591,7 +591,7 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -604,7 +604,7 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -617,7 +617,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
 }
@@ -630,7 +630,7 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
 }
@@ -640,7 +640,7 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 // CHECK-LABEL: @vector_outerproduct_f64
 // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
 func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -650,7 +650,7 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
 // CHECK-LABEL: @vector_outerproduct_f32
 // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
 func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -660,7 +660,7 @@ func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32
 // CHECK-LABEL: @vector_outerproduct_f16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>
 func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
 }
@@ -670,7 +670,7 @@ func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16
 // CHECK-LABEL: @vector_outerproduct_bf16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>
 func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
 }

>From 05967be331c49f714e54a9a6be8ab1344f077f74 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 26 Oct 2023 16:31:15 +0000
Subject: [PATCH 6/9] Fixups

---
 mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5c91e97249809da..ae5056702a8128d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -486,6 +486,7 @@ struct VectorOuterProductToArmSMELowering
     Value lhsMask = {};
     Value rhsMask = {};
     Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
     if (outerProductOp.isMasked()) {
       auto maskingOp = outerProductOp.getMaskingOp();
       rewriter.setInsertionPoint(maskingOp);
@@ -498,7 +499,6 @@ struct VectorOuterProductToArmSMELowering
       if (!createMaskOp)
         return failure();
 
-      auto loc = outerProductOp.getLoc();
       auto maskType = createMaskOp.getVectorType();
       Value lhsMaskDim = createMaskOp.getOperand(0);
       Value rhsMaskDim = createMaskOp.getOperand(1);

>From a14773ac54c2f130948057348ce4e4cd1949db91 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 26 Oct 2023 16:40:50 +0000
Subject: [PATCH 7/9] Tidy up rewrite

---
 .../VectorToArmSME/VectorToArmSME.cpp         | 46 +++++++++++--------
 1 file changed, 27 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index ae5056702a8128d..94ebe58dc0dc60f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -488,26 +488,13 @@ struct VectorOuterProductToArmSMELowering
     Operation *rootOp = outerProductOp;
     auto loc = outerProductOp.getLoc();
     if (outerProductOp.isMasked()) {
-      auto maskingOp = outerProductOp.getMaskingOp();
-      rewriter.setInsertionPoint(maskingOp);
-      rootOp = maskingOp;
-
-      // Attempt to extract masks from vector.create_mask.
-      // TODO: Add support for other mask sources.
-      auto mask = maskingOp.getMask();
-      auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
-      if (!createMaskOp)
+      auto maskOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+      auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
+      if (failed(operandMasks))
         return failure();
-
-      auto maskType = createMaskOp.getVectorType();
-      Value lhsMaskDim = createMaskOp.getOperand(0);
-      Value rhsMaskDim = createMaskOp.getOperand(1);
-
-      VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
-      lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
-                                                      lhsMaskDim);
-      rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
-                                                      rhsMaskDim);
+      std::tie(lhsMask, rhsMask) = *operandMasks;
     }
 
     rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
@@ -516,6 +503,27 @@ struct VectorOuterProductToArmSMELowering
 
     return success();
   }
+
+  static FailureOr<std::pair<Value, Value>>
+  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
+    // Attempt to extract masks from vector.create_mask.
+    // TODO: Add support for other mask sources.
+    auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return failure();
+
+    auto maskType = createMaskOp.getVectorType();
+    Value lhsMaskDim = createMaskOp.getOperand(0);
+    Value rhsMaskDim = createMaskOp.getOperand(1);
+
+    VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+    Value lhsMask =
+        rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
+    Value rhsMask =
+        rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
+
+    return std::make_pair(lhsMask, rhsMask);
+  }
 };
 
 } // namespace

>From a4d95a442046f10ba3d22469e9757d6cba9ef4dc Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 30 Oct 2023 10:40:03 +0000
Subject: [PATCH 8/9] Only compile tests once

---
 .../Vector/CPU/ArmSME/test-outerproduct-f32.mlir   | 14 ++++++++------
 .../Vector/CPU/ArmSME/test-outerproduct-f64.mlir   | 14 ++++++++------
 2 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 5ed6bffd9335de4..ae5ad9cc2a5e90c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -3,22 +3,24 @@
 // DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
-// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
 // DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
 
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITHOUT-ACC
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=WITHOUT-ACC
 
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-ACC
 
 // REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_4x4xf32
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK
 
 // REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_4x4xf32
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
 
 func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index fe5d655daf770ac..36ce896a4c1bd90 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -3,22 +3,24 @@
 // DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
-// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
 // DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
 
-// RUN: %{compile} | %{run} | FileCheck %s
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
 
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-ACC
 
 // REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_2x2xf64
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK
 
 // REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
 
 func.func @test_outerproduct_no_accumulator_2x2xf64() {
   %c0 = arith.constant 0 : index

>From 452f4a112b0a7a2183fea03044d4f0ced6a73a19 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 30 Oct 2023 13:18:48 +0000
Subject: [PATCH 9/9] Fixups

---
 .../VectorToArmSME/VectorToArmSME.cpp         |  2 +-
 .../Transforms/LegalizeForLLVMExport.cpp      | 10 +--
 mlir/test/Dialect/ArmSME/invalid.mlir         |  4 +-
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 10 +--
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    | 23 ++----
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 80 +++++++++----------
 6 files changed, 58 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 94ebe58dc0dc60f..b60c21e2ced7a8f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -429,7 +429,7 @@ struct TransposeOpToArmSMELowering
 
 /// Conversion pattern for vector.outerproduct.
 ///
-/// If the vector.outerproduct is masked (and the mask from a
+/// If the vector.outerproduct is masked (and the mask is from a
 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
 /// operands.
 ///
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 32c01d3b1f49c26..105f2de207a0843 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -460,11 +460,11 @@ struct MoveTileSliceToVectorArmSMELowering
   }
 };
 
-/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
 ///
 /// Example:
 ///
-///   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+///   %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
 ///     : vector<[4]xf32>, vector<[4]xf32>
 ///
 /// is converted to:
@@ -474,7 +474,7 @@ struct MoveTileSliceToVectorArmSMELowering
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
-struct OuterProductToArmSMELowering
+struct OuterProductOpConversion
     : public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
   using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
 
@@ -725,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      OuterProductToArmSMELowering, ZeroOpConversion,
-      VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
+      OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
+      VectorInsertToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 716dd86b00872d2..dba8b1937936e2c 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -157,7 +157,7 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
 
 // -----
 
-func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
+func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
 {
   // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
   %0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
@@ -166,7 +166,7 @@ func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: ve
 
 // -----
 
-func.func @arm_sme_outproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
+func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
 {
   // expected-error at +1 {{op failed to verify that all of {lhs, rhs} have same type}}
   %0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 49d79f0cfe9b3e2..90b05c54c58d931 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1168,7 +1168,7 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>
 
 // -----
 
-func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
+func.func @arm_sme_outerproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
   // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>
   %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>
   return %result : vector<[8]x[8]xi16>
@@ -1176,7 +1176,7 @@ func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) ->
 
 // -----
 
-func.func @arm_sme_outproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+func.func @arm_sme_outerproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
   // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[4]xf32>, vector<[4]xf32>
   %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<[4]xf32>, vector<[4]xf32>
   return %result : vector<[4]x[4]xf32>
@@ -1184,7 +1184,7 @@ func.func @arm_sme_outproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector
 
 // -----
 
-func.func @arm_sme_outproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
+func.func @arm_sme_outerproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
   // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<[2]xi64>, vector<[2]xi64>
   %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<[2]xi64>, vector<[2]xi64>
   return %result : vector<[2]x[2]xi64>
@@ -1192,7 +1192,7 @@ func.func @arm_sme_outproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]
 
 // -----
 
-func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
+func.func @arm_sme_outerproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
   // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>
   %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>
   return %result : vector<[2]x[2]xf64>
@@ -1200,7 +1200,7 @@ func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2
 
 // -----
 
-func.func @arm_sme_outproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
+func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
   // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[16]xi8>, vector<[16]xi8>
   %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
   return %result : vector<[16]x[16]xi8>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 13f24ba8f0e7e3d..721ff8f2c3589d4 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -465,7 +465,7 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
 // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
   // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
   // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
   // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
@@ -483,22 +483,9 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
-  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
-  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
-  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
-  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
-}
-
-// -----
-
 // CHECK-LABEL: @vector_outerproduct_masked_f16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
   // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
   // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
@@ -511,7 +498,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
   // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
   // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
@@ -522,9 +509,9 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-LABEL: @vector_outerproduct_masked_f64
 // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
   // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
   // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
   // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 07d295870969d3b..9eb7cd143e5b5ea 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -585,35 +585,9 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
-  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
-  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
-  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
-}
-
-// -----
-
-// CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
-  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
-  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
-  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
-  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
-  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
-}
-
-// -----
-
 // CHECK-LABEL: @vector_outerproduct_masked_f16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -626,7 +600,7 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
@@ -637,22 +611,28 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
-func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
-  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
-  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
 }
 
 // -----
 
-// CHECK-LABEL: @vector_outerproduct_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
-func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
-  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
-  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
-  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
 
 // -----
@@ -674,3 +654,23 @@ func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xb
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
 }
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
+func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
+func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}



More information about the Mlir-commits mailing list