[Mlir-commits] [mlir] e666295 - [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (#69604)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 31 02:06:24 PDT 2023


Author: Benjamin Maxwell
Date: 2023-10-31T09:06:21Z
New Revision: e66629501189b15f24810b7a65cd814f76a25e23

URL: https://github.com/llvm/llvm-project/commit/e66629501189b15f24810b7a65cd814f76a25e23
DIFF: https://github.com/llvm/llvm-project/commit/e66629501189b15f24810b7a65cd814f76a25e23.diff

LOG: [mlir][ArmSME] Support lowering masked vector.outerproduct ops to SME (#69604)

This patch adds support for lowering masked outer products to SME. This
is done in two stages. First, vector.outerproducts (both masked and
non-masked) are rewritten to arm_sme.outerproducts. The
arm_sme.outerproduct op is close to vector.outerproduct, but supports
masking on the operands rather than the result. It also limits the cases
it handles to things that could be (directly) lowered to SME.

This currently requires that the source of the mask is a
vector.create_mask op. E.g.:

```mlir
%mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
%result = vector.mask %mask {
             vector.outerproduct %vecA, %vecB
              : vector<[4]xf32>, vector<[4]xf32>
          } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
```
Is rewritten to:
```
%maskA = vector.create_mask %dimA : vector<[4]xi1>
%maskB = vector.create_mask %dimB : vector<[4]xi1>
%result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
              : vector<[4]xf32>, vector<[4]xf32>
```
(The same rewrite works for non-masked vector.outerproducts too)

The arm_sme.outerproduct can then be directly lowered to SME intrinsics.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSME/invalid.mlir
    mlir/test/Dialect/ArmSME/roundtrip.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index b30d0fdb866bd23..2f6e52ff2badbeb 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -79,6 +79,27 @@ def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
   let defaultValue = "TileSliceLayout::Horizontal";
 }
 
+def CombiningKind : I32EnumAttr<"CombiningKind", "Kind of combining function", [
+  I32EnumAttrCase<"Add", 0, "add">,
+  I32EnumAttrCase<"Sub", 1, "sub">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies how to combine a newly produced value with the
+/// accumulator. This is similar to vector::CombiningKindAttr, but limited to
+/// the functions that are valid for SME outer products. Add corresponds to a
+/// MOPA and sub to a MOPS.
+/// E.g. For f32:
+/// FMOPA: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
+/// FMOPS: https://developer.arm.com/documentation/ddi0602/2022-03/SME-Instructions/FMOPS--non-widening---Floating-point-outer-product-and-subtract-
+def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
+                                          "kind"> {
+  let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "CombiningKind::Add";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -209,7 +230,7 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let results = (outs SMETile:$res);
   let description = [{
     Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction. 
+    `zero` intrinsic and instruction.
 
     Example 1: Zero an 8-bit element ZA tile.
 
@@ -561,4 +582,90 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 }
 
+class HasMatchingMaskTypeConstraint<string operand> :
+  OptionalTypesMatchWith<
+    "shape of `" # operand #  "Mask` matches `" # operand # "`",
+    operand, operand # "Mask",
+    "::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
+
+class OuterProductResultTileTypeConstraint<string operand> :
+  OptionalTypesMatchWith<operand # "type is derived from `lhs` and `rhs`",
+    "lhs", operand,
+    "[&]{"
+    "  auto vectorType = ::llvm::cast<mlir::VectorType>($_self);"
+    "  int64_t size = vectorType.getDimSize(0);"
+    "  return VectorType::get("
+    "    { size, size }, vectorType.getElementType(), { true, true });"
+    "}()">;
+
+def OuterProductOp :
+  ArmSME_Op<"outerproduct", [Pure,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs">,
+    HasMatchingMaskTypeConstraint<"rhs">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">>,
+    OuterProductResultTileTypeConstraint<"result">,
+    OuterProductResultTileTypeConstraint<"acc">
+  ]>
+{
+  let summary = "Outer product with optional fused add/sub";
+
+  let description = [{
+    This operation represents an outer product that fits within an SME tile.
+    All operands must be SVE vectors and the result a SME tile. Unlike
+    `vector.outerproduct` masking is on the operands (rather than the result),
+    which mirrors the SME instructions.
+
+    Example 1: Unmasked outerproduct (without accumulator)
+    ```mlir
+    // Not specifying an accumulator implicitly zeros the destination tile.
+    %result = arm_sme.outerproduct $lhs, $rhs : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    Example 2: Unmasked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator)
+                : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    Example 3: Masked outerproduct
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    Example 4: Masked outerproduct (with accumulator)
+    ```mlir
+    %result = arm_sme.outerproduct $lhs, $rhs acc($accumulator) masks($lhsMask, $rhsMask)
+                : vector<[4]xf32>, vector<[4]xf32>
+    ```
+  }];
+
+  let arguments = (ins
+    SVEVector:$lhs, SVEVector:$rhs,
+    Optional<SVEPredicate>:$lhsMask,
+    Optional<SVEPredicate>:$rhsMask,
+    Optional<SMETile>: $acc,
+    ArmSME_CombiningKindAttr:$kind);
+  let results = (outs SMETile:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `kind` `` $kind
+      | `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict `:` type($lhs) `,` type($rhs)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+  }];
+}
+
 #endif // ARMSME_OPS

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..b60c21e2ced7a8f 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -427,6 +427,105 @@ struct TransposeOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.outerproduct.
+///
+/// If the vector.outerproduct is masked (and the mask is from a
+/// vector.create_mask), then the mask is decomposed into two 1-D masks for the
+/// operands.
+///
+/// Example:
+///
+///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
+///   %result = vector.mask %mask {
+///                vector.outerproduct %vecA, %vecB
+///                 : vector<[4]xf32>, vector<[4]xf32>
+///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+///
+/// is converted to:
+///
+///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
+///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
+///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
+///                : vector<[4]xf32>, vector<[4]xf32>
+///
+/// Unmasked outerproducts can be directly replaced with the arm_sme op.
+///
+/// Example:
+///
+///   %result = vector.outerproduct %vecA, %vecB
+///              : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+///   %result = arm_sme.outerproduct %vecA, %vecB
+///              : vector<[4]xf32>, vector<[4]xf32>
+///
+struct VectorOuterProductToArmSMELowering
+    : public OpRewritePattern<vector::OuterProductOp> {
+
+  using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
+                                PatternRewriter &rewriter) const override {
+
+    // We don't yet support lowering AXPY operations to SME. These could be
+    // lowered by masking out all but the first element of the LHS.
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      return outerProductOp.emitError("AXPY operations not supported");
+
+    if (!arm_sme::isValidSMETileVectorType(
+            outerProductOp.getResultVectorType()))
+      return outerProductOp.emitError(
+          "outer product does not fit into SME tile");
+
+    auto kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      return outerProductOp.emitError(
+          "unsupported kind (lowering to SME only supports ADD at the moment)");
+
+    Value lhsMask = {};
+    Value rhsMask = {};
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      rewriter.setInsertionPoint(maskOp);
+      rootOp = maskOp;
+      auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
+      if (failed(operandMasks))
+        return failure();
+      std::tie(lhsMask, rhsMask) = *operandMasks;
+    }
+
+    rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
+        rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
+        outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
+
+    return success();
+  }
+
+  static FailureOr<std::pair<Value, Value>>
+  decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
+    // Attempt to extract masks from vector.create_mask.
+    // TODO: Add support for other mask sources.
+    auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return failure();
+
+    auto maskType = createMaskOp.getVectorType();
+    Value lhsMaskDim = createMaskOp.getOperand(0);
+    Value rhsMaskDim = createMaskOp.getOperand(1);
+
+    VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
+    Value lhsMask =
+        rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
+    Value rhsMask =
+        rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
+
+    return std::make_pair(lhsMask, rhsMask);
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -434,5 +533,6 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
                SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
+               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+               VectorOuterProductToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 1231da356f8ed95..105f2de207a0843 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -460,11 +460,11 @@ struct MoveTileSliceToVectorArmSMELowering
   }
 };
 
-/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+/// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
 ///
 /// Example:
 ///
-///   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+///   %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
 ///     : vector<[4]xf32>, vector<[4]xf32>
 ///
 /// is converted to:
@@ -474,13 +474,13 @@ struct MoveTileSliceToVectorArmSMELowering
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
-struct VectorOuterProductToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
-  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+struct OuterProductOpConversion
+    : public ConvertOpToLLVMPattern<arm_sme::OuterProductOp> {
+  using ConvertOpToLLVMPattern<arm_sme::OuterProductOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp,
-                  vector::OuterProductOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
+                  arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
@@ -512,24 +512,13 @@ struct VectorOuterProductToArmSMELowering
       return true;
     };
 
-    auto resultVectorType = outerProductOp.getResultVectorType();
-    if (!isSupportedType(resultVectorType))
-      return outerProductOp.emitError("unsupported type");
-
-    vector::CombiningKind kind = outerProductOp.getKind();
-    if (kind != vector::CombiningKind::ADD)
-      // TODO: support subtract.
+    // TODO: Support CombiningKind::Sub for outer products.
+    if (outerProductOp.getKind() != CombiningKind::Add)
       return outerProductOp.emitError("unsupported kind");
 
-    auto maskableOp =
-        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: support masking.
-      return outerProductOp.emitError("masking is currently unsupported");
-
-    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
-      // AXPY operation not suited for SME.
-      return failure();
+    auto resultVectorType = outerProductOp.getResultType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
 
     auto loc = outerProductOp.getLoc();
 
@@ -542,21 +531,24 @@ struct VectorOuterProductToArmSMELowering
     auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
         loc, rewriter.getIntegerType(elementWidth), acc);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy =
-        VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
-                        /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
-
     auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
 
+    Value lhsMask = outerProductOp.getLhsMask();
+    Value rhsMask = outerProductOp.getRhsMask();
+
+    if (!lhsMask || !rhsMask) {
+      auto predTy =
+          outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
+      Value allActiveMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predTy, true));
+      lhsMask = allActiveMask;
+      rhsMask = allActiveMask;
+    }
+
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(
-        loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
-        outerProductOp.getRhs());
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+                                               outerProductOp.getLhs(),
+                                               outerProductOp.getRhs());
 
     // Create `CastTileToVectorOp` to use as the output.
     rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
@@ -733,6 +725,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      VectorOuterProductToArmSMELowering, ZeroOpConversion,
-      VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
+      OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
+      VectorInsertToArmSMELowering>(converter);
 }

diff  --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 25c62f78d843543..dba8b1937936e2c 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -150,3 +150,25 @@ func.func @arm_sme_tile_load__pad_but_no_mask(%src : memref<?x?xf64>, %pad : f64
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_outerproduct__bad_result_type(%vecA: vector<[2]xi16>, %vecB: vector<[2]xi16>) -> vector<[2]x[2]xi16>
+{
+  // expected-error at +1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[2]x[2]xi16>'}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[2]xi16>, vector<[2]xi16>
+  return %0 : vector<[2]x[2]xi16>
+}
+
+// -----
+
+func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: vector<[8]xf32>) -> vector<[4]x[4]xf32>
+{
+  // expected-error at +1 {{op failed to verify that all of {lhs, rhs} have same type}}
+  %0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32>
+  return %0 : vector<[4]x[4]xf32>
+}

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 6866137267dc66a..90b05c54c58d931 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1161,3 +1161,47 @@ func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_outerproduct(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[8]x[8]xi16> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16>
+  %result = arm_sme.outerproduct %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16>
+  return %result : vector<[8]x[8]xi16>
+}
+
+// -----
+
+func.func @arm_sme_outerproduct_with_masking(%vecA: vector<[4]xf32>, %vecB: vector<[4]xf32>, %maskA: vector<[4]xi1>, %maskB: vector<[4]xi1>) -> vector<[4]x[4]xf32> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[4]xf32>, vector<[4]xf32>
+  %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) : vector<[4]xf32>, vector<[4]xf32>
+  return %result : vector<[4]x[4]xf32>
+}
+
+// -----
+
+func.func @arm_sme_outerproduct_with_acc(%vecA: vector<[2]xi64>, %vecB: vector<[2]xi64>, %acc: vector<[2]x[2]xi64>) -> vector<[2]x[2]xi64> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} acc({{.*}}) : vector<[2]xi64>, vector<[2]xi64>
+  %result = arm_sme.outerproduct %vecA, %vecB acc(%acc) : vector<[2]xi64>, vector<[2]xi64>
+  return %result : vector<[2]x[2]xi64>
+}
+
+// -----
+
+func.func @arm_sme_outerproduct_with_kind(%vecA: vector<[2]xf64>, %vecB: vector<[2]xf64>) -> vector<[2]x[2]xf64>  {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> : vector<[2]xf64>, vector<[2]xf64>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> : vector<[2]xf64>, vector<[2]xf64>
+  return %result : vector<[2]x[2]xf64>
+}
+
+// -----
+
+func.func @arm_sme_outerproduct_with_everything(%vecA: vector<[16]xi8>, %vecB: vector<[16]xi8>, %acc: vector<[16]x[16]xi8>, %maskA: vector<[16]xi1>, %maskB: vector<[16]xi1>) -> vector<[16]x[16]xi8> {
+  // CHECK: arm_sme.outerproduct {{.*}}, {{.*}} kind<sub> acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[16]xi8>, vector<[16]xi8>
+  %result = arm_sme.outerproduct %vecA, %vecB kind<sub> acc(%acc) masks(%maskA, %maskB) : vector<[16]xi8>, vector<[16]xi8>
+  return %result : vector<[16]x[16]xi8>
+}

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 32f46d9fd817c9d..721ff8f2c3589d4 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -463,9 +463,68 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 
 // -----
 
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+  // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32
+  // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM0]] : vector<[4]xi32>
+  // CHECK: %[[DIM1_I32:.*]] = arith.index_cast %[[DIM1]] : index to i32
+  // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertelement %[[DIM1_I32]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[SPLAT_DIM1:.*]] = llvm.shufflevector %[[INSERT_DIM1]], {{.*}} : vector<[4]xi32>
+  // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM1]] : vector<[4]xi32>
+  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
+  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32>
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
 // CHECK-LABEL: @vector_outerproduct_unsupported_axpy
 func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
-  // CHECK-NOT: arm_sme
+  // expected-error at +1 {{AXPY operations not supported}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
   return %0 : vector<[2]xf64>
 }
@@ -473,7 +532,6 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 // -----
 
 func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -490,9 +548,8 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
-  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
-  // expected-error at +1 {{masking is currently unsupported}}
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+  // expected-error at +1 {{failed to legalize operation 'vector.outerproduct'}}
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..9eb7cd143e5b5ea 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -578,3 +578,99 @@ func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[8]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[8]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xbf16>, vector<[8]xbf16>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[4]xf32>, vector<[4]xf32>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_masked_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+  %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
+  // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[2]xi1>
+  // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[2]xi1>
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[2]xf64>, vector<[2]xf64>
+  %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>
+func.func @vector_outerproduct_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xf16>, vector<[8]xf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_bf16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>
+func.func @vector_outerproduct_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[8]xbf16>, vector<[8]xbf16>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
+  "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f32
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>
+func.func @vector_outerproduct_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[4]xf32>, vector<[4]xf32>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_f64
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>
+func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // CHECK: arm_sme.outerproduct %[[LHS]], %[[RHS]] acc(%[[ACC]]) : vector<[2]xf64>, vector<[2]xf64>
+  %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 38ba489e2fafb2c..ae5ad9cc2a5e90c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -3,16 +3,24 @@
 // DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
-// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
 // DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
 
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITHOUT-ACC
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s --check-prefix=WITHOUT-ACC
 
 // REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_4x4xf32
-// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=WITH-ACC
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-ACC
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_4x4xf32
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_4x4xf32
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
 
 func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
@@ -41,7 +49,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   %c0 = arith.constant 0 : index
   %f10 = arith.constant 10.0 : f32
 
-  %acc = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+  %acc = vector.splat %f10 : vector<[4]x[4]xf32>
   %vector_i32 = llvm.intr.experimental.stepvector : vector<[4]xi32>
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
@@ -61,3 +69,67 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
 
   return
 }
+
+func.func @test_masked_outerproduct_no_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 3 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 3x2xf32 section.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 1, 2, 0, 0
+  // WITH-MASK-NEXT: ( 2, 4, 0, 0
+  // WITH-MASK-NEXT: ( 3, 6, 0, 0
+  // WITH-MASK-NEXT: ( 0, 0, 0, 0
+  // WITH-MASK:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[4]x[4]xf32>
+  vector.print str "TILE END"
+
+  return
+}
+
+func.func @test_masked_outerproduct_with_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[4]xi32>
+  %f10 = arith.constant 10.0 : f32
+
+  %acc = vector.splat %f10 : vector<[4]x[4]xf32>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[4]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
+
+  %lhsDim = arith.constant 2 : index
+  %rhsDim = arith.constant 3 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
+  } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+
+  // Print the tile. Due to masking the result will be the top 2x3xf32 section.
+  //
+  // WITH-MASK-AND-ACC:      TILE BEGIN
+  // WITH-MASK-AND-ACC-NEXT: ( 11, 12, 13, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 12, 14, 16, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10, 10, 10
+  // WITH-MASK-AND-ACC:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[4]x[4]xf32>
+  vector.print str "TILE END"
+
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 82f14595a24da2f..36ce896a4c1bd90 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,34 +1,47 @@
-// DEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
+// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
 // DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
 // DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
-// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
 // DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
 
-// RUN: %{compile} | %{run} | FileCheck %s
+// RUN: %{compile}
 
-func.func @test_outerproduct_with_accumulator_2x2xf64() {
-  %f1 = arith.constant 1.0 : f64
-  %f2 = arith.constant 2.0 : f64
-  %f10 = arith.constant 10.0 : f64
+// RUN: %{run} | FileCheck %s
+
+// REDEFINE: %{entry_point} = test_outerproduct_with_accumulator_2x2xf64
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-ACC
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_no_accumulator_2x2xf64
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK
+
+// REDEFINE: %{entry_point} = test_masked_outerproduct_with_accumulator_2x2xf64
+// RUN: %{run} | FileCheck %s --check-prefix=WITH-MASK-AND-ACC
+
+func.func @test_outerproduct_no_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
 
-  %a = vector.splat %f1 : vector<[2]xf64>
-  %b = vector.splat %f2 : vector<[2]xf64>
-  // TODO: vector.splat doesn't support ArmSME.
-  %c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
 
-  %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
+  %lhsDim = arith.constant 1 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.outerproduct %vector, %vector : vector<[2]xf64>, vector<[2]xf64>
 
   // Print the tile. The smallest SVL is 128-bits so the tile will be at least
   // 2x2xf64.
   //
   // CHECK:      TILE BEGIN
-  // CHECK-NEXT: ( 12, 12
-  // CHECK-NEXT: ( 12, 12
+  // CHECK-NEXT: ( 1, 2
+  // CHECK-NEXT: ( 2, 4
   // CHECK:      TILE END
   vector.print str "TILE BEGIN"
   vector.print %tile : vector<[2]x[2]xf64>
@@ -36,3 +49,90 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
 
   return
 }
+
+func.func @test_outerproduct_with_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %acc = vector.splat %f10 : vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %tile = vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
+
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
+  //
+  // WITH-ACC:      TILE BEGIN
+  // WITH-ACC-NEXT: ( 11, 12
+  // WITH-ACC-NEXT: ( 12, 14
+  // WITH-ACC:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}
+
+func.func @test_masked_outerproduct_no_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 2 : index
+  %rhsDim = arith.constant 1 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector : vector<[2]xf64>, vector<[2]xf64>
+  } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+
+  // Print the tile. Due to masking the result will be the top 2x1xf64 section.
+  //
+  // WITH-MASK:      TILE BEGIN
+  // WITH-MASK-NEXT: ( 1, 0
+  // WITH-MASK-NEXT: ( 2, 0
+  // WITH-MASK:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}
+
+func.func @test_masked_outerproduct_with_accumulator_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %ones = arith.constant dense<1> : vector<[2]xi32>
+  %f10 = arith.constant 10.0 : f64
+
+  %acc = vector.splat %f10 : vector<[2]x[2]xf64>
+  %step_vector = llvm.intr.experimental.stepvector : vector<[2]xi32>
+  %vector_i32 = arith.addi %step_vector, %ones : vector<[2]xi32>
+  %vector = arith.sitofp %vector_i32 : vector<[2]xi32> to vector<[2]xf64>
+
+  %lhsDim = arith.constant 1 : index
+  %rhsDim = arith.constant 2 : index
+  %mask = vector.create_mask %lhsDim, %rhsDim : vector<[2]x[2]xi1>
+
+  %tile = vector.mask %mask {
+    vector.outerproduct %vector, %vector, %acc : vector<[2]xf64>, vector<[2]xf64>
+  } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
+
+  // Print the tile. Due to masking the result will be the top 1x2xf64 section.
+  //
+  // WITH-MASK-AND-ACC:      TILE BEGIN
+  // WITH-MASK-AND-ACC-NEXT: ( 11, 12
+  // WITH-MASK-AND-ACC-NEXT: ( 10, 10
+  // WITH-MASK-AND-ACC:      TILE END
+  vector.print str "TILE BEGIN"
+  vector.print %tile : vector<[2]x[2]xf64>
+  vector.print str "TILE END"
+
+  return
+}


        


More information about the Mlir-commits mailing list