[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 23 07:14:31 PST 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/79152

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes vector operations so that they can be lowered to ArmSME. This initial patch adds decomposition for `vector.outerproduct`, `vector.transfer_read`, and `vector.transfer_write` when they operate on vector types larger than a single SME tile. For example, a [8]x[8]xf32 outer product would be decomposed into four [4]x[4]xf32 outer products, which could then be lowered to ArmSME. These three ops have been picked as supporting them alone allows lowering matmuls that use all ZA accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple of an SME tile size, but other than that any shape can be used. E.g. `vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>` can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to legalize unrolling the reduction dimension of matmuls (which is not type-decomposition), which is why the pass has quite a general name.

>From a38f4e9240f9bad2ff96a8b186908db8914ad841 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 13:51:11 +0000
Subject: [PATCH] [mlir][ArmSME] Add initial SME vector legalization pass

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |   3 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  19 ++
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |   7 +
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  22 ++
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |   5 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 308 ++++++++++++++++++
 .../Dialect/ArmSME/vector-legalization.mlir   | 268 +++++++++++++++
 .../Linalg/CPU/ArmSME/multi-tile-matmul.mlir  | 109 +++++++
 .../CPU/ArmSME/test-multi-tile-transpose.mlir | 171 ++++++++++
 9 files changed, 911 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
 create mode 100644 mlir/test/Dialect/ArmSME/vector-legalization.mlir
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7cd..9ba8c43551257b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805..9a6f5446de0094 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def VectorLegalization
+  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+  let summary = "Legalize vectors for ArmSME";
+  let description = [{
+    This pass legalizes vector operations so that they can be lowered to ArmSME.
+    This includes decomposing operations that operate on vector types larger
+    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+    tile-sized operations, as well as rewrites needed to get operations into
+    forms compatible with SME lowerings.
+  }];
+  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "arm_sme::ArmSMEDialect",
+    "vector::VectorDialect",
+    "arith::ArithDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48fb..e89be8ed81e03f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d0391210555662..6a9e0221822267 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
   return forOp;
 }
 
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+  if (vType.getRank() != 2 || !vType.allDimsScalable())
+    return false;
+
+  auto elementType = vType.getElementType();
+  if (!isValidSMETileElementType(elementType))
+    return false;
+
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+  int64_t vectorRows = vType.getDimSize(0);
+  int64_t vectorCols = vType.getDimSize(1);
+
+  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+         vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb5844204384..3c32fc2645ce1b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   TileAllocation.cpp
+  VectorLegalization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRSCFDialect
-  MLIRPass
+  MLIRSCFTransforms
+  MLIRFuncTransforms
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 00000000000000..a801ebe27413d5
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+  int row{0};
+  int col{0};
+  VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+                                                Location loc,
+                                                ValueRange indices,
+                                                ArrayRef<int> scalableOffset) {
+  auto vscale = builder.create<vector::VectorScaleOp>(loc);
+  return llvm::map_to_vector(
+      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+        auto [index, base] = pair;
+        auto offset = builder.create<arith::MulIOp>(
+            loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+        return builder.create<arith::AddIOp>(loc, index, offset);
+      });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+                                             ValueRange indices,
+                                             SMETile tileTile) {
+  return addConstantScalableOffset(builder, loc, indices,
+                                   {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+                     SMETile tileTile) {
+  assert(isSupportedMaskOp(mask));
+  if (!mask)
+    return Value{};
+  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // coordinates where the mask ends. So we subtract where this tile starts,
+  // from the mask operands to get the parameters for this tile tile.
+  auto tileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+  auto createTileMask = builder.create<vector::CreateMaskOp>(
+      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+  return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+                         VectorType smeTileType,
+                         bool transposeIndices = false) {
+  assert(isMultipleOfSMETileVectorType(type));
+  return llvm::map_range(
+      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+                                              smeTileType.getDimSize(1)}),
+      [=](auto indices) {
+        int row = int(indices[0]);
+        int col = int(indices[1]);
+        if (transposeIndices)
+          std::swap(row, col);
+        return SMETile{row, col, smeTileType};
+      });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+  assert(isMultipleOfSMETileVectorType(type));
+  int64_t vectorRows = type.getDimSize(0);
+  int64_t vectorCols = type.getDimSize(1);
+  auto elementType = type.getElementType();
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::OuterProductOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = outerProductOp.getResultVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    Value mask;
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      mask = maskOp.getMask();
+      rootOp = maskOp;
+    }
+
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    ValueRange accSMETiles = adaptor.getAcc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+    SmallVector<Value> resultSMETiles;
+    for (auto [index, tileTile] :
+         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto lhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+      auto rhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, tileType, lhs, rhs,
+          !accSMETiles.empty() ? accSMETiles[index] : Value{},
+          outerProductOp.getKind());
+
+      auto maskedOuterProduct =
+          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+    }
+
+    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::MaskOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (auto outerProductOp =
+            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+          outerProductOp, rewriter);
+    }
+    return failure();
+  }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+    : public OneToNOpConversionPattern<vector::TransferReadOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = readOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = readOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = readOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = readOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+    SmallVector<Value> resultSMETiles;
+    for (SMETile tileTile :
+         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto transferRead = rewriter.create<vector::TransferReadOp>(
+          loc, tileType, readOp.getSource(),
+          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+          readOp.getInBoundsAttr());
+      resultSMETiles.push_back(transferRead);
+    }
+
+    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = writeOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto inputSMETiles = adaptor.getVector();
+
+    Value destTensorOrMemref = writeOp.getSource();
+    for (auto [index, tileTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, inputSMETiles[index], destTensorOrMemref,
+          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = tileWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
+struct VectorLegalizationPass
+    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    OneToNTypeConverter converter;
+    RewritePatternSet patterns(context);
+
+    converter.addConversion([](Type type) { return type; });
+    converter.addConversion(
+        [](VectorType vectorType,
+           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+          if (!isMultipleOfSMETileVectorType(vectorType))
+            return std::nullopt;
+          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(tileTileCount, tileType);
+          return success();
+        });
+
+    // Note: High benefit to ensure masked outer products are lowered first.
+    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+                 LegalizeTransferWriteOp>(converter, context);
+    populateFuncTypeConversionPatterns(converter, patterns);
+    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+  return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 00000000000000..a20abeefedcfd4
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+  return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32>
+  return %0 : vector<[4]x[16]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4(
+// CHECK-SAME:                                         %[[LHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                         %[[RHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                         %[[LHS_DIM:.*]]: index,
+// CHECK-SAME:                                         %[[RHS_DIM:.*]]: index)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32>
+{
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index
+  // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1>
+  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32>
+  return %0 : vector<[16]x[4]xf32>
+}
+
+// -----
+
+/// This demonstrates a rectangular tiling that uses all f64 accumulators.
+
+// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf64>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[4]xf64>)
+// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>)
+func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64>
+  return %0 : vector<[8]x[4]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_f32_scalable_8x8(
+// CHECK-SAME:                                  %[[SRC:.*]]: memref<?x?xi32>)
+// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>)
+func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked(
+// CHECK-SAME:                                          %[[SRC:.*]]: memref<?x?xi16>,
+// CHECK-SAME:                                          %[[DIM0:.*]]: index,
+// CHECK-SAME:                                          %[[DIM1:.*]]: index)
+// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>)
+func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16>
+  return %0 : vector<[8]x[16]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_f16_scalable_16x8(
+// CHECK-SAME:                                    %[[DEST:.*]]: memref<?x?xf16>,
+// CHECK-SAME:                                    %[[TOP:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME:                                    %[[BOTTOM:.*]]: vector<[8]x[8]xf16>)
+func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
+  return
+}
+
+// -----
+
+/// This is already a legal type. It should be ignored.
+
+// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked
+func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index)
+{
+  // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
+// CHECK-SAME:                                        %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                        %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write(
+// CHECK-SAME:                                         %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
new file mode 100644
index 00000000000000..fb192a829173cb
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
+// RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
+// RUN:   -arm-sme-vector-legalization -canonicalize -cse \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN:   -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
+  %res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+  return
+}
+
+func.func @main() {
+  /// Set SVL to 128-bit. This ensures this small matmul will use all four
+  /// 32-bit SME virtual tiles.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : i32
+  %c7 = arith.constant 7 : index
+
+  %A = arith.constant dense<[
+    [ 1.,  8., 15., 22., 29., 36., 43., 50., 57., 64., 71., 78., 85.],
+    [ 2.,  9., 16., 23., 30., 37., 44., 51., 58., 65., 72., 79., 86.],
+    [ 3., 10., 17., 24., 31., 38., 45., 52., 59., 66., 73., 80., 87.],
+    [ 4., 11., 18., 25., 32., 39., 46., 53., 60., 67., 74., 81., 88.],
+    [ 5., 12., 19., 26., 33., 40., 47., 54., 61., 68., 75., 82., 89.],
+    [ 6., 13., 20., 27., 34., 41., 48., 55., 62., 69., 76., 83., 90.],
+    [ 7., 14., 21., 28., 35., 42., 49., 56., 63., 70., 77., 84., 91.]
+  ]> : tensor<7x13xf32>
+
+  %B_init = tensor.empty() : tensor<13x7xf32>
+  %B = linalg.transpose ins(%A: tensor<7x13xf32>)
+                        outs(%B_init: tensor<13x7xf32>) permutation = [1, 0]
+
+  %A_dyn = tensor.cast %A : tensor<7x13xf32> to tensor<?x?xf32>
+  %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
+
+  %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
+  %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
+  // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
+  // CHECK: [33514, 34086, 34658, 35230, 35802, 36374, 36946]
+  // CHECK: [34073, 34658, 35243, 35828, 36413, 36998, 37583]
+  // CHECK: [34632, 35230, 35828, 36426, 37024, 37622, 38220]
+  // CHECK: [35191, 35802, 36413, 37024, 37635, 38246, 38857]
+  // CHECK: [35750, 36374, 36998, 37622, 38246, 38870, 39494]
+  // CHECK: [36309, 36946, 37583, 38220, 38857, 39494, 40131]
+  call @matmul(%A_dyn, %B_dyn, %C) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> ()
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module : !transform.any_op {transform.consumed}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile for size [8] x [8], which corresponds to (2 x SVLs) x (2 x SVLs),
+    // where SVLs is the number of 32-bit elements in a vector of SVL bits.
+    // This uses all four 32-bit SME virtual tiles.
+    %tiled_linalg_op, %loop_i, %loop_j, %loop_k = transform.structured.tile_using_for %matmul[[8], [8], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">, !transform.op<"scf.for">)
+
+    // Step 2: Vectorize.
+    transform.structured.vectorize %tiled_linalg_op vector_sizes [[8], [8], 1]
+      : !transform.any_op
+
+    // Step 3: Bufferize ahead of TransferReadDropUnitDimsPattern, which
+    // currently only supports memrefs.
+    %bufferize = transform.bufferization.one_shot_bufferize %module
+      {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %bufferize
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns).
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.reduction_to_contract
+    } : !transform.any_op
+
+    // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit
+    // dims, specifically to prevent vector.transfer_read of vector<[8]x1xf32>,
+    // which can't be lowered in generic path.
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_masks
+      transform.apply_patterns.vector.rank_reducing_subview_patterns
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
new file mode 100644
index 00000000000000..7821f7cd865db7
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-arm-sme-to-llvm \
+// RUN:   -convert-vector-to-llvm=enable-arm-sve \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+  %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+  scf.for %i = %c0 to %rows step %c1 {
+    scf.for %j = %c0 to %cols step %c1 {
+      %n = arith.addi %i, %c1 : index
+      %val = arith.index_cast %n : index to i32
+      %val_f32 = arith.sitofp %val : i32 to f32
+      memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+    }
+  }
+  return
+}
+
+func.func @testTransposedReadWithMask() {
+  %in = memref.alloca() : memref<4x16xf32>
+  %out = memref.alloca() : memref<16x4xf32>
+
+  %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  /// A mask so we only read the first 2x15 portion of %in.
+  %maskRows = arith.constant 2 : index
+  %maskCols = arith.constant 15 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A vector.transfer_read with a transpose permutation map. So the input data
+  /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+  %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+    {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// Write the vector back to a memref (that also has a transposed shape).
+  vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 15x2) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @testTransposedWriteWithMask() {
+  %in = memref.alloca() : memref<16x4xf32>
+  %out = memref.alloca() : memref<4x16xf32>
+
+  %fill = arith.constant -1.0 : f32
+  linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+
+  %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A regular read.
+  %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+    : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// A mask so we only write the first 3x8 portion of transpose(%in).
+  %maskRows = arith.constant 3 : index
+  %maskCols = arith.constant 8 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+  /// Write out the data with a transpose. Here (like the read test) the mask
+  /// matches the shape of the memory, not the vector.
+  vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+    : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 3x8) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @main() {
+  /// Set the SVL to 128-bits (i.e. vscale = 1).
+  /// This test is for the use of multiple tiles rather than scalability.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4]
+  //
+  //      CHECK: (Masked 15x2) transposed result:
+  //      CHECK:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [0,   0,   0,   0]
+  func.call @testTransposedReadWithMask() : () -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4]
+  // CHECK-NEXT:  [5,   5,   5,   5]
+  // CHECK-NEXT:  [6,   6,   6,   6]
+  // CHECK-NEXT:  [7,   7,   7,   7]
+  // CHECK-NEXT:  [8,   8,   8,   8]
+  // CHECK-NEXT:  [9,   9,   9,   9]
+  // CHECK-NEXT:  [10,   10,   10,   10]
+  // CHECK-NEXT:  [11,   11,   11,   11]
+  // CHECK-NEXT:  [12,   12,   12,   12]
+  // CHECK-NEXT:  [13,   13,   13,   13]
+  // CHECK-NEXT:  [14,   14,   14,   14]
+  // CHECK-NEXT:  [15,   15,   15,   15]
+  // CHECK-NEXT:  [16,   16,   16,   16]
+  //
+  //      CHECK: (Masked 3x8) transposed result:
+  //      CHECK:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [-1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  func.call @testTransposedWriteWithMask() : () -> ()
+
+  return
+}
+
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)



More information about the Mlir-commits mailing list