[Mlir-commits] [mlir] 042800a - [mlir][ArmSME] Add initial SME vector legalization pass (#79152)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 31 03:55:26 PST 2024
Author: Benjamin Maxwell
Date: 2024-01-31T11:55:22Z
New Revision: 042800a4dd79375ec0895c8959a43c86149232f3
URL: https://github.com/llvm/llvm-project/commit/042800a4dd79375ec0895c8959a43c86149232f3
DIFF: https://github.com/llvm/llvm-project/commit/042800a4dd79375ec0895c8959a43c86149232f3.diff
LOG: [mlir][ArmSME] Add initial SME vector legalization pass (#79152)
This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.
For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.
In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
Added:
mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
mlir/test/Dialect/ArmSME/vector-legalization.mlir
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
mlir/lib/Dialect/ArmSME/IR/Utils.cpp
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index bb49ce4c62723..c2f1b1f1b874e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -36,6 +36,9 @@ std::unique_ptr<Pass> createTileAllocationPass();
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 844e1957efc0a..66027c5ba77bd 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -156,4 +156,27 @@ def OuterProductFusion
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
}
+def VectorLegalization
+ : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+ let summary = "Legalize vectors for ArmSME";
+ let description = [{
+ This pass legalizes vector operations so that they can be lowered to ArmSME.
+ This includes decomposing operations that operate on vector types larger
+ than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+ tile-sized operations, as well as rewrites needed to get operations into
+ forms compatible with SME lowerings.
+
+ Note: Decomposition is currently limited to vector types that are an exact
+ multiple of SME tiles. That is scalable in two dimensions, with both the
+ rows and columns divisible by the SVE vector length for the element type.
+ }];
+ let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+ let dependentDialects = [
+ "func::FuncDialect",
+ "arm_sme::ArmSMEDialect",
+ "vector::VectorDialect",
+ "arith::ArithDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48f..027ad8954f92f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
PatternRewriter &rewriter, Location loc, Value initTile,
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
+/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
+/// the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d039121055566..6a9e022182226 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
return forOp;
}
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+ if (vType.getRank() != 2 || !vType.allDimsScalable())
+ return false;
+
+ auto elementType = vType.getElementType();
+ if (!isValidSMETileElementType(elementType))
+ return false;
+
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+ int64_t vectorRows = vType.getDimSize(0);
+ int64_t vectorCols = vType.getDimSize(1);
+
+ return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+ vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+ return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index c06f9d3cc7a9f..600f2ecdb51bc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
OuterProductFusion.cpp
TileAllocation.cpp
+ VectorLegalization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -10,10 +11,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRArmSMETransformsIncGen
LINK_LIBS PUBLIC
+ MLIRPass
MLIRArmSMEDialect
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRSCFDialect
- MLIRPass
+ MLIRSCFTransforms
+ MLIRFuncTransforms
)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 0000000000000..85ec53c2618aa
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,380 @@
+//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass legalizes vector operations so they can be lowered to ArmSME.
+// Currently, this only implements the decomposition of vector operations that
+// use vector sizes larger than an SME tile, into multiple SME-sized operations.
+//
+// Note: In the context of this pass 'tile' always refers to an SME tile.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+// Common match failure reasons.
+static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
+ "op vector size is not multiple of SME tiles");
+static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
+ "op mask is unsupported for legalization/decomposition");
+static constexpr StringLiteral
+ MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
+
+/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
+/// larger vector type. The (`row`, `col`) are the position of the tile in the
+/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
+/// sub-tiles, we would have:
+///
+/// 8 x vscale
+/// ┌─────────────┬─────────────┐
+/// │(0,0) │(0,4) │
+/// │ │ │
+/// ├─────────────┼─────────────┤ 8 x vscale
+/// │(4,0) │(4,4) │
+/// │ │ │
+/// └─────────────┴─────────────┘
+struct SMESubTile {
+ // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+ int row{0};
+ int col{0};
+ // The SME tile type.
+ VectorType type;
+};
+
+/// Adds a constant elementwise scalable offset to `indices` (which are of equal
+/// length). For example, in the 2D case this would return:
+// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+ Location loc,
+ ValueRange indices,
+ ArrayRef<int> scalableOffsets) {
+ auto vscale = builder.create<vector::VectorScaleOp>(loc);
+ return llvm::map_to_vector(
+ llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
+ auto [index, base] = pair;
+ auto offset = builder.create<arith::MulIOp>(
+ loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+ return builder.create<arith::AddIOp>(loc, index, offset);
+ });
+}
+
+/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
+/// indices for one of the SME sub-tiles it will decompose into.
+///
+/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
+/// indices for each tile would need to be adjusted as follows:
+///
+/// initial indices = [a,b], inital size = 8x8, target size = 4x4
+/// ┌─────────────┬─────────────┐
+/// │[a,b] │[a,b+4] │
+/// │ │ │
+/// ├─────────────┼─────────────┤
+/// │[a+4,b] │[a+4,b+4] │
+/// │ │ │
+/// └─────────────┴─────────────┘
+SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
+ ValueRange indices,
+ SMESubTile smeTile) {
+ return addConstantScalableOffset(builder, loc, indices,
+ {smeTile.row, smeTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+/// TODO: Add support for vector.constant_mask once required for SME.
+bool isSupportedMaskOp(Value mask) {
+ return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+ SMESubTile smeTile) {
+ assert(isSupportedMaskOp(mask));
+ if (!mask)
+ return Value{};
+ auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+ // The operands of `vector.create_mask` (from a 2D perspective) are the
+ // coordinates where the mask ends. So we subtract where this tile starts,
+ // from the mask operands to get the parameters for this sub-tile.
+ auto smeTileMaskDims = addConstantScalableOffset(
+ builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
+ auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
+ loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
+ return smeTileCreateMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType. For example, if decomposing an [8]x[8] into
+/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
+/// (4, 4).
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+ VectorType smeTileType,
+ bool transposeIndices = false) {
+ assert(isMultipleOfSMETileVectorType(type) &&
+ "`type` not multiple of SME tiles");
+ return llvm::map_range(
+ StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+ smeTileType.getDimSize(1)}),
+ [=](auto indices) {
+ int row = int(indices[0]);
+ int col = int(indices[1]);
+ if (transposeIndices)
+ std::swap(row, col);
+ return SMESubTile{row, col, smeTileType};
+ });
+}
+
+/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
+/// `type`.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+ assert(isMultipleOfSMETileVectorType(type) &&
+ "`type` not multiple of SME tiles");
+ int64_t vectorRows = type.getDimSize(0);
+ int64_t vectorCols = type.getDimSize(1);
+ auto elementType = type.getElementType();
+ unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+ return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeVectorOuterProductOpsByDecomposition
+ : public OneToNOpConversionPattern<vector::OuterProductOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = outerProductOp.getResultVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(
+ outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+
+ Value mask;
+ Operation *rootOp = outerProductOp;
+ auto loc = outerProductOp.getLoc();
+ if (outerProductOp.isMasked()) {
+ auto maskOp = outerProductOp.getMaskingOp();
+ mask = maskOp.getMask();
+ rootOp = maskOp;
+ }
+
+ if (!isSupportedMaskOp(mask))
+ return rewriter.notifyMatchFailure(outerProductOp,
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+
+ ValueRange accSMETiles = adaptor.getAcc();
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+ VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+ SmallVector<Value> resultSMETiles;
+ for (auto [index, smeTile] : llvm::enumerate(
+ decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
+
+ auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
+ auto lhs = rewriter.create<vector::ScalableExtractOp>(
+ loc, sliceType, outerProductOp.getLhs(), smeTile.row);
+ auto rhs = rewriter.create<vector::ScalableExtractOp>(
+ loc, sliceType, outerProductOp.getRhs(), smeTile.col);
+ auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
+ loc, smeTileType, lhs, rhs,
+ !accSMETiles.empty() ? accSMETiles[index] : Value{},
+ outerProductOp.getKind());
+
+ auto maskedOuterProduct =
+ vector::maskOperation(rewriter, smeOuterProduct, smeMask);
+ resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+ }
+
+ rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+ return success();
+ }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOpsByDecomposition
+ : public OneToNOpConversionPattern<vector::MaskOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ if (auto outerProductOp =
+ llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+ LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
+ getContext());
+ return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+ outerProductOp, rewriter);
+ }
+ return failure();
+ }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferReadOpsByDecomposition
+ : public OneToNOpConversionPattern<vector::TransferReadOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = readOp.getVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(
+ readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+
+ auto mask = readOp.getMask();
+ if (!isSupportedMaskOp(mask))
+ return rewriter.notifyMatchFailure(readOp,
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+
+ auto permutationMap = readOp.getPermutationMap();
+ if (!permutationMap.isPermutation())
+ return rewriter.notifyMatchFailure(readOp,
+ MATCH_FAILURE_NON_PERMUTATION_MAP);
+
+ // Note: For 2D vector types the only non-identity permutation is a simple
+ // tranpose [1, 0].
+ bool transposed = !permutationMap.isIdentity();
+
+ auto loc = readOp.getLoc();
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+
+ SmallVector<Value> resultSMETiles;
+ for (SMESubTile smeTile :
+ decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
+ auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
+ auto smeRead = rewriter.create<vector::TransferReadOp>(
+ loc, smeTileType, readOp.getSource(),
+ getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
+ readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
+ readOp.getInBoundsAttr());
+ resultSMETiles.push_back(smeRead);
+ }
+
+ rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+ return success();
+ }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferWriteOpsByDecomposition
+ : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = writeOp.getVectorType();
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(
+ writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
+
+ auto mask = writeOp.getMask();
+ if (!isSupportedMaskOp(mask))
+ return rewriter.notifyMatchFailure(writeOp,
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isPermutation())
+ return rewriter.notifyMatchFailure(writeOp,
+ MATCH_FAILURE_NON_PERMUTATION_MAP);
+
+ // Note: For 2D vector types the only non-identity permutation is a simple
+ // tranpose [1, 0].
+ bool transposed = !permutationMap.isIdentity();
+
+ auto loc = writeOp.getLoc();
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+ auto inputSMETiles = adaptor.getVector();
+
+ Value destTensorOrMemref = writeOp.getSource();
+ for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
+ rewriter, vectorType, smeTileType, transposed))) {
+ auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
+ auto smeWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, inputSMETiles[index], destTensorOrMemref,
+ getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
+ writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
+ if (writeOp.hasPureTensorSemantics())
+ destTensorOrMemref = smeWrite.getResult();
+ }
+
+ if (writeOp.hasPureTensorSemantics())
+ rewriter.replaceOp(writeOp, destTensorOrMemref);
+ else
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+};
+
+struct VectorLegalizationPass
+ : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+ void runOnOperation() override {
+ auto *context = &getContext();
+ OneToNTypeConverter converter;
+ RewritePatternSet patterns(context);
+
+ converter.addConversion([](Type type) { return type; });
+ converter.addConversion(
+ [](VectorType vectorType,
+ SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return std::nullopt;
+ auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
+ auto smeTileType =
+ getSMETileTypeForElement(vectorType.getElementType());
+ types = SmallVector<Type>(smeTileCount, smeTileType);
+ return success();
+ });
+
+ // Note: High benefit to ensure masked outer products are lowered first.
+ patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
+ converter, context, 1024);
+ patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+ LegalizeTransferReadOpsByDecomposition,
+ LegalizeTransferWriteOpsByDecomposition>(converter, context);
+ populateFuncTypeConversionPatterns(converter, patterns);
+ scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+ if (failed(applyPartialOneToNConversion(getOperation(), converter,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+ return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 0000000000000..a20abeefedcfd
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME: %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+ // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+ %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+ return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME: %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME: %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32>
+{
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+ // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+ %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32>
+ return %0 : vector<[4]x[16]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4(
+// CHECK-SAME: %[[LHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[LHS_DIM:.*]]: index,
+// CHECK-SAME: %[[RHS_DIM:.*]]: index)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32>
+{
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index
+ // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+ // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index
+ // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index
+ // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+ // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+ // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+ // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index
+ // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+ // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index
+ // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+ // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index
+ // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+ // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+ // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+ %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1>
+ %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32>
+ return %0 : vector<[16]x[4]xf32>
+}
+
+// -----
+
+/// This demonstrates a rectangular tiling that uses all f64 accumulators.
+
+// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc(
+// CHECK-SAME: %[[LHS:.*]]: vector<[8]xf64>,
+// CHECK-SAME: %[[RHS:.*]]: vector<[4]xf64>)
+// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>)
+func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64>
+{
+ // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64>
+ // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64>
+ // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64>
+ // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64>
+ // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64>
+ // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64>
+ // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+ // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>
+ %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64>
+ return %0 : vector<[8]x[4]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_f32_scalable_8x8(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>)
+// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>)
+func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32>
+{
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32>
+ return %0 : vector<[8]x[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi16>,
+// CHECK-SAME: %[[DIM0:.*]]: index,
+// CHECK-SAME: %[[DIM1:.*]]: index)
+// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>)
+func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16>
+{
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+ // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index
+ // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1>
+ // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1>
+ // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0 : i16
+ %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1>
+ %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16>
+ return %0 : vector<[8]x[16]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_f16_scalable_16x8(
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf16>,
+// CHECK-SAME: %[[TOP:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME: %[[BOTTOM:.*]]: vector<[8]x[8]xf16>)
+func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
+{
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+ // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+ // CHECK-NEXT: return
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
+ return
+}
+
+// -----
+
+/// This is already a legal type. It should be ignored.
+
+// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked
+func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index)
+{
+ // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1>
+ vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+ return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+ // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-NEXT: return
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+ vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+ // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+ // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+ // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+ // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ // CHECK-NEXT: return
+ %c0 = arith.constant 0 : index
+ %pad = arith.constant 0.0 : f32
+ %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32>
+ vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
new file mode 100644
index 0000000000000..327f237ba8948
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -0,0 +1,115 @@
+// RUN: mlir-opt %s \
+// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
+// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
+// RUN: -arm-sme-vector-legalization -canonicalize -cse \
+// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN: -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN: -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -e=main -entry-point-result=void \
+// RUN: -march=aarch64 -mattr="+sve,+sme" \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+/// This is very similar to the SME matmul.mlir test, except that it uses a tile
+/// size of [8]x[8]xf32, which is larger than a 32-bit SME virtual tile, which
+/// would be [4]x[4]xf32. The [8]x[8] tile will be decomposed into four
+/// by the `-arm-sme-vector-legalization` pass, which should then use all 32-bit
+/// SME accumulators.
+
+func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
+ %res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+ outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+ %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
+ call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+ return
+}
+
+func.func @main() {
+ /// Set SVL to 128-bit. This ensures this small matmul will use all four
+ /// 32-bit SME virtual tiles.
+ %c128 = arith.constant 128 : i32
+ func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+ %c0 = arith.constant 0 : i32
+ %c7 = arith.constant 7 : index
+
+ %A = arith.constant dense<[
+ [ 1., 8., 15., 22., 29., 36., 43., 50., 57., 64., 71., 78., 85.],
+ [ 2., 9., 16., 23., 30., 37., 44., 51., 58., 65., 72., 79., 86.],
+ [ 3., 10., 17., 24., 31., 38., 45., 52., 59., 66., 73., 80., 87.],
+ [ 4., 11., 18., 25., 32., 39., 46., 53., 60., 67., 74., 81., 88.],
+ [ 5., 12., 19., 26., 33., 40., 47., 54., 61., 68., 75., 82., 89.],
+ [ 6., 13., 20., 27., 34., 41., 48., 55., 62., 69., 76., 83., 90.],
+ [ 7., 14., 21., 28., 35., 42., 49., 56., 63., 70., 77., 84., 91.]
+ ]> : tensor<7x13xf32>
+
+ %B_init = tensor.empty() : tensor<13x7xf32>
+ %B = linalg.transpose ins(%A: tensor<7x13xf32>)
+ outs(%B_init: tensor<13x7xf32>) permutation = [1, 0]
+
+ %A_dyn = tensor.cast %A : tensor<7x13xf32> to tensor<?x?xf32>
+ %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
+
+ %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
+ %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
+ // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
+ // CHECK: [33514, 34086, 34658, 35230, 35802, 36374, 36946]
+ // CHECK: [34073, 34658, 35243, 35828, 36413, 36998, 37583]
+ // CHECK: [34632, 35230, 35828, 36426, 37024, 37622, 38220]
+ // CHECK: [35191, 35802, 36413, 37024, 37635, 38246, 38857]
+ // CHECK: [35750, 36374, 36998, 37622, 38246, 38870, 39494]
+ // CHECK: [36309, 36946, 37583, 38220, 38857, 39494, 40131]
+ call @matmul(%A_dyn, %B_dyn, %C) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> ()
+
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module : !transform.any_op {transform.consumed}) {
+ %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 1: Tile for size [8] x [8], which corresponds to (2 x SVLs) x (2 x SVLs),
+ // where SVLs is the number of 32-bit elements in a vector of SVL bits.
+ // This uses all four 32-bit SME virtual tiles.
+ %tiled_linalg_op, %loop_i, %loop_j, %loop_k = transform.structured.tile_using_for %matmul[[8], [8], 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">, !transform.op<"scf.for">)
+
+ // Step 2: Vectorize.
+ transform.structured.vectorize %tiled_linalg_op vector_sizes [[8], [8], 1]
+ : !transform.any_op
+
+ // Step 3: Bufferize ahead of TransferReadDropUnitDimsPattern, which
+ // currently only supports memrefs.
+ %bufferize = transform.bufferization.one_shot_bufferize %module
+ {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+ %func = transform.structured.match ops{["func.func"]} in %bufferize
+ : (!transform.any_op) -> !transform.any_op
+
+ // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns).
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_masked_transfers
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ transform.apply_patterns.vector.reduction_to_contract
+ } : !transform.any_op
+
+ // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit
+ // dims, specifically to prevent vector.transfer_read of vector<[8]x1xf32>,
+ // which can't be lowered in generic path.
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_masks
+ transform.apply_patterns.vector.rank_reducing_subview_patterns
+ } : !transform.any_op
+
+ transform.yield
+ }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
new file mode 100644
index 0000000000000..0827d9b7464ad
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN: -convert-arm-sme-to-llvm \
+// RUN: -convert-vector-to-llvm=enable-arm-sve \
+// RUN: -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -e=main -entry-point-result=void \
+// RUN: -march=aarch64 -mattr="+sve,+sme" \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+ %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+ scf.for %i = %c0 to %rows step %c1 {
+ scf.for %j = %c0 to %cols step %c1 {
+ %n = arith.addi %i, %c1 : index
+ %val = arith.index_cast %n : index to i32
+ %val_f32 = arith.sitofp %val : i32 to f32
+ memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+ }
+ }
+ return
+}
+
+func.func @testTransposedReadWithMask(%maskRows: index, %maskCols: index) {
+ %in = memref.alloca() : memref<4x16xf32>
+ %out = memref.alloca() : memref<16x4xf32>
+
+ %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+ %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+ func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+ /// A mask so we only read the first maskRows x maskCols portion of %in.
+ %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+ %pad = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+
+ /// A vector.transfer_read with a transpose permutation map. So the input data
+ /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+ %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+ {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+ /// Write the vector back to a memref (that also has a transposed shape).
+ vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+ /// Print the input memref.
+ vector.print str "Input memref:"
+ %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+ /// Print the result memref.
+ vector.print str "Masked transposed result:"
+ %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+ return
+}
+
+func.func @testTransposedWriteWithMask(%maskRows: index, %maskCols: index) {
+ %in = memref.alloca() : memref<16x4xf32>
+ %out = memref.alloca() : memref<4x16xf32>
+
+ %c0_f32 = arith.constant 0.0 : f32
+ linalg.fill ins(%c0_f32 : f32) outs(%out : memref<4x16xf32>)
+
+ %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+ %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+ func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+ /// A regular read.
+ %c0 = arith.constant 0 : index
+ %read = vector.transfer_read %inDyn[%c0, %c0], %c0_f32 {in_bounds = [true, true]}
+ : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+ /// A mask so we only write the first maskRows x maskCols portion of transpose(%in).
+ %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+ /// Write out the data with a transpose. Here (like the read test) the mask
+ /// matches the shape of the memory, not the vector.
+ vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+ : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+ /// Print the input memref.
+ vector.print str "Input memref:"
+ %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+ /// Print the result memref.
+ vector.print str "Masked transposed result:"
+ %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+ return
+}
+
+func.func @main() {
+ /// Set the SVL to 128-bits (i.e. vscale = 1).
+ /// This test is for the use of multiple tiles rather than scalability.
+ %c128 = arith.constant 128 : i32
+ func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+ // CHECK: Input memref:
+ // CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+ // CHECK-NEXT: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
+ // CHECK-NEXT: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
+ //
+ // CHECK: Masked transposed result:
+ // CHECK: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [1, 2, 0, 0]
+ // CHECK-NEXT: [0, 0, 0, 0]
+ %readMaskRows = arith.constant 2 : index
+ %readMaskCols = arith.constant 15 : index
+ func.call @testTransposedReadWithMask(%readMaskRows, %readMaskCols) : (index, index) -> ()
+
+ // CHECK: Input memref:
+ // CHECK: [1, 1, 1, 1]
+ // CHECK-NEXT: [2, 2, 2, 2]
+ // CHECK-NEXT: [3, 3, 3, 3]
+ // CHECK-NEXT: [4, 4, 4, 4]
+ // CHECK-NEXT: [5, 5, 5, 5]
+ // CHECK-NEXT: [6, 6, 6, 6]
+ // CHECK-NEXT: [7, 7, 7, 7]
+ // CHECK-NEXT: [8, 8, 8, 8]
+ // CHECK-NEXT: [9, 9, 9, 9]
+ // CHECK-NEXT: [10, 10, 10, 10]
+ // CHECK-NEXT: [11, 11, 11, 11]
+ // CHECK-NEXT: [12, 12, 12, 12]
+ // CHECK-NEXT: [13, 13, 13, 13]
+ // CHECK-NEXT: [14, 14, 14, 14]
+ // CHECK-NEXT: [15, 15, 15, 15]
+ // CHECK-NEXT: [16, 16, 16, 16]
+ //
+ // CHECK: Masked transposed result:
+ // CHECK: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [1, 2, 3, 4, 5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0]
+ // CHECK-NEXT: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ %writeMaskRows = arith.constant 3 : index
+ %writeMaskCols = arith.constant 8 : index
+ func.call @testTransposedWriteWithMask(%writeMaskRows, %writeMaskCols) : (index, index) -> ()
+
+ return
+}
+
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)
More information about the Mlir-commits
mailing list