[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (PR #65621)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 13 01:42:02 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
<details>
<summary>Changes</summary>
This patch adds support for lowering vector.outerproduct to the ArmSME
MOPA intrinsic for the following types:
vector<[8]xf16>, vector<[8]xf16> -> vector<[8]x[8]xf16>
vector<[8]xbf16>, vector<[8]xbf16> -> vector<[8]x[8]xbf16>
vector<[4]xf32>, vector<[4]xf32> -> vector<[4]x[4]xf32>
vector<[2]xf64>, vector<[2]xf64> -> vector<[2]x[2]xf64>
The FP variants are lowered to FMOPA (non-widening) [1] and BFloat to BFMOPA
(non-widening) [2].
Note at the ISA level these variants are implemented by different
architecture features, these are listed below:
FMOPA (non-widening)
* half-precision - +sme2p1,+sme-f16f16
* single-precision - +sme
* double-precision - +sme-f64f64
BFMOPA (non-widening)
* half-precision - +sme2p1,+b16b16
There's currently no way to target different features when lowering to
ArmSME. Integration tests are added for F32 and F64. We use QEMU to run
the integration tests but SME2 support isn't available yet, it's
targeted for 9.0, so integration tests for these variants excluded.
Masking is currently unsupported.
Depends on #65450.
[1] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/FMOPA--non-widening---Floating-point-outer-product-and-accumulate-
[2] https://developer.arm.com/documentation/ddi0602/2023-06/SME-Instructions/BFMOPA--non-widening---BFloat16-floating-point-outer-product-and-accumulate-
--
Patch is 26.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65621.diff
9 Files Affected:
- (modified) mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h (+1-1)
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+2)
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+11-8)
- (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+113-4)
- (modified) mlir/lib/Dialect/ArmSME/Utils/Utils.cpp (-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+4-1)
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+106-1)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+116)
- (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+77)
<pre>
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index ed174699314e8d9..2a4327535c68750 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -239,7 +239,7 @@ class LLVMTypeConverter : public TypeConverter {
Type convertMemRefToBarePtr(BaseMemRefType type) const;
/// Convert a 1D vector type into an LLVM vector type.
- Type convertVectorType(VectorType type) const;
+ FailureOr<Type> convertVectorType(VectorType type) const;
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 554b9f119230667..9e8ad48b3c2db94 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -20,6 +20,8 @@
namespace mlir {
namespace arm_sme {
+constexpr unsigned MinStreamingVectorLengthInBits = 128;
+
/// Return minimum number of elements for the given element `type` in
/// a vector of SVL bits.
unsigned getSMETileSliceMinNumElts(Type type);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index a9e7ce9d42848b5..49e0513e629d951 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -61,7 +61,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
addConversion([&](MemRefType type) { return convertMemRefType(type); });
addConversion(
[&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
- addConversion([&](VectorType type) { return convertVectorType(type); });
+ addConversion([&](VectorType type) -> std::optional<Type> {
+ FailureOr<Type> llvmType = convertVectorType(type);
+ if (failed(llvmType))
+ return std::nullopt;
+ return llvmType;
+ });
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -490,10 +495,9 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
/// * 1-D `vector<axT>` remains as is while,
/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
/// `!llvm.array<ax...array<jxvector<kxT>>>`.
-/// As LLVM does not support arrays of scalable vectors, it is assumed that
-/// scalable vectors are always 1-D. This condition could be relaxed once the
-/// missing functionality is added in LLVM
-Type LLVMTypeConverter::convertVectorType(VectorType type) const {
+/// Returns failure for n-D scalable vector types as LLVM does not support
+/// arrays of scalable vectors.
+FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
auto elementType = convertType(type.getElementType());
if (!elementType)
return {};
@@ -503,9 +507,8 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) const {
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
- assert(
- (!type.isScalable() || (type.getRank() == 1)) &&
- "expected 1-D scalable vector (n-D scalable vectors are not supported)");
+ if (type.isScalable() && (type.getRank() > 1))
+ return failure();
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 685f8d57f76f52c..6c8843fbb4546e6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -361,6 +361,112 @@ struct MoveVectorToTileSliceToArmSMELowering
}
};
+/// Lower `vector.outerproduct` to SME MOPA intrinsics.
+///
+/// Example:
+///
+/// %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>}
+/// : vector<[4]xf32>, vector<[4]xf32>
+///
+/// is converted to:
+///
+/// "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
+/// : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+/// vector<[4]xf32>) -> ()
+///
+/// Currently only supports FMOPA and BFMOPA (non-widening).
+struct VectorOuterProductToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::OuterProductOp> {
+ using ConvertOpToLLVMPattern<vector::OuterProductOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::OuterProductOp outerProductOp,
+ vector::OuterProductOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto isSupportedType = [](VectorType vectorType) {
+ // TODO: the FP outer product instruction variants are predicated on
+ // different features [1]:
+ //
+ // * FMOPA (non-widening)
+ // * half-precision - +sme2p1,+sme-f16f16
+ // * single-precision - +sme
+ // * double-precision - +sme-f64f64
+ // * BFMOPA
+ // * half-precision - +sme2p1,+b16b16
+ //
+ // It should be possible to control lowering based on target features.
+ // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
+ return false;
+
+ auto elementType = vectorType.getElementType();
+
+ if (!elementType.isF16() && !elementType.isBF16() &&
+ !elementType.isF32() && !elementType.isF64())
+ return false;
+
+ unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
+ vectorType.getElementTypeBitWidth();
+ if (vectorType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+ return false;
+
+ return true;
+ };
+
+ auto resultVectorType = outerProductOp.getResultVectorType();
+ if (!isSupportedType(resultVectorType))
+ return outerProductOp.emitError("unsupported type");
+
+ vector::CombiningKind kind = outerProductOp.getKind();
+ if (kind != vector::CombiningKind::ADD)
+ // TODO: support subtract.
+ return outerProductOp.emitError("unsupported kind");
+
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(outerProductOp.getOperation());
+ if (maskableOp.isMasked())
+ // TODO: support masking.
+ return outerProductOp.emitError("masking is currently unsupported");
+
+ if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
+ // AXPY operation not suited for SME.
+ return failure();
+
+ auto loc = outerProductOp.getLoc();
+
+ Value acc = outerProductOp.getAcc();
+ if (!acc)
+ // Initalize accumulator with zero.
+ acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+
+ unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
+ auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
+ loc, rewriter.getIntegerType(elementWidth), acc);
+
+ // Create all active predicate mask.
+ auto one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI1Type(),
+ rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+ auto predTy =
+ VectorType::get(resultVectorType.getShape()[0], rewriter.getI1Type(),
+ /*scalableDims=*/{true});
+ auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+ auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+
+ // Create 'arm_sme.intr.mopa' outer product intrinsic.
+ rewriter.create<arm_sme::aarch64_sme_mopa>(
+ loc, tileI32, allActiveMask, allActiveMask, outerProductOp.getLhs(),
+ outerProductOp.getRhs());
+
+ // Create `CastTileToVectorOp` to use as the output.
+ rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
+ outerProductOp, resultVectorType, tileId);
+
+ return success();
+ }
+};
+
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -374,8 +480,10 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
+ arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
+ arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
+ target.addIllegalOp<vector::OuterProductOp>();
// Mark 'func.func' ops as legal if either:
// 1. no 'arm_za' function attribute is present.
@@ -405,7 +513,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
- patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
- LoadTileSliceToArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering>(converter);
+ patterns
+ .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
+ LoadTileSliceToArmSMELowering, MoveVectorToTileSliceToArmSMELowering,
+ VectorOuterProductToArmSMELowering>(converter);
}
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index 8b2be7bc1901b9a..b8a47951cc7bbba 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -17,8 +17,6 @@
using namespace mlir;
using namespace mlir::arm_sme;
-static constexpr unsigned MinStreamingVectorLengthInBits = 128;
-
unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
assert(isValidSMETileElementType(type) && "invalid tile type!");
return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index b66077372164e79..95a010dd59d95bc 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1121,11 +1121,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
+ VectorType resType = op.getResultVectorType();
+ if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
+ return failure();
+
auto loc = op.getLoc();
VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
- VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
bool isInt = isa<IntegerType, IndexType>(eltType);
Value acc = op.getAcc();
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index af528295ef6ee23..687ef79385334cf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,8 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// vector.transfer_write
+//===----------------------------------------------------------------------===//
// CHECK-LABEL: @transfer_write_2d_zero_i8(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
return
}
+//===----------------------------------------------------------------------===//
+// vector.load
+//===----------------------------------------------------------------------===//
+
// -----
// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
return %tile : vector<[1]x[1]xi128>
}
+//===----------------------------------------------------------------------===//
+// vector.store
+//===----------------------------------------------------------------------===//
+
// -----
// CHECK-LABEL: @vector_store_i8(
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
+
+//===----------------------------------------------------------------------===//
+// vector.outerproduct
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f16
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+ // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
+ // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
+ // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+ // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_bf16
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
+ // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
+ "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f32
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
+ // CHECK-NOT: arith.extui
+ // CHECK-NOT: arith.trunci
+ // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_add_f64
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+ // CHECK: arith.trunci {{.*}} : i64 to i32
+ // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_no_accumulator
+func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+ // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
+ // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+ %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @vector_outerproduct_unsupported_axpy
+func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> {
+ // CHECK-NOT: arm_sme
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64
+ return %0 : vector<[2]xf64>
+}
+
+// -----
+
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+ // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+ // expected-error at +1 {{unsupported type}}
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
+ "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+}
+
+// -----
+
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+ // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+ // expected-error at +1 {{unsupported kind}}
+ %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
+ "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+}
+
+// -----
+
+func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+ // expected-error at +2 {{failed to legalize operation 'vector.outerproduct'}}
+ // expected-error at +1 {{masking is currently unsupported}}
+ %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+ "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
new file mode 100644
index 000000000000000..00f1f6fd3fa8e19
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -0,0 +1,116 @@
+// DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-...
<truncated>
</pre>
</details>
https://github.com/llvm/llvm-project/pull/65621
More information about the Mlir-commits
mailing list