[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (PR #65621)

Cullen Rhodes llvmlistbot at llvm.org
Thu Sep 7 08:36:02 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/65621:

This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on #65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-

>From ca428094f035f1fa332431628cb9cba96e43b87c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 6 Sep 2023 07:20:59 +0000
Subject: [PATCH 1/2] [mlir][llvm] Return failure from type converter for n-D
 scalable vectors

This patch changes vector type conversion to return failure on n-D
scalable vector types instead of asserting.

This is an alternative approach to #65261 that aims to enable lowering
of Vector ops directly to ArmSME intrinsics where possible, and seems
more consistent with other type conversions. It's trivial to hit the
assert at the moment and it could be interpreted as n-D scalable vector
types being a bug, when they're valid types in the Vector dialect.

By returning failure it will generally fail more gracefully,
particularly for release builds or other builds where assertions are
disabled.
---
 .../Conversion/LLVMCommon/TypeConverter.h     |  2 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 19 +++++++++++--------
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d9..2a4327535c68750 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
   Type convertMemRefToBarePtr(BaseMemRefType type) const;
 
   /// Convert a 1D vector type into an LLVM vector type.
-  Type convertVectorType(VectorType type) const;
+  FailureOr<Type> convertVectorType(VectorType type) const;
 
   /// Options for customizing the llvm lowering.
   LowerToLLVMOptions options;
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..49e0513e629d951 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addConversion([&](MemRefType type) { return convertMemRefType(type); });
   addConversion(
       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
-  addConversion([&](VectorType type) { return convertVectorType(type); });
+  addConversion([&](VectorType type) -> std::optional<Type> {
+    FailureOr<Type> llvmType = convertVectorType(type);
+    if (failed(llvmType))
+      return std::nullopt;
+    return llvmType;
+  });
 
   // LLVM-compatible types are legal, so add a pass-through conversion. Do this
   // before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
 ///  * 1-D `vector<axT>` remains as is while,
 ///  * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
 ///    `!llvm.array<ax...array<jxvector<kxT>>>`.
-/// As LLVM does not support arrays of scalable vectors, it is assumed that
-/// scalable vectors are always 1-D. This condition could be relaxed once the
-/// missing functionality is added in LLVM
-Type LLVMTypeConverter::convertVectorType(VectorType type) const {
+/// Returns failure for n-D scalable vector types as LLVM does not support
+/// arrays of scalable vectors.
+FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   auto elementType = convertType(type.getElementType());
   if (!elementType)
     return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
-  assert(
-      (!type.isScalable() || (type.getRank() == 1)) &&
-      "expected 1-D scalable vector (n-D scalable vectors are not supported)");
+  if (type.isScalable() && (type.getRank() > 1))
+    return failure();
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
     vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);

>From 1999394c189a4fc2e5489b633a814cf4c7395613 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 30 Aug 2023 14:41:35 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA

This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:

  vector<[8]xf16>,  vector<[8]xf16>  -> vector<[8]x[8]xf16>
  vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
  vector<[4]xf32>,  vector<[4]xf32>  -> vector<[4]x[4]xf32>
  vector<[2]xf64>,  vector<[2]xf64>  -> vector<[2]x[2]xf64>

The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].

Note at the ISA level these variants are implemented by different
architecture features, these are listed below:

  FMOPA (non-widening)
    * half-precision   - +sme2p1,+sme-f16f16
    * single-precision - +sme
    * double-precision - +sme-f64f64
  BFMOPA (non-widening)
    * half-precision   - +sme2p1,+b16b16

There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.

Masking is currently unsupported.

Depends on #65450.

[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |   2 +
 .../Transforms/LegalizeForLLVMExport.cpp      | 116 +++++++++++++++++-
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp       |   2 -
 .../Vector/Transforms/LowerVectorContract.cpp |   5 +-
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    | 107 +++++++++++++++-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |  93 ++++++++++++++
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |  50 ++++++++
 7 files changed, 367 insertions(+), 8 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 554b9f119230667..9e8ad48b3c2db94 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -20,6 +20,8 @@
 namespace mlir {
 namespace arm_sme {
 
+constexpr unsigned MinStreamingVectorLengthInBits = 128;
+
 /// Return minimum number of elements for the given element `type` in
 /// a vector of SVL bits.
 unsigned getSMETileSliceMinNumElts(Type type);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 685f8d57f76f52c..ef4ae754f3b3b03 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -361,6 +361,111 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
+/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+///
+/// Example:
+///
+///   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+///     : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+///   "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
+///     : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+///        vector<[4]xf32>) -> ()
+///
+/// Currently only supports FMOPA and BFMOPA (non-widening).
+struct VectorOuterProductToArmSMELowering
+    : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
+  using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp,
+                  vector::OuterProductOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto isSupportedType = [](VectorType vectorType) {
+      // TODO: the FP outer product instruction variants are predicated on
+      // different features:
+      //
+      // * FMOPA (non-widening)
+      //   * half-precision   - +sme2p1,+sme-f16f16
+      //   * single-precision - +sme
+      //   * double-precision - +sme-f64f64
+      // * BFMOPA
+      //   * half-precision   - +sme2p1,+b16b16
+      //
+      // It should be possible to control lowering based on target features.
+      if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
+        return false;
+
+      auto elementType = vectorType.getElementType();
+
+      if (!elementType.isF16() && !elementType.isBF16() &&
+          !elementType.isF32() && !elementType.isF64())
+        return false;
+
+      unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
+                            vectorType.getElementTypeBitWidth();
+      if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+        return false;
+
+      return true;
+    };
+
+    auto resultVectorType = outerProductOp.getResultVectorType();
+    if (!isSupportedType(resultVectorType))
+      return outerProductOp.emitError("unsupported type");
+
+    vector::CombiningKind kind = outerProductOp.getKind();
+    if (kind != vector::CombiningKind::ADD)
+      // TODO: support subtract.
+      return outerProductOp.emitError("unsupported kind");
+
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: support masking.
+      return outerProductOp.emitError("masking is currently unsupported");
+
+    if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+      // AXPY operation not suited for SME.
+      return failure();
+
+    auto loc = outerProductOp.getLoc();
+
+    Value acc = outerProductOp.getAcc();
+    if (!acc)
+      // Initalize accumulator with zero.
+      acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+
+    unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
+    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
+        loc, rewriter.getIntegerType(elementWidth), acc);
+
+    // Create all active predicate mask.
+    auto one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI1Type(),
+        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+    auto predTy =
+        VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
+                        /*scalableDims=*/{true});
+    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+    auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+
+    // Create 'arm_sme.intr.mopa' outer product intrinsic.
+    rewriter.create<arm_sme::aarch64_sme_mopa>(
+        loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
+        outerProductOp.getRhs());
+
+    // Create `CastTileToVectorOp` to use as the output.
+    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
+        outerProductOp, resultVectorType, tileId);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +479,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
       arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
       arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+      arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
+      arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
+  target.addIllegalOp<vector::OuterProductOp>();
 
   // Mark 'func.func' ops as legal if either:
   //   1. no 'arm_za' function attribute is present.
@@ -405,7 +512,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
-  patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
-               LoadTileSliceToArmSMELowering,
-               MoveVectorToTileSliceToArmSMELowering>(converter);
+  patterns
+      .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
+           LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
+           VectorOuterProductToArmSMELowering>(converter);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index 8b2be7bc1901b9a..b8a47951cc7bbba 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -17,8 +17,6 @@
 using namespace mlir;
 using namespace mlir::arm_sme;
 
-static constexpr unsigned MinStreamingVectorLengthInBits = 128;
-
 unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
   assert(isValidSMETileElementType(type) && "invalid tile type!");
   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index b66077372164e79..95a010dd59d95bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 
   LogicalResult matchAndRewrite(vector::OuterProductOp op,
                                 PatternRewriter &rewriter) const override {
+    VectorType resType = op.getResultVectorType();
+    if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
+      return failure();
+
     auto loc = op.getLoc();
 
     VectorType lhsType = op.getOperandVectorTypeLHS();
     VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
-    VectorType resType = op.getResultVectorType();
     Type eltType = resType.getElementType();
     bool isInt = isa<IntegerType, IndexType>(eltType);
     Value acc = op.getAcc();
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index af528295ef6ee23..ac77c6a897f1a3a 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,8 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// vector.transfer_write
+//===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @transfer_write_2d_zero_i8(
 // CHECK-SAME:                             %[[ARG0:.*]]: memref<?x?xi8>)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
 // -----
 
 // Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
   return %tile : vector<[1]x[1]xi128>
 }
 
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
 // -----
 
 // CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+  // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
+  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
+  // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_bf16
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f32
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+  // CHECK-NOT: arith.extui
+  // CHECK-NOT: arith.trunci
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f64
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // CHECK: arith.trunci {{.*}} : i64 to i32
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_no_accumulator
+func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
+  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_scalar_rhs
+func.func @vector_outerproduct_scalar_rhs(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
+  // CHECK-NOT: arm_sme
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
+  return %0 : vector<[2]xf64>
+}
+
+// -----
+
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+  // expected-error at +1 {{unsupported type}}
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+}
+
+// -----
+
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+  // expected-error at +1 {{unsupported kind}}
+  %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+  // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+  // expected-error at +1 {{masking is currently unsupported}}
+  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
new file mode 100644
index 000000000000000..82fc5e751d28c96
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -0,0 +1,93 @@
+// DEFINE: %{entry_point} = test_outerproduct_4x4xf32
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// REDEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-NO-ACC
+
+func.func @test_outerproduct_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f10 = arith.constant 10.0 : f32
+
+  %a = vector.splat %f1 : vector<[4]xf32>
+  %b = vector.splat %f2 : vector<[4]xf32>
+  // TODO: vector.splat doesn't support ArmSME.
+  %c = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+
+  %tile = vector.outerproduct %a, %b, %c : vector<[4]xf32>, vector<[4]xf32>
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem = memref.alloca(%za_s_size) : memref<?xf32>
+
+  // Store the tile to memory.
+  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
+
+  // Reload and print. The smallest SVL is 128-bits so the tile will be at
+  // least 4x4xf32.
+  //
+  // CHECK:      ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
+    vector.print %tileslice : vector<[4]xf32>
+  }
+
+  return
+}
+
+func.func @test_outerproduct_no_accumulator_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f10 = arith.constant 10.0 : f32
+
+  %a = vector.splat %f1 : vector<[4]xf32>
+  %b = vector.splat %f2 : vector<[4]xf32>
+
+  %tile = vector.outerproduct %a, %b : vector<[4]xf32>, vector<[4]xf32>
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem = memref.alloca(%za_s_size) : memref<?xf32>
+
+  // Store the tile to memory.
+  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
+
+  // Reload and print. The smallest SVL is 128-bits so the tile will be at
+  // least 4x4xf32.
+  //
+  // CHECK-NO-ACC:      ( 2, 2, 2, 2
+  // CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
+  // CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
+  // CHECK-NO-ACC-NEXT: ( 2, 2, 2, 2
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
+    vector.print %tileslice : vector<[4]xf32>
+  }
+
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
new file mode 100644
index 000000000000000..344973a4ddc7825
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -0,0 +1,50 @@
+// DEFINE: %{entry_point} = test_outerproduct_2x2xf64
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+func.func @test_outerproduct_2x2xf64() {
+  %c0 = arith.constant 0 : index
+  %f1 = arith.constant 1.0 : f64
+  %f2 = arith.constant 2.0 : f64
+  %f10 = arith.constant 10.0 : f64
+
+  %a = vector.splat %f1 : vector<[2]xf64>
+  %b = vector.splat %f2 : vector<[2]xf64>
+  // TODO: vector.splat doesn't support ArmSME.
+  %c = vector.broadcast %f10 : f64 to vector<[2]x[2]xf64>
+
+  %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
+
+  // Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
+  %vscale = vector.vscale
+  %min_elts_d = arith.constant 2 : index
+  %svl_d = arith.muli %min_elts_d, %vscale : index
+  %za_d_size = arith.muli %svl_d, %svl_d : index
+
+  // Allocate memory.
+  %mem = memref.alloca(%za_d_size) : memref<?xf64>
+
+  // Store the tile to memory.
+  vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+
+  // Reload and print. The smallest SVL is 128-bits so the tile will be at
+  // least 2x2xf64.
+  //
+  // CHECK:      ( 12, 12
+  // CHECK-NEXT: ( 12, 12
+  scf.for %i = %c0 to %za_d_size step %svl_d {
+    %tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
+    vector.print %tileslice : vector<[2]xf64>
+  }
+
+  return
+}



More information about the Mlir-commits mailing list