[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jan 25 09:43:53 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/79152

>From a38f4e9240f9bad2ff96a8b186908db8914ad841 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 13:51:11 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add initial SME vector legalization pass

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |   3 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  19 ++
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |   7 +
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  22 ++
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |   5 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 308 ++++++++++++++++++
 .../Dialect/ArmSME/vector-legalization.mlir   | 268 +++++++++++++++
 .../Linalg/CPU/ArmSME/multi-tile-matmul.mlir  | 109 +++++++
 .../CPU/ArmSME/test-multi-tile-transpose.mlir | 171 ++++++++++
 9 files changed, 911 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
 create mode 100644 mlir/test/Dialect/ArmSME/vector-legalization.mlir
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7cd7..9ba8c43551257b4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e805b..9a6f5446de00946 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def VectorLegalization
+  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+  let summary = "Legalize vectors for ArmSME";
+  let description = [{
+    This pass legalizes vector operations so that they can be lowered to ArmSME.
+    This includes decomposing operations that operate on vector types larger
+    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+    tile-sized operations, as well as rewrites needed to get operations into
+    forms compatible with SME lowerings.
+  }];
+  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "arm_sme::ArmSMEDialect",
+    "vector::VectorDialect",
+    "arith::ArithDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48fbd..e89be8ed81e03fa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d03912105556625..6a9e0221822267a 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
   return forOp;
 }
 
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+  if (vType.getRank() != 2 || !vType.allDimsScalable())
+    return false;
+
+  auto elementType = vType.getElementType();
+  if (!isValidSMETileElementType(elementType))
+    return false;
+
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+  int64_t vectorRows = vType.getDimSize(0);
+  int64_t vectorCols = vType.getDimSize(1);
+
+  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+         vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb58442043843..3c32fc2645ce1b5 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   TileAllocation.cpp
+  VectorLegalization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRSCFDialect
-  MLIRPass
+  MLIRSCFTransforms
+  MLIRFuncTransforms
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 000000000000000..a801ebe27413d5d
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+  int row{0};
+  int col{0};
+  VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+                                                Location loc,
+                                                ValueRange indices,
+                                                ArrayRef<int> scalableOffset) {
+  auto vscale = builder.create<vector::VectorScaleOp>(loc);
+  return llvm::map_to_vector(
+      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+        auto [index, base] = pair;
+        auto offset = builder.create<arith::MulIOp>(
+            loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+        return builder.create<arith::AddIOp>(loc, index, offset);
+      });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+                                             ValueRange indices,
+                                             SMETile tileTile) {
+  return addConstantScalableOffset(builder, loc, indices,
+                                   {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+                     SMETile tileTile) {
+  assert(isSupportedMaskOp(mask));
+  if (!mask)
+    return Value{};
+  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // coordinates where the mask ends. So we subtract where this tile starts,
+  // from the mask operands to get the parameters for this tile tile.
+  auto tileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+  auto createTileMask = builder.create<vector::CreateMaskOp>(
+      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+  return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+                         VectorType smeTileType,
+                         bool transposeIndices = false) {
+  assert(isMultipleOfSMETileVectorType(type));
+  return llvm::map_range(
+      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+                                              smeTileType.getDimSize(1)}),
+      [=](auto indices) {
+        int row = int(indices[0]);
+        int col = int(indices[1]);
+        if (transposeIndices)
+          std::swap(row, col);
+        return SMETile{row, col, smeTileType};
+      });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+  assert(isMultipleOfSMETileVectorType(type));
+  int64_t vectorRows = type.getDimSize(0);
+  int64_t vectorCols = type.getDimSize(1);
+  auto elementType = type.getElementType();
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::OuterProductOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = outerProductOp.getResultVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    Value mask;
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      mask = maskOp.getMask();
+      rootOp = maskOp;
+    }
+
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    ValueRange accSMETiles = adaptor.getAcc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+    SmallVector<Value> resultSMETiles;
+    for (auto [index, tileTile] :
+         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto lhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+      auto rhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, tileType, lhs, rhs,
+          !accSMETiles.empty() ? accSMETiles[index] : Value{},
+          outerProductOp.getKind());
+
+      auto maskedOuterProduct =
+          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+    }
+
+    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::MaskOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (auto outerProductOp =
+            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+          outerProductOp, rewriter);
+    }
+    return failure();
+  }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+    : public OneToNOpConversionPattern<vector::TransferReadOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = readOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = readOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = readOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = readOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+    SmallVector<Value> resultSMETiles;
+    for (SMETile tileTile :
+         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto transferRead = rewriter.create<vector::TransferReadOp>(
+          loc, tileType, readOp.getSource(),
+          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+          readOp.getInBoundsAttr());
+      resultSMETiles.push_back(transferRead);
+    }
+
+    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = writeOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto inputSMETiles = adaptor.getVector();
+
+    Value destTensorOrMemref = writeOp.getSource();
+    for (auto [index, tileTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, inputSMETiles[index], destTensorOrMemref,
+          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = tileWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
+struct VectorLegalizationPass
+    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    OneToNTypeConverter converter;
+    RewritePatternSet patterns(context);
+
+    converter.addConversion([](Type type) { return type; });
+    converter.addConversion(
+        [](VectorType vectorType,
+           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+          if (!isMultipleOfSMETileVectorType(vectorType))
+            return std::nullopt;
+          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(tileTileCount, tileType);
+          return success();
+        });
+
+    // Note: High benefit to ensure masked outer products are lowered first.
+    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+                 LegalizeTransferWriteOp>(converter, context);
+    populateFuncTypeConversionPatterns(converter, patterns);
+    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+  return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 000000000000000..a20abeefedcfd46
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+  return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32>
+  return %0 : vector<[4]x[16]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4(
+// CHECK-SAME:                                         %[[LHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                         %[[RHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                         %[[LHS_DIM:.*]]: index,
+// CHECK-SAME:                                         %[[RHS_DIM:.*]]: index)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32>
+{
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index
+  // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1>
+  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32>
+  return %0 : vector<[16]x[4]xf32>
+}
+
+// -----
+
+/// This demonstrates a rectangular tiling that uses all f64 accumulators.
+
+// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf64>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[4]xf64>)
+// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>)
+func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64>
+  return %0 : vector<[8]x[4]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_f32_scalable_8x8(
+// CHECK-SAME:                                  %[[SRC:.*]]: memref<?x?xi32>)
+// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>)
+func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked(
+// CHECK-SAME:                                          %[[SRC:.*]]: memref<?x?xi16>,
+// CHECK-SAME:                                          %[[DIM0:.*]]: index,
+// CHECK-SAME:                                          %[[DIM1:.*]]: index)
+// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>)
+func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16>
+  return %0 : vector<[8]x[16]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_f16_scalable_16x8(
+// CHECK-SAME:                                    %[[DEST:.*]]: memref<?x?xf16>,
+// CHECK-SAME:                                    %[[TOP:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME:                                    %[[BOTTOM:.*]]: vector<[8]x[8]xf16>)
+func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
+  return
+}
+
+// -----
+
+/// This is already a legal type. It should be ignored.
+
+// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked
+func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index)
+{
+  // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
+// CHECK-SAME:                                        %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                        %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write(
+// CHECK-SAME:                                         %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
new file mode 100644
index 000000000000000..fb192a829173cb8
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
+// RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
+// RUN:   -arm-sme-vector-legalization -canonicalize -cse \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN:   -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
+  %res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+  return
+}
+
+func.func @main() {
+  /// Set SVL to 128-bit. This ensures this small matmul will use all four
+  /// 32-bit SME virtual tiles.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : i32
+  %c7 = arith.constant 7 : index
+
+  %A = arith.constant dense<[
+    [ 1.,  8., 15., 22., 29., 36., 43., 50., 57., 64., 71., 78., 85.],
+    [ 2.,  9., 16., 23., 30., 37., 44., 51., 58., 65., 72., 79., 86.],
+    [ 3., 10., 17., 24., 31., 38., 45., 52., 59., 66., 73., 80., 87.],
+    [ 4., 11., 18., 25., 32., 39., 46., 53., 60., 67., 74., 81., 88.],
+    [ 5., 12., 19., 26., 33., 40., 47., 54., 61., 68., 75., 82., 89.],
+    [ 6., 13., 20., 27., 34., 41., 48., 55., 62., 69., 76., 83., 90.],
+    [ 7., 14., 21., 28., 35., 42., 49., 56., 63., 70., 77., 84., 91.]
+  ]> : tensor<7x13xf32>
+
+  %B_init = tensor.empty() : tensor<13x7xf32>
+  %B = linalg.transpose ins(%A: tensor<7x13xf32>)
+                        outs(%B_init: tensor<13x7xf32>) permutation = [1, 0]
+
+  %A_dyn = tensor.cast %A : tensor<7x13xf32> to tensor<?x?xf32>
+  %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
+
+  %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
+  %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
+  // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
+  // CHECK: [33514, 34086, 34658, 35230, 35802, 36374, 36946]
+  // CHECK: [34073, 34658, 35243, 35828, 36413, 36998, 37583]
+  // CHECK: [34632, 35230, 35828, 36426, 37024, 37622, 38220]
+  // CHECK: [35191, 35802, 36413, 37024, 37635, 38246, 38857]
+  // CHECK: [35750, 36374, 36998, 37622, 38246, 38870, 39494]
+  // CHECK: [36309, 36946, 37583, 38220, 38857, 39494, 40131]
+  call @matmul(%A_dyn, %B_dyn, %C) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> ()
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module : !transform.any_op {transform.consumed}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile for size [8] x [8], which corresponds to (2 x SVLs) x (2 x SVLs),
+    // where SVLs is the number of 32-bit elements in a vector of SVL bits.
+    // This uses all four 32-bit SME virtual tiles.
+    %tiled_linalg_op, %loop_i, %loop_j, %loop_k = transform.structured.tile_using_for %matmul[[8], [8], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">, !transform.op<"scf.for">)
+
+    // Step 2: Vectorize.
+    transform.structured.vectorize %tiled_linalg_op vector_sizes [[8], [8], 1]
+      : !transform.any_op
+
+    // Step 3: Bufferize ahead of TransferReadDropUnitDimsPattern, which
+    // currently only supports memrefs.
+    %bufferize = transform.bufferization.one_shot_bufferize %module
+      {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %bufferize
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns).
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.reduction_to_contract
+    } : !transform.any_op
+
+    // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit
+    // dims, specifically to prevent vector.transfer_read of vector<[8]x1xf32>,
+    // which can't be lowered in generic path.
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_masks
+      transform.apply_patterns.vector.rank_reducing_subview_patterns
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
new file mode 100644
index 000000000000000..7821f7cd865db78
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-arm-sme-to-llvm \
+// RUN:   -convert-vector-to-llvm=enable-arm-sve \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+  %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+  scf.for %i = %c0 to %rows step %c1 {
+    scf.for %j = %c0 to %cols step %c1 {
+      %n = arith.addi %i, %c1 : index
+      %val = arith.index_cast %n : index to i32
+      %val_f32 = arith.sitofp %val : i32 to f32
+      memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+    }
+  }
+  return
+}
+
+func.func @testTransposedReadWithMask() {
+  %in = memref.alloca() : memref<4x16xf32>
+  %out = memref.alloca() : memref<16x4xf32>
+
+  %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  /// A mask so we only read the first 2x15 portion of %in.
+  %maskRows = arith.constant 2 : index
+  %maskCols = arith.constant 15 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A vector.transfer_read with a transpose permutation map. So the input data
+  /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+  %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+    {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// Write the vector back to a memref (that also has a transposed shape).
+  vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 15x2) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @testTransposedWriteWithMask() {
+  %in = memref.alloca() : memref<16x4xf32>
+  %out = memref.alloca() : memref<4x16xf32>
+
+  %fill = arith.constant -1.0 : f32
+  linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+
+  %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A regular read.
+  %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+    : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// A mask so we only write the first 3x8 portion of transpose(%in).
+  %maskRows = arith.constant 3 : index
+  %maskCols = arith.constant 8 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+  /// Write out the data with a transpose. Here (like the read test) the mask
+  /// matches the shape of the memory, not the vector.
+  vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+    : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 3x8) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @main() {
+  /// Set the SVL to 128-bits (i.e. vscale = 1).
+  /// This test is for the use of multiple tiles rather than scalability.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4]
+  //
+  //      CHECK: (Masked 15x2) transposed result:
+  //      CHECK:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [0,   0,   0,   0]
+  func.call @testTransposedReadWithMask() : () -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4]
+  // CHECK-NEXT:  [5,   5,   5,   5]
+  // CHECK-NEXT:  [6,   6,   6,   6]
+  // CHECK-NEXT:  [7,   7,   7,   7]
+  // CHECK-NEXT:  [8,   8,   8,   8]
+  // CHECK-NEXT:  [9,   9,   9,   9]
+  // CHECK-NEXT:  [10,   10,   10,   10]
+  // CHECK-NEXT:  [11,   11,   11,   11]
+  // CHECK-NEXT:  [12,   12,   12,   12]
+  // CHECK-NEXT:  [13,   13,   13,   13]
+  // CHECK-NEXT:  [14,   14,   14,   14]
+  // CHECK-NEXT:  [15,   15,   15,   15]
+  // CHECK-NEXT:  [16,   16,   16,   16]
+  //
+  //      CHECK: (Masked 3x8) transposed result:
+  //      CHECK:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [-1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  func.call @testTransposedWriteWithMask() : () -> ()
+
+  return
+}
+
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)

>From 1be013a50cea56f3df09c26ffcc7ff08f4e27f23 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jan 2024 17:40:16 +0000
Subject: [PATCH 2/2] Fixups

---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  4 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 98 ++++++++++++++-----
 2 files changed, 74 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index e89be8ed81e03fa..027ad8954f92f5a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,8 +56,8 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
-/// Returns true if `vType` is a multiple of an SME tile size. Note returns
-/// false if the `vType` exactly matches the size of an SME tile.
+/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
+/// the `vType` exactly matches the size of an SME tile.
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index a801ebe27413d5d..526246b646da8fc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -1,3 +1,11 @@
+//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -19,6 +27,14 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+// Common match failure reasons.
+static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
+    "op vector size is not multiple of SME tiles");
+static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
+    "op mask is unsupported for legalization/decomposition");
+static constexpr StringLiteral
+    MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
+
 struct SMETile {
   // Note: The units of (row, col) are vscale (as SME tiles are scalable).
   int row{0};
@@ -26,8 +42,9 @@ struct SMETile {
   VectorType type;
 };
 
-/// Adds a constant scalable offset to `indices`. i.e. for 2D:
-/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+/// Adds a constant scalable offset to `indices` (which are of equal length).
+/// For example, in the 2D case this would return:
+// { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
                                                 Location loc,
                                                 ValueRange indices,
@@ -42,8 +59,20 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
       });
 }
 
-/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
-/// for one of the SME tiles it will decompose into.
+/// Remaps `indices` (e.g. from a load/store) for a larger vector type to
+/// indices for one of the SME tiles it will decompose into.
+///
+/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
+/// indices for each tile would need to be remapped as follows:
+///
+/// initial indices = [a,b], inital size = 8x8, target size = 4x4
+/// ┌─────────────┬─────────────┐
+/// │[a,b]        │[a,b+4]      │
+/// │             │             │
+/// ├─────────────┼─────────────┤
+/// │[a+4,b]      │[a+4,b+4]    │
+/// │             │             │
+/// └─────────────┴─────────────┘
 SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
                                              ValueRange indices,
                                              SMETile tileTile) {
@@ -64,7 +93,7 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
   if (!mask)
     return Value{};
   auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
-  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // The operands of `vector.create_mask` (from a 2D perspective) are the
   // coordinates where the mask ends. So we subtract where this tile starts,
   // from the mask operands to get the parameters for this tile tile.
   auto tileMaskDims = addConstantScalableOffset(
@@ -75,7 +104,9 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
 }
 
 /// Constructs an iterator that returns each SME tile (with coordinates)
-/// contained within a VectorType.
+/// contained within a VectorType. For example, if decomposing an [8]x[8] into
+/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
+/// (4, 4).
 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                          VectorType smeTileType,
                          bool transposeIndices = false) {
@@ -92,7 +123,8 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
       });
 }
 
-/// Returns the number of SME tiles that fit into the a vector type.
+/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
+/// `type`.
 int getNumberOfSMETilesForVectorType(VectorType type) {
   assert(isMultipleOfSMETileVectorType(type));
   int64_t vectorRows = type.getDimSize(0);
@@ -102,8 +134,9 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
-/// Legalize `vector.outerproduct` operations to fit within SME tiles.
-struct LegalizeVectorOuterProductOp
+/// Legalize `vector.outerproduct` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeVectorOuterProductOpsByDecomposition
     : public OneToNOpConversionPattern<vector::OuterProductOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -112,7 +145,8 @@ struct LegalizeVectorOuterProductOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = outerProductOp.getResultVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     Value mask;
     Operation *rootOp = outerProductOp;
@@ -124,7 +158,8 @@ struct LegalizeVectorOuterProductOp
     }
 
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     ValueRange accSMETiles = adaptor.getAcc();
     auto tileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -159,7 +194,7 @@ struct LegalizeVectorOuterProductOp
 // conversion adding target materializations in the `vector.mask` region
 // (invalid). This pattern matches on `vector.mask` then calls into the
 // `vector.outerproduct` pattern to work around this issue.
-struct LegalizeMaskedVectorOuterProductOp
+struct LegalizeMaskedVectorOuterProductOpsByDecomposition
     : public OneToNOpConversionPattern<vector::MaskOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -168,7 +203,8 @@ struct LegalizeMaskedVectorOuterProductOp
                   OneToNPatternRewriter &rewriter) const override {
     if (auto outerProductOp =
             llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
-      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
+                                                           getContext());
       return static_cast<RewritePattern &>(pattern).matchAndRewrite(
           outerProductOp, rewriter);
     }
@@ -176,8 +212,9 @@ struct LegalizeMaskedVectorOuterProductOp
   }
 };
 
-/// Legalize `vector.transfer_read` operations to fit within SME tiles.
-struct LegalizeTransferReadOp
+/// Legalize `vector.transfer_read` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferReadOpsByDecomposition
     : public OneToNOpConversionPattern<vector::TransferReadOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -186,15 +223,18 @@ struct LegalizeTransferReadOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = readOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     auto mask = readOp.getMask();
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(readOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     auto permutationMap = readOp.getPermutationMap();
     if (!permutationMap.isPermutation())
-      return failure();
+      return rewriter.notifyMatchFailure(readOp,
+                                         MATCH_FAILURE_NON_PERMUTATION_MAP);
 
     // Note: For 2D vector types the only non-identity permutation is a simple
     // tranpose [1, 0].
@@ -220,8 +260,9 @@ struct LegalizeTransferReadOp
   }
 };
 
-/// Legalize `vector.transfer_write` operations to fit within SME tiles.
-struct LegalizeTransferWriteOp
+/// Legalize `vector.transfer_write` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferWriteOpsByDecomposition
     : public OneToNOpConversionPattern<vector::TransferWriteOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -230,15 +271,18 @@ struct LegalizeTransferWriteOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = writeOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     auto mask = writeOp.getMask();
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(writeOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     auto permutationMap = writeOp.getPermutationMap();
     if (!permutationMap.isPermutation())
-      return failure();
+      return rewriter.notifyMatchFailure(writeOp,
+                                         MATCH_FAILURE_NON_PERMUTATION_MAP);
 
     // Note: For 2D vector types the only non-identity permutation is a simple
     // tranpose [1, 0].
@@ -289,9 +333,11 @@ struct VectorLegalizationPass
         });
 
     // Note: High benefit to ensure masked outer products are lowered first.
-    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
-                 LegalizeTransferWriteOp>(converter, context);
+    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
+        converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+                 LegalizeTransferReadOpsByDecomposition,
+                 LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);
     scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
 



More information about the Mlir-commits mailing list