[Mlir-commits] [mlir] [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (PR #69604)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Oct 20 02:16:11 PDT 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/69604

>From 5adf0b6fc6bdc67431f2d65edb34c66fd9f54a29 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 11:58:34 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Support lowering masked
 vector.outerproduct ops to SME

This patch adds support for lowering masked outer products to SME. This
is done in two stages. First, vector.outerproducts (both masked and
non-masked) are rewritten to arm_sme.outerproducts. The
arm_sme.outerproduct op is close to vector.outerproduct, but supports
masking on the operands rather than the result. It also limits the
cases it handles to things that could be lowered to SME, but does not
enforce that everything matches SME tiles at this level.

This currently requires that the source of the mask is a
vector.create_mask op. E.g.:

```mlir
%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
```
Is rewritten to:
```
%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
```
(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 108 ++++++++++++++++++
 .../VectorToArmSME/VectorToArmSME.cpp         |  77 ++++++++++++-
 .../Transforms/LegalizeForLLVMExport.cpp      |  60 +++++-----
 mlir/test/Dialect/ArmSME/invalid.mlir         |  29 +++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       |  53 +++++++++
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    |  74 +++++++++++-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  96 ++++++++++++++++
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  70 ++++++++++++
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  35 ++++++
 9 files changed, 562 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..f60126e83603f47 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -79,6 +79,23 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
   let defaultValue = "TileSliceLayout::Horizontal";
 }
 
+def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
+  I32EnumAttrCase<"Add", 0, "add">,
+  I32EnumAttrCase<"Sub", 1, "sub">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies how to combine a newly produced value with the
+/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
+/// the functions that are valid for SME outer products.
+def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
+                                          "kind"> {
+  let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "CombiningKind::Add";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -507,4 +524,95 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
+  TypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
+    "!" # maskGetter # "() || std::equal_to<>()">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllElementTypesMatch<["lhs", "rhs", "result"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    PredOpTrait<
+      "result type is derived from `lhs` and `rhs`",
+      CPred<
+        "getResultType() == VectorType::get({"
+          "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
+          "getRhsType().getElementType(),"
+          "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
+    TypesMatchWith<"`result` and `acc` have the same type",
+      "result", "acc",
+      "::llvm::cast<mlir::Type>($_self)",
+      "!getAcc() || std::equal_to<>()">
+  ]>
+{
+  let summary = "Vector outerproduct with optional fused add";
+
+  let description = [{
+    This op is based on `vector.outerproduct` with the extra conditions that:
+
+    * AXPY operations are not supported
+    * The only combining functions are "add" and "sub"
+    * Masking is performed on the inputs (rather than the output)
+
+    This is meant as an intermediate op for lowering `vector.outerproduct` to
+    SME. Types are not restricted to SVE/SME vectors at this level.
+
+    Example 1: Unmasked outerproduct (without accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 2: Unmasked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Masked outerproduct
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 4: Masked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    VectorOfRank<[1]>:$lhs, VectorOfRank<[1]>:$rhs,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$lhsMask,
+    Optional<VectorOfRankAndType<[1],[I1]>>:$rhsMask,
+    Optional<VectorOfRank<[2]>>: $acc,
+    ArmSME_CombiningKindAttr:$kind);
+  let results = (outs VectorOfRank<[2]>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `kind` `` $kind
+      | `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict
+    `:` type($lhs) `,` type($rhs) `,` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return getLhs().getType(); }
+    VectorType getRhsType() { return getRhs().getType(); }
+    VectorType getResultType() { return getResult().getType(); }
+  }];
+}
+
 #endif // ARMSME_OPS
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..b7ec2603477d59e 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,6 +427,80 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Lowers a masked `vector.outerproduct` to `arm_sme.outerproduct`.
+/// The 2-D mask of the `vector.outerproduct` (if from a `vector.create_mask`)
+/// is decomposed into two 1-D masks for the operands.
+///
+///  BEFORE:
+///  ```mlir
+///  %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///  %result = vector.mask %mask {
+///               vector.outerproduct %vecA, %vecB
+///                : vector<[4]xf32>, vector<[4]xf32>
+///            } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///  %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+///  ```
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+    // AXPY operation not suited for SME.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
+
+    auto kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      return outerProductOp.emitError("unsupported kind");
+
+    Value lhsMask = {};
+    Value rhsMask = {};
+    Operation *rootOp = outerProductOp;
+    if (outerProductOp.isMasked()) {
+      auto maskingOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskingOp);
+      rootOp = maskingOp;
+
+      // Attempt to extract masks from vector.create_mask.
+      // TODO: Add support for other mask sources.
+      auto mask = maskingOp.getMask();
+      auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+
+      auto maskType = createMaskOp.getVectorType();
+      if (maskType.getRank() != 2)
+        return failure();
+
+      auto loc = outerProductOp.getLoc();
+
+      Value lhsMaskDim = createMaskOp.getOperand(0);
+      Value rhsMaskDim = createMaskOp.getOperand(1);
+
+      VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+      lhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      lhsMaskDim);
+      rhsMask = rewriter.create<vector::CreateMaskOp>(loc, operandMaskType,
+                                                      rhsMaskDim);
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -434,5 +508,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
                SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
+               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+               VectorOuterProductToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..0e5c996e518a472 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -458,13 +458,13 @@ struct MoveTileSliceToVectorArmSMELowering
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
-struct VectorOuterProductToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
-  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+struct OuterProductToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
+  using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp,
-                  vector::OuterProductOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
+                  arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
@@ -496,24 +496,13 @@ struct VectorOuterProductToArmSMELowering
       return true;
     };
 
-    auto resultVectorType = outerProductOp.getResultVectorType();
-    if (!isSupportedType(resultVectorType))
-      return outerProductOp.emitError("unsupported type");
-
-    vector::CombiningKind kind = outerProductOp.getKind();
-    if (kind != vector::CombiningKind::ADD)
-      // TODO: support subtract.
+    // TODO: Support CombiningKind::Sub for outer products.
+    if (outerProductOp.getKind() != CombiningKind::Add)
       return outerProductOp.emitError("unsupported kind");
 
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: support masking.
-      return outerProductOp.emitError("masking is currently unsupported");
-
-    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      // AXPY operation not suited for SME.
-      return failure();
+    auto resultVectorType = outerProductOp.getResultType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
 
     auto loc = outerProductOp.getLoc();
 
@@ -526,21 +515,24 @@ struct VectorOuterProductToArmSMELowering
     auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
         loc, rewriter.getIntegerType(elementWidth), acc);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy =
-        VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
-                        /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
-
     auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
 
+    Value lhsMask = outerProductOp.getLhsMask();
+    Value rhsMask = outerProductOp.getRhsMask();
+
+    if (!lhsMask || !rhsMask) {
+      auto predTy =
+          outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
+      Value allActiveMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predTy, true));
+      lhsMask = allActiveMask;
+      rhsMask = allActiveMask;
+    }
+
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(
-        loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
-        outerProductOp.getRhs());
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+                                               outerProductOp.getLhs(),
+                                               outerProductOp.getRhs());
 
     // Create `CastTileToVectorOp` to use as the output.
     rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
@@ -716,6 +708,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      VectorOuterProductToArmSMELowering, ZeroOpConversion,
+      OuterProductToArmSMELowering, ZeroOpConversion,
       VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..996c26834ae7fa6 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -97,3 +97,32 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_mask_type(%vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<5xi1>, %maskB: vector<[2]xi1>)-> vector<3x[2]xf32>
+{
+  // expected-error at +1 {{use of value '%maskA' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct__bad_result_type(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[2]x1xi16>
+{
+  // expected-error at +1 {{op failed to verify that result type is derived from `lhs` and `rhs`}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[2]x1xi16>
+  return %0 : vector<[2]x1xi16>
+}
+
+// -----
+
+// expected-note at +1 {{prior use here}}
+func.func @arm_sme_outproduct__bad_acc_type(%vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<6x6xi32>) -> vector<7x6xi32>
+{
+  // expected-error at +1 {{use of value '%acc' expects different type than prior uses}}
+  %0 = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %0 : vector<7x6xi32>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..6ee6d7a6149b940 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1135,3 +1135,56 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_outproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>, vector<[8]x[8]xi16>
+  return %result : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_masking(
+  %vecA: vector<3xf32>, %vecB: vector<[2]xf32>, %maskA: vector<3xi1>, %maskB: vector<[2]xi1>
+) -> vector<3x[2]xf32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<3xf32>, vector<[2]xf32>, vector<3x[2]xf32>
+  return %result : vector<3x[2]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_acc(
+  %vecA: vector<7xi32>, %vecB: vector<6xi32>, %acc: vector<7x6xi32>
+) -> vector<7x6xi32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<7xi32>, vector<6xi32>, vector<7x6xi32>
+  return %result : vector<7x6xi32>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  return %result : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_outproduct_with_everything(
+  %vecA: vector<[4]xf16>, %vecB: vector<4xf16>, %acc: vector<[4]x4xf16>,
+  %maskA: vector<[4]xi1>, %maskB: vector<4xi1>
+) -> vector<[4]x4xf16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}})
+  // CHECK-SAME: : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB)
+              : vector<[4]xf16>, vector<4xf16>, vector<[4]x4xf16>
+  return %result : vector<[4]x4xf16>
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 32f46d9fd817c9d..f615a9ef0231443 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -463,9 +463,75 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, {{.*}} : vector<[4]xi32>
+  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
+  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
 // CHECK-LABEL: @vector_outerproduct_unsupported_axpy
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // CHECK-NOT: arm_sme
+  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
@@ -473,7 +539,6 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 // -----
 
 func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -490,9 +555,8 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
-  // expected-error at +1 {{masking is currently unsupported}}
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..21edbc259e124c3 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -578,3 +578,99 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[2]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>,
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[4]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0: index, %dim1: index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask {{.*}} : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
+func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>, vector<[2]x[2]xf64>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
+func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>
+func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>, vector<[8]x[8]xf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>
+func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>, vector<[8]x[8]xbf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 4265ca0f599281c..052e74c776108a0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -14,6 +14,12 @@
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
 // RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
 
+// REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+
 llvm.func @printCString(!llvm.ptr<i8>)
 
 func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
@@ -82,5 +88,69 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   return
 }
 
+func.func @test_masked_outerproduct_no_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 3 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 3x2xf32 section.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 1, 2, 0, 0
+  // WITH-MASK-NEXT: ( 2, 4, 0, 0
+  // WITH-MASK-NEXT: ( 3, 6, 0, 0
+  // WITH-MASK-NEXT: ( 0, 0, 0, 0
+  // WITH-MASK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  vector.print %tile : vector<[4]x[4]xf32>
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
+func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+  %f10 = arith.constant 10.0 : f32
+
+  %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 2 : index
+  %rhsDim = arith.constant 3 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 2x3xf32 section.
+  //
+  // WITH-MASK-AND-ACC:      TILE BEGIN
+  // WITH-MASK-AND-ACC-NEXT: ( 11, 12, 13, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 12, 14, 16, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC:      TILE END
+  func.call @printTileBegin() : () -> ()
+  vector.print %tile : vector<[4]x[4]xf32>
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
 llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
 llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index cb2c6b98a4eef3a..b42c74c911ee306 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -11,6 +11,9 @@
 
 // RUN: %{compile} | %{run} | FileCheck %s
 
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-MASK
+
 llvm.func @printCString(!llvm.ptr<i8>)
 
 func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
@@ -57,5 +60,37 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
   return
 }
 
+func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %acc = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 1 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
+  } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 11, 12
+  // WITH-MASK-NEXT: ( 10, 10
+  // WITH-MASK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  vector.print %tile : vector<[2]x[2]xf64>
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
 llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
 llvm.mlir.global internal constant @str_tile_end("TILE END\0A")

>From a545d5735b40b9f59bc0bb2f0149f97f71a8ff8f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 19 Oct 2023 14:35:18 +0000
Subject: [PATCH 2/3] Make comment in VectorToArmSME.cpp match others in style

---
 .../VectorToArmSME/VectorToArmSME.cpp         | 37 ++++++++++---------
 1 file changed, 19 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index b7ec2603477d59e..44e285211a0566c 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,26 +427,27 @@ struct TransposeOpToArmSMELowering
   }
 };
 
-/// Lowers a masked `vector.outerproduct` to `arm_sme.outerproduct`.
-/// The 2-D mask of the `vector.outerproduct` (if from a `vector.create_mask`)
-/// is decomposed into two 1-D masks for the operands.
+/// Conversion pattern for vector.outerproduct.
 ///
-///  BEFORE:
-///  ```mlir
-///  %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
-///  %result = vector.mask %mask {
-///               vector.outerproduct %vecA, %vecB
-///                : vector<[4]xf32>, vector<[4]xf32>
-///            } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
-///  ```
+/// If the vector.outerproduct is masked (and the mask from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///   %result = vector.mask %mask {
+///                vector.outerproduct %vecA, %vecB
+///                 : vector<[4]xf32>, vector<[4]xf32>
+///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///                : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
 ///
-///  AFTER:
-///  ```mlir
-///  %maskA = vector.create_mask %dimA : vector<[4]xi1>
-///  %maskB = vector.create_mask %dimB : vector<[4]xi1>
-///  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
-///              : vector<[4]xf32>, vector<[4]xf32>, vector<[4]x[4]xf32>
-///  ```
 struct VectorOuterProductToArmSMELowering
     : public OpRewritePattern<vector::OuterProductOp> {
 

>From 852a83b5779222d1af16a3b51261159825270895 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 20 Oct 2023 09:14:17 +0000
Subject: [PATCH 3/3] Switch to OptionalTypesMatchWith

---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 16 +++++++---------
 1 file changed, 7 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f60126e83603f47..2c2af897c3a426b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -524,19 +524,18 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
-class HasMatchingMaskTypeConstraint<string operand, string maskGetter> :
-  TypesMatchWith<
+class HasMatchingMaskTypeConstraint<string operand> :
+  OptionalTypesMatchWith<
     "shape of `" # operand #  "Mask` matches `" # operand # "`",
     operand, operand # "Mask",
-    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))",
-    "!" # maskGetter # "() || std::equal_to<>()">;
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [Pure,
     AttrSizedOperandSegments,
     AllElementTypesMatch<["lhs", "rhs", "result"]>,
-    HasMatchingMaskTypeConstraint<"lhs", "getLhsMask">,
-    HasMatchingMaskTypeConstraint<"rhs", "getRhsMask">,
+    HasMatchingMaskTypeConstraint<"lhs">,
+    HasMatchingMaskTypeConstraint<"rhs">,
     PredOpTrait<
       "both `lhsMask` and `rhsMask` should be provided or neither",
       CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
@@ -547,10 +546,9 @@ def OuterProductOp :
           "getLhsType().getDimSize(0), getRhsType().getDimSize(0)},"
           "getRhsType().getElementType(),"
           "{getLhsType().getScalableDims()[0], getRhsType().getScalableDims()[0]})">>,
-    TypesMatchWith<"`result` and `acc` have the same type",
+    OptionalTypesMatchWith<"`result` and `acc` have the same type",
       "result", "acc",
-      "::llvm::cast<mlir::Type>($_self)",
-      "!getAcc() || std::equal_to<>()">
+      "::llvm::cast<mlir::Type>($_self)">
   ]>
 {
   let summary = "Vector outerproduct with optional fused add";



More information about the Mlir-commits mailing list