[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jan 30 04:48:01 PST 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/79152

>From abcbd078a9c0d2e08e8ea2174b7a3b640c9d57f8 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 13:51:11 +0000
Subject: [PATCH 1/5] [mlir][ArmSME] Add initial SME vector legalization pass

This adds a new pass (`-arm-sme-vector-legalization`) which legalizes
vector operations so that they can be lowered to ArmSME. This initial
patch adds decomposition for `vector.outerproduct`,
`vector.transfer_read`, and `vector.transfer_write` when they operate on
vector types larger than a single SME tile. For example, a [8]x[8]xf32
outer product would be decomposed into four [4]x[4]xf32 outer products,
which could then be lowered to ArmSME. These three ops have been picked
as supporting them alone allows lowering matmuls that use all ZA
accumulators to ArmSME.

For it to be possible to legalize a vector type it has to be a multiple
of an SME tile size, but other than that any shape can be used. E.g.
`vector<[8]x[8]xf32>`, `vector<[4]x[16]xf32>`, `vector<[16]x[4]xf32>`
can all be lowered to four `vector<[4]x[4]xf32>` operations.

In future, this pass will be extended with more SME-specific rewrites to
legalize unrolling the reduction dimension of matmuls (which is not
type-decomposition), which is why the pass has quite a general name.
---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |   3 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  19 ++
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |   7 +
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  22 ++
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |   5 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 308 ++++++++++++++++++
 .../Dialect/ArmSME/vector-legalization.mlir   | 268 +++++++++++++++
 .../Linalg/CPU/ArmSME/multi-tile-matmul.mlir  | 109 +++++++
 .../CPU/ArmSME/test-multi-tile-transpose.mlir | 171 ++++++++++
 9 files changed, 911 insertions(+), 1 deletion(-)
 create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
 create mode 100644 mlir/test/Dialect/ArmSME/vector-legalization.mlir
 create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index aef2959265a7..9ba8c4355125 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
+/// Pass that legalizes vectors so they can be lowered to ArmSME.
+std::unique_ptr<Pass> createVectorLegalizationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8d1ba6ed34e8..9a6f5446de00 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -122,4 +122,23 @@ def TileAllocation
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def VectorLegalization
+  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
+  let summary = "Legalize vectors for ArmSME";
+  let description = [{
+    This pass legalizes vector operations so that they can be lowered to ArmSME.
+    This includes decomposing operations that operate on vector types larger
+    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
+    tile-sized operations, as well as rewrites needed to get operations into
+    forms compatible with SME lowerings.
+  }];
+  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
+  let dependentDialects = [
+    "func::FuncDialect",
+    "arm_sme::ArmSMEDialect",
+    "vector::VectorDialect",
+    "arith::ArithDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 41702008ee48..e89be8ed81e0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,6 +56,13 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
+/// Returns true if `vType` is a multiple of an SME tile size. Note returns
+/// false if the `vType` exactly matches the size of an SME tile.
+bool isMultipleOfSMETileVectorType(VectorType vType);
+
+/// Creates a vector type for the SME tile of `elementType`.
+VectorType getSMETileTypeForElement(Type elementType);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index d03912105556..6a9e02218222 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -94,4 +94,26 @@ scf::ForOp createLoopOverTileSlices(
   return forOp;
 }
 
+bool isMultipleOfSMETileVectorType(VectorType vType) {
+  if (vType.getRank() != 2 || !vType.allDimsScalable())
+    return false;
+
+  auto elementType = vType.getElementType();
+  if (!isValidSMETileElementType(elementType))
+    return false;
+
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+
+  int64_t vectorRows = vType.getDimSize(0);
+  int64_t vectorCols = vType.getDimSize(1);
+
+  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
+         vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
+}
+
+VectorType getSMETileTypeForElement(Type elementType) {
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 96eb58442043..3c32fc2645ce 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   TileAllocation.cpp
+  VectorLegalization.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -9,10 +10,12 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRPass
   MLIRArmSMEDialect
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
   MLIRSCFDialect
-  MLIRPass
+  MLIRSCFTransforms
+  MLIRFuncTransforms
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
new file mode 100644
index 000000000000..a801ebe27413
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -0,0 +1,308 @@
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Transforms/OneToNTypeConversion.h"
+
+#define DEBUG_TYPE "arm-sme-vector-legalization"
+
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_VECTORLEGALIZATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace mlir::arm_sme
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+struct SMETile {
+  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
+  int row{0};
+  int col{0};
+  VectorType type;
+};
+
+/// Adds a constant scalable offset to `indices`. i.e. for 2D:
+/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
+                                                Location loc,
+                                                ValueRange indices,
+                                                ArrayRef<int> scalableOffset) {
+  auto vscale = builder.create<vector::VectorScaleOp>(loc);
+  return llvm::map_to_vector(
+      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+        auto [index, base] = pair;
+        auto offset = builder.create<arith::MulIOp>(
+            loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
+        return builder.create<arith::AddIOp>(loc, index, offset);
+      });
+}
+
+/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
+/// for one of the SME tiles it will decompose into.
+SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
+                                             ValueRange indices,
+                                             SMETile tileTile) {
+  return addConstantScalableOffset(builder, loc, indices,
+                                   {tileTile.row, tileTile.col});
+}
+
+/// Returns true if `mask` is generated by an operation that can be decomposed
+/// for SME. Currently, that is just no mask, or vector.create_mask.
+bool isSupportedMaskOp(Value mask) {
+  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
+}
+
+/// Extracts a mask for an SME tile from the mask of a larger vector type.
+Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
+                     SMETile tileTile) {
+  assert(isSupportedMaskOp(mask));
+  if (!mask)
+    return Value{};
+  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // coordinates where the mask ends. So we subtract where this tile starts,
+  // from the mask operands to get the parameters for this tile tile.
+  auto tileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
+  auto createTileMask = builder.create<vector::CreateMaskOp>(
+      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
+  return createTileMask.getResult();
+}
+
+/// Constructs an iterator that returns each SME tile (with coordinates)
+/// contained within a VectorType.
+auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
+                         VectorType smeTileType,
+                         bool transposeIndices = false) {
+  assert(isMultipleOfSMETileVectorType(type));
+  return llvm::map_range(
+      StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
+                                              smeTileType.getDimSize(1)}),
+      [=](auto indices) {
+        int row = int(indices[0]);
+        int col = int(indices[1]);
+        if (transposeIndices)
+          std::swap(row, col);
+        return SMETile{row, col, smeTileType};
+      });
+}
+
+/// Returns the number of SME tiles that fit into the a vector type.
+int getNumberOfSMETilesForVectorType(VectorType type) {
+  assert(isMultipleOfSMETileVectorType(type));
+  int64_t vectorRows = type.getDimSize(0);
+  int64_t vectorCols = type.getDimSize(1);
+  auto elementType = type.getElementType();
+  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
+  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
+}
+
+/// Legalize `vector.outerproduct` operations to fit within SME tiles.
+struct LegalizeVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::OuterProductOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = outerProductOp.getResultVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    Value mask;
+    Operation *rootOp = outerProductOp;
+    auto loc = outerProductOp.getLoc();
+    if (outerProductOp.isMasked()) {
+      auto maskOp = outerProductOp.getMaskingOp();
+      mask = maskOp.getMask();
+      rootOp = maskOp;
+    }
+
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    ValueRange accSMETiles = adaptor.getAcc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+
+    SmallVector<Value> resultSMETiles;
+    for (auto [index, tileTile] :
+         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto lhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+      auto rhs = rewriter.create<vector::ScalableExtractOp>(
+          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
+      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, tileType, lhs, rhs,
+          !accSMETiles.empty() ? accSMETiles[index] : Value{},
+          outerProductOp.getKind());
+
+      auto maskedOuterProduct =
+          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
+    }
+
+    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
+// get the help of the type conversion), but doing so results in the type
+// conversion adding target materializations in the `vector.mask` region
+// (invalid). This pattern matches on `vector.mask` then calls into the
+// `vector.outerproduct` pattern to work around this issue.
+struct LegalizeMaskedVectorOuterProductOp
+    : public OneToNOpConversionPattern<vector::MaskOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (auto outerProductOp =
+            llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
+      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
+          outerProductOp, rewriter);
+    }
+    return failure();
+  }
+};
+
+/// Legalize `vector.transfer_read` operations to fit within SME tiles.
+struct LegalizeTransferReadOp
+    : public OneToNOpConversionPattern<vector::TransferReadOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = readOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = readOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = readOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = readOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+
+    SmallVector<Value> resultSMETiles;
+    for (SMETile tileTile :
+         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto transferRead = rewriter.create<vector::TransferReadOp>(
+          loc, tileType, readOp.getSource(),
+          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+          readOp.getInBoundsAttr());
+      resultSMETiles.push_back(transferRead);
+    }
+
+    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    return success();
+  }
+};
+
+/// Legalize `vector.transfer_write` operations to fit within SME tiles.
+struct LegalizeTransferWriteOp
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return failure();
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask))
+      return failure();
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return failure();
+
+    // Note: For 2D vector types the only non-identity permutation is a simple
+    // tranpose [1, 0].
+    bool transposed = !permutationMap.isIdentity();
+
+    auto loc = writeOp.getLoc();
+    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto inputSMETiles = adaptor.getVector();
+
+    Value destTensorOrMemref = writeOp.getSource();
+    for (auto [index, tileTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
+      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+          loc, inputSMETiles[index], destTensorOrMemref,
+          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
+          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+      if (writeOp.hasPureTensorSemantics())
+        destTensorOrMemref = tileWrite.getResult();
+    }
+
+    if (writeOp.hasPureTensorSemantics())
+      rewriter.replaceOp(writeOp, destTensorOrMemref);
+    else
+      rewriter.eraseOp(writeOp);
+
+    return success();
+  }
+};
+
+struct VectorLegalizationPass
+    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
+  void runOnOperation() override {
+    auto *context = &getContext();
+    OneToNTypeConverter converter;
+    RewritePatternSet patterns(context);
+
+    converter.addConversion([](Type type) { return type; });
+    converter.addConversion(
+        [](VectorType vectorType,
+           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
+          if (!isMultipleOfSMETileVectorType(vectorType))
+            return std::nullopt;
+          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(tileTileCount, tileType);
+          return success();
+        });
+
+    // Note: High benefit to ensure masked outer products are lowered first.
+    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
+                 LegalizeTransferWriteOp>(converter, context);
+    populateFuncTypeConversionPatterns(converter, patterns);
+    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
+
+    if (failed(applyPartialOneToNConversion(getOperation(), converter,
+                                            std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
+  return std::make_unique<VectorLegalizationPass>();
+}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
new file mode 100644
index 000000000000..a20abeefedcf
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -0,0 +1,268 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @outerproduct_f32_scalable_8x8_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf32>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[8]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_8x8_no_acc(%lhs: vector<[8]xf32>, %rhs: vector<[8]xf32>) -> vector<[8]x[8]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[8]xf32>
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf32>, vector<[8]xf32>
+  return %0 : vector<[8]x[8]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_scalable_4x16_acc(
+// CHECK-SAME:                                      %[[LHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[RHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                      %[[ACC_0:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_1:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_2:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                      %[[ACC_3:[A-Za-z0-9]*]]: vector<[4]x[4]xf32>)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_scalable_4x16_acc(%lhs: vector<[4]xf32>, %rhs: vector<[16]xf32>, %acc: vector<[4]x[16]xf32>) -> vector<[4]x[16]xf32>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_2:.*]] = vector.scalable.extract %[[RHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_3:.*]] = vector.scalable.extract %[[RHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]], %[[ACC_0]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]], %[[ACC_1]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_2]], %[[ACC_2]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_3]], %[[ACC_3]] {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %lhs, %rhs, %acc : vector<[4]xf32>, vector<[16]xf32>
+  return %0 : vector<[4]x[16]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @outerproduct_f32_masked_scalable_16x4(
+// CHECK-SAME:                                         %[[LHS:.*]]: vector<[16]xf32>,
+// CHECK-SAME:                                         %[[RHS:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                         %[[LHS_DIM:.*]]: index,
+// CHECK-SAME:                                         %[[RHS_DIM:.*]]: index)
+// CHECK-SAME: -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+func.func @outerproduct_f32_masked_scalable_16x4(%lhs: vector<[16]xf32>, %rhs: vector<[4]xf32>, %lhs_dim: index, %rhs_dim: index) -> vector<[16]x[4]xf32>
+{
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_4:.*]] = arith.constant -4 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[MINUS_12:.*]] = arith.constant -12 : index
+  // CHECK-DAG: %[[MINUS_4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_4]] : index
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[MINUS_12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_12]] : index
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][8] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][12] : vector<[4]xf32> from vector<[16]xf32>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[4]xf32> from vector<[4]xf32>
+  // CHECK-DAG: %[[MASK_0:.*]] = vector.create_mask %[[LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_1_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_4_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_1:.*]] = vector.create_mask %[[TILE_1_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_2_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_2:.*]] = vector.create_mask %[[TILE_2_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[TILE_3_LHS_DIM:.*]] = arith.addi %[[LHS_DIM]], %[[MINUS_12_VSCALE]] : index
+  // CHECK-DAG: %[[MASK_3:.*]] = vector.create_mask %[[TILE_3_LHS_DIM]], %[[RHS_DIM]] : vector<[4]x[4]xi1>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.mask %[[MASK_0]] { vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.mask %[[MASK_1]] { vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.mask %[[MASK_2]] { vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.mask %[[MASK_3]] { vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]] : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+  %mask = vector.create_mask %lhs_dim, %rhs_dim : vector<[16]x[4]xi1>
+  %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs : vector<[16]xf32>, vector<[4]xf32> } : vector<[16]x[4]xi1> -> vector<[16]x[4]xf32>
+  return %0 : vector<[16]x[4]xf32>
+}
+
+// -----
+
+/// This demonstrates a rectangular tiling that uses all f64 accumulators.
+
+// CHECK-LABEL: @outerproduct_f64_scalable_8x4_no_acc(
+// CHECK-SAME:                                        %[[LHS:.*]]: vector<[8]xf64>,
+// CHECK-SAME:                                        %[[RHS:.*]]: vector<[4]xf64>)
+// CHECK-SAME: -> (vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>)
+func.func @outerproduct_f64_scalable_8x4_no_acc(%lhs: vector<[8]xf64>, %rhs: vector<[4]xf64>) -> vector<[8]x[4]xf64>
+{
+  // CHECK-DAG: %[[LHS_0:.*]] = vector.scalable.extract %[[LHS]][0] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_1:.*]] = vector.scalable.extract %[[LHS]][2] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_2:.*]] = vector.scalable.extract %[[LHS]][4] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[LHS_3:.*]] = vector.scalable.extract %[[LHS]][6] : vector<[2]xf64> from vector<[8]xf64>
+  // CHECK-DAG: %[[RHS_0:.*]] = vector.scalable.extract %[[RHS]][0] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RHS_1:.*]] = vector.scalable.extract %[[RHS]][2] : vector<[2]xf64> from vector<[4]xf64>
+  // CHECK-DAG: %[[RES_0:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_1:.*]] = vector.outerproduct %[[LHS_0]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_2:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_3:.*]] = vector.outerproduct %[[LHS_1]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_4:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_5:.*]] = vector.outerproduct %[[LHS_2]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_6:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_0]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-DAG: %[[RES_7:.*]] = vector.outerproduct %[[LHS_3]], %[[RHS_1]] : vector<[2]xf64>, vector<[2]xf64>
+  // CHECK-NEXT: return %[[RES_0]], %[[RES_1]], %[[RES_2]], %[[RES_3]], %[[RES_4]], %[[RES_5]], %[[RES_6]], %[[RES_7]] : vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>, vector<[2]x[2]xf64>
+  %0 = vector.outerproduct %lhs, %rhs : vector<[8]xf64>, vector<[4]xf64>
+  return %0 : vector<[8]x[4]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_f32_scalable_8x8(
+// CHECK-SAME:                                  %[[SRC:.*]]: memref<?x?xi32>)
+// CHECK-SAME: -> (vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>)
+func.func @transfer_read_f32_scalable_8x8(%src: memref<?x?xi32>) -> vector<[8]x[8]xi32>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[TOP_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[TOP_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C0]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-DAG: %[[BOTTOM_RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C4_VSCALE]], %[[C4_VSCALE]]], %[[C0_I32]] {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[TOP_LEFT]], %[[TOP_RIGHT]], %[[BOTTOM_LEFT]], %[[BOTTOM_RIGHT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_i16_scalable_8x16_masked(
+// CHECK-SAME:                                          %[[SRC:.*]]: memref<?x?xi16>,
+// CHECK-SAME:                                          %[[DIM0:.*]]: index,
+// CHECK-SAME:                                          %[[DIM1:.*]]: index)
+// CHECK-SAME: -> (vector<[8]x[8]xi16>, vector<[8]x[8]xi16>)
+func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0: index, %dim1: index) -> vector<[8]x[16]xi16>
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[MINUS_8:.*]] = arith.constant -8 : index
+  // CHECK-DAG: %[[C0_I16:.*]] = arith.constant 0 : i16
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[MINUS_8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[MINUS_8]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[RIGHT_DIM_1:.*]] = arith.addi %[[DIM1]], %[[MINUS_8_VSCALE]] : index
+  // CHECK-DAG: %[[LEFT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[DIM1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[RIGHT_MASK:.*]] = vector.create_mask %[[DIM0]], %[[RIGHT_DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-DAG: %[[LEFT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[C0_I16]], %[[LEFT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-DAG: %[[RIGHT:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[C0_I16]], %[[RIGHT_MASK]] {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK-NEXT: return %[[LEFT]], %[[RIGHT]] : vector<[8]x[8]xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0 : i16
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[16]xi1>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[16]xi16>
+  return %0 : vector<[8]x[16]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_f16_scalable_16x8(
+// CHECK-SAME:                                    %[[DEST:.*]]: memref<?x?xf16>,
+// CHECK-SAME:                                    %[[TOP:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME:                                    %[[BOTTOM:.*]]: vector<[8]x[8]xf16>)
+func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
+  return
+}
+
+// -----
+
+/// This is already a legal type. It should be ignored.
+
+// CHECK-LABEL: @transfer_write_i8_scalable_16x16_masked
+func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec: vector<[16]x[16]xi8>, %dim0: index, %dim1: index)
+{
+  // CHECK: vector.transfer_write {{.*}} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim0 : vector<[16]x[16]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
+// CHECK-SAME:                                        %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                        %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_f32_scalable_4x16_via_write(
+// CHECK-SAME:                                         %[[SRC:.*]]: memref<?x?xf32>,
+// CHECK-SAME:                                         %[[DEST:.*]]: memref<?x?xf32>)
+func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK-DAG: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+  // CHECK-DAG: %[[C12_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C12]] : index
+  // CHECK-DAG: %[[TILE_0:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: return
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[16]xf32>
+  vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
new file mode 100644
index 000000000000..fb192a829173
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -0,0 +1,109 @@
+// RUN: mlir-opt %s \
+// RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
+// RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
+// RUN:   -arm-sme-vector-legalization -canonicalize -cse \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf=full-unroll -convert-arm-sme-to-llvm \
+// RUN:   -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
+  %res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
+                       outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
+  %xf = tensor.cast %res : tensor<?x?xf32> to tensor<*xf32>
+  call @printMemrefF32(%xf) : (tensor<*xf32>) -> ()
+  return
+}
+
+func.func @main() {
+  /// Set SVL to 128-bit. This ensures this small matmul will use all four
+  /// 32-bit SME virtual tiles.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : i32
+  %c7 = arith.constant 7 : index
+
+  %A = arith.constant dense<[
+    [ 1.,  8., 15., 22., 29., 36., 43., 50., 57., 64., 71., 78., 85.],
+    [ 2.,  9., 16., 23., 30., 37., 44., 51., 58., 65., 72., 79., 86.],
+    [ 3., 10., 17., 24., 31., 38., 45., 52., 59., 66., 73., 80., 87.],
+    [ 4., 11., 18., 25., 32., 39., 46., 53., 60., 67., 74., 81., 88.],
+    [ 5., 12., 19., 26., 33., 40., 47., 54., 61., 68., 75., 82., 89.],
+    [ 6., 13., 20., 27., 34., 41., 48., 55., 62., 69., 76., 83., 90.],
+    [ 7., 14., 21., 28., 35., 42., 49., 56., 63., 70., 77., 84., 91.]
+  ]> : tensor<7x13xf32>
+
+  %B_init = tensor.empty() : tensor<13x7xf32>
+  %B = linalg.transpose ins(%A: tensor<7x13xf32>)
+                        outs(%B_init: tensor<13x7xf32>) permutation = [1, 0]
+
+  %A_dyn = tensor.cast %A : tensor<7x13xf32> to tensor<?x?xf32>
+  %B_dyn = tensor.cast %B : tensor<13x7xf32> to tensor<?x?xf32>
+
+  %C_init = bufferization.alloc_tensor(%c7, %c7) : tensor<?x?xf32>
+  %C = linalg.fill ins(%c0 : i32) outs(%C_init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  // CHECK: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [7, 7] strides = [7, 1] data =
+  // CHECK: [32955, 33514, 34073, 34632, 35191, 35750, 36309]
+  // CHECK: [33514, 34086, 34658, 35230, 35802, 36374, 36946]
+  // CHECK: [34073, 34658, 35243, 35828, 36413, 36998, 37583]
+  // CHECK: [34632, 35230, 35828, 36426, 37024, 37622, 38220]
+  // CHECK: [35191, 35802, 36413, 37024, 37635, 38246, 38857]
+  // CHECK: [35750, 36374, 36998, 37622, 38246, 38870, 39494]
+  // CHECK: [36309, 36946, 37583, 38220, 38857, 39494, 40131]
+  call @matmul(%A_dyn, %B_dyn, %C) : (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> ()
+
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%module : !transform.any_op {transform.consumed}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %module
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 1: Tile for size [8] x [8], which corresponds to (2 x SVLs) x (2 x SVLs),
+    // where SVLs is the number of 32-bit elements in a vector of SVL bits.
+    // This uses all four 32-bit SME virtual tiles.
+    %tiled_linalg_op, %loop_i, %loop_j, %loop_k = transform.structured.tile_using_for %matmul[[8], [8], 1]
+      : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.op<"scf.for">, !transform.op<"scf.for">)
+
+    // Step 2: Vectorize.
+    transform.structured.vectorize %tiled_linalg_op vector_sizes [[8], [8], 1]
+      : !transform.any_op
+
+    // Step 3: Bufferize ahead of TransferReadDropUnitDimsPattern, which
+    // currently only supports memrefs.
+    %bufferize = transform.bufferization.one_shot_bufferize %module
+      {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
+
+    %func = transform.structured.match ops{["func.func"]} in %bufferize
+      : (!transform.any_op) -> !transform.any_op
+
+    // Step 4: Lower vector.multi_reduction to vector.contract (+ some helpful patterns).
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_masked_transfers
+      transform.apply_patterns.vector.transfer_permutation_patterns
+      transform.apply_patterns.vector.reduction_to_contract
+    } : !transform.any_op
+
+    // Step 5: Lower vector.contract to vector.outerproduct. Also drop unit
+    // dims, specifically to prevent vector.transfer_read of vector<[8]x1xf32>,
+    // which can't be lowered in generic path.
+    transform.apply_patterns to %func {
+      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+      transform.apply_patterns.vector.lower_masks
+      transform.apply_patterns.vector.rank_reducing_subview_patterns
+    } : !transform.any_op
+
+    transform.yield
+  }
+}
+
+func.func private @printMemrefF32(%ptr : tensor<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
new file mode 100644
index 000000000000..7821f7cd865d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-arm-sme-to-llvm \
+// RUN:   -convert-vector-to-llvm=enable-arm-sve \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+  %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+  scf.for %i = %c0 to %rows step %c1 {
+    scf.for %j = %c0 to %cols step %c1 {
+      %n = arith.addi %i, %c1 : index
+      %val = arith.index_cast %n : index to i32
+      %val_f32 = arith.sitofp %val : i32 to f32
+      memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+    }
+  }
+  return
+}
+
+func.func @testTransposedReadWithMask() {
+  %in = memref.alloca() : memref<4x16xf32>
+  %out = memref.alloca() : memref<16x4xf32>
+
+  %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  /// A mask so we only read the first 2x15 portion of %in.
+  %maskRows = arith.constant 2 : index
+  %maskCols = arith.constant 15 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A vector.transfer_read with a transpose permutation map. So the input data
+  /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+  %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+    {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// Write the vector back to a memref (that also has a transposed shape).
+  vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 15x2) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @testTransposedWriteWithMask() {
+  %in = memref.alloca() : memref<16x4xf32>
+  %out = memref.alloca() : memref<4x16xf32>
+
+  %fill = arith.constant -1.0 : f32
+  linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+
+  %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A regular read.
+  %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+    : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// A mask so we only write the first 3x8 portion of transpose(%in).
+  %maskRows = arith.constant 3 : index
+  %maskCols = arith.constant 8 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+  /// Write out the data with a transpose. Here (like the read test) the mask
+  /// matches the shape of the memory, not the vector.
+  vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+    : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 3x8) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @main() {
+  /// Set the SVL to 128-bits (i.e. vscale = 1).
+  /// This test is for the use of multiple tiles rather than scalability.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4]
+  //
+  //      CHECK: (Masked 15x2) transposed result:
+  //      CHECK:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [1,   2,   0,   0]
+  // CHECK-NEXT:  [0,   0,   0,   0]
+  func.call @testTransposedReadWithMask() : () -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4]
+  // CHECK-NEXT:  [5,   5,   5,   5]
+  // CHECK-NEXT:  [6,   6,   6,   6]
+  // CHECK-NEXT:  [7,   7,   7,   7]
+  // CHECK-NEXT:  [8,   8,   8,   8]
+  // CHECK-NEXT:  [9,   9,   9,   9]
+  // CHECK-NEXT:  [10,   10,   10,   10]
+  // CHECK-NEXT:  [11,   11,   11,   11]
+  // CHECK-NEXT:  [12,   12,   12,   12]
+  // CHECK-NEXT:  [13,   13,   13,   13]
+  // CHECK-NEXT:  [14,   14,   14,   14]
+  // CHECK-NEXT:  [15,   15,   15,   15]
+  // CHECK-NEXT:  [16,   16,   16,   16]
+  //
+  //      CHECK: (Masked 3x8) transposed result:
+  //      CHECK:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  // CHECK-NEXT:  [-1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
+  func.call @testTransposedWriteWithMask() : () -> ()
+
+  return
+}
+
+func.func private @printMemrefF32(%ptr : memref<*xf32>)
+func.func private @setArmSVLBits(%bits : i32)

>From aad4e7ec6d5954fafc85ceae91823db365c6cc5c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 25 Jan 2024 17:40:16 +0000
Subject: [PATCH 2/5] Fixups

---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  4 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 98 ++++++++++++++-----
 2 files changed, 74 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index e89be8ed81e0..027ad8954f92 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -56,8 +56,8 @@ scf::ForOp createLoopOverTileSlices(
     PatternRewriter &rewriter, Location loc, Value initTile,
     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
 
-/// Returns true if `vType` is a multiple of an SME tile size. Note returns
-/// false if the `vType` exactly matches the size of an SME tile.
+/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
+/// the `vType` exactly matches the size of an SME tile.
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index a801ebe27413..526246b646da 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -1,3 +1,11 @@
+//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -19,6 +27,14 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+// Common match failure reasons.
+static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
+    "op vector size is not multiple of SME tiles");
+static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
+    "op mask is unsupported for legalization/decomposition");
+static constexpr StringLiteral
+    MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
+
 struct SMETile {
   // Note: The units of (row, col) are vscale (as SME tiles are scalable).
   int row{0};
@@ -26,8 +42,9 @@ struct SMETile {
   VectorType type;
 };
 
-/// Adds a constant scalable offset to `indices`. i.e. for 2D:
-/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
+/// Adds a constant scalable offset to `indices` (which are of equal length).
+/// For example, in the 2D case this would return:
+// { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
                                                 Location loc,
                                                 ValueRange indices,
@@ -42,8 +59,20 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
       });
 }
 
-/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
-/// for one of the SME tiles it will decompose into.
+/// Remaps `indices` (e.g. from a load/store) for a larger vector type to
+/// indices for one of the SME tiles it will decompose into.
+///
+/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
+/// indices for each tile would need to be remapped as follows:
+///
+/// initial indices = [a,b], inital size = 8x8, target size = 4x4
+/// ┌─────────────┬─────────────┐
+/// │[a,b]        │[a,b+4]      │
+/// │             │             │
+/// ├─────────────┼─────────────┤
+/// │[a+4,b]      │[a+4,b+4]    │
+/// │             │             │
+/// └─────────────┴─────────────┘
 SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
                                              ValueRange indices,
                                              SMETile tileTile) {
@@ -64,7 +93,7 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
   if (!mask)
     return Value{};
   auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
-  // The the operands of `vector.create_mask` (from a 2D perspective) are the
+  // The operands of `vector.create_mask` (from a 2D perspective) are the
   // coordinates where the mask ends. So we subtract where this tile starts,
   // from the mask operands to get the parameters for this tile tile.
   auto tileMaskDims = addConstantScalableOffset(
@@ -75,7 +104,9 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
 }
 
 /// Constructs an iterator that returns each SME tile (with coordinates)
-/// contained within a VectorType.
+/// contained within a VectorType. For example, if decomposing an [8]x[8] into
+/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
+/// (4, 4).
 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                          VectorType smeTileType,
                          bool transposeIndices = false) {
@@ -92,7 +123,8 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
       });
 }
 
-/// Returns the number of SME tiles that fit into the a vector type.
+/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
+/// `type`.
 int getNumberOfSMETilesForVectorType(VectorType type) {
   assert(isMultipleOfSMETileVectorType(type));
   int64_t vectorRows = type.getDimSize(0);
@@ -102,8 +134,9 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
-/// Legalize `vector.outerproduct` operations to fit within SME tiles.
-struct LegalizeVectorOuterProductOp
+/// Legalize `vector.outerproduct` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeVectorOuterProductOpsByDecomposition
     : public OneToNOpConversionPattern<vector::OuterProductOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -112,7 +145,8 @@ struct LegalizeVectorOuterProductOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = outerProductOp.getResultVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     Value mask;
     Operation *rootOp = outerProductOp;
@@ -124,7 +158,8 @@ struct LegalizeVectorOuterProductOp
     }
 
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(outerProductOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     ValueRange accSMETiles = adaptor.getAcc();
     auto tileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -159,7 +194,7 @@ struct LegalizeVectorOuterProductOp
 // conversion adding target materializations in the `vector.mask` region
 // (invalid). This pattern matches on `vector.mask` then calls into the
 // `vector.outerproduct` pattern to work around this issue.
-struct LegalizeMaskedVectorOuterProductOp
+struct LegalizeMaskedVectorOuterProductOpsByDecomposition
     : public OneToNOpConversionPattern<vector::MaskOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -168,7 +203,8 @@ struct LegalizeMaskedVectorOuterProductOp
                   OneToNPatternRewriter &rewriter) const override {
     if (auto outerProductOp =
             llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
-      LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
+      LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
+                                                           getContext());
       return static_cast<RewritePattern &>(pattern).matchAndRewrite(
           outerProductOp, rewriter);
     }
@@ -176,8 +212,9 @@ struct LegalizeMaskedVectorOuterProductOp
   }
 };
 
-/// Legalize `vector.transfer_read` operations to fit within SME tiles.
-struct LegalizeTransferReadOp
+/// Legalize `vector.transfer_read` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferReadOpsByDecomposition
     : public OneToNOpConversionPattern<vector::TransferReadOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -186,15 +223,18 @@ struct LegalizeTransferReadOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = readOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     auto mask = readOp.getMask();
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(readOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     auto permutationMap = readOp.getPermutationMap();
     if (!permutationMap.isPermutation())
-      return failure();
+      return rewriter.notifyMatchFailure(readOp,
+                                         MATCH_FAILURE_NON_PERMUTATION_MAP);
 
     // Note: For 2D vector types the only non-identity permutation is a simple
     // tranpose [1, 0].
@@ -220,8 +260,9 @@ struct LegalizeTransferReadOp
   }
 };
 
-/// Legalize `vector.transfer_write` operations to fit within SME tiles.
-struct LegalizeTransferWriteOp
+/// Legalize `vector.transfer_write` operations to fit within SME tiles by
+/// decomposing them into tile-sized operations.
+struct LegalizeTransferWriteOpsByDecomposition
     : public OneToNOpConversionPattern<vector::TransferWriteOp> {
   using OneToNOpConversionPattern::OneToNOpConversionPattern;
 
@@ -230,15 +271,18 @@ struct LegalizeTransferWriteOp
                   OneToNPatternRewriter &rewriter) const override {
     auto vectorType = writeOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
-      return failure();
+      return rewriter.notifyMatchFailure(
+          writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
 
     auto mask = writeOp.getMask();
     if (!isSupportedMaskOp(mask))
-      return failure();
+      return rewriter.notifyMatchFailure(writeOp,
+                                         MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     auto permutationMap = writeOp.getPermutationMap();
     if (!permutationMap.isPermutation())
-      return failure();
+      return rewriter.notifyMatchFailure(writeOp,
+                                         MATCH_FAILURE_NON_PERMUTATION_MAP);
 
     // Note: For 2D vector types the only non-identity permutation is a simple
     // tranpose [1, 0].
@@ -289,9 +333,11 @@ struct VectorLegalizationPass
         });
 
     // Note: High benefit to ensure masked outer products are lowered first.
-    patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
-                 LegalizeTransferWriteOp>(converter, context);
+    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
+        converter, context, 1024);
+    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+                 LegalizeTransferReadOpsByDecomposition,
+                 LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);
     scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
 

>From 0fc3fbd479ada225e972366a72e0602dbc89aeb4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Jan 2024 10:57:14 +0000
Subject: [PATCH 3/5] Fixups

---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |   4 +
 .../ArmSME/Transforms/TileAllocation.cpp      |  11 +-
 .../ArmSME/Transforms/VectorLegalization.cpp  | 131 +++++++++++-------
 3 files changed, 88 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 9a6f5446de00..44269344877d 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -131,6 +131,10 @@ def VectorLegalization
     than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
     tile-sized operations, as well as rewrites needed to get operations into
     forms compatible with SME lowerings.
+
+    Note: Decomposition is currently limited to vector types that are an exact
+    multiple of SME tiles. That is scalable in two dimensions, with both the
+    rows and columns divisible by the SVE vector length for the element type.
   }];
   let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
   let dependentDialects = [
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4acb2a8fb7b5..19994bf47873 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -244,8 +244,9 @@ struct AssignTileIDsPattern
 
     // Set all operations dependent on `tileOp` to use the same tile ID.
     // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't solve
-    // cases like this (%tileA and %tileB come from different root operations):
+    // example, as this only allocates tile IDs to existing ops, it can't
+    // solve cases like this (%tileA and %tileB come from different root
+    // operations):
     //
     // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
     //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
@@ -254,9 +255,9 @@ struct AssignTileIDsPattern
     // }
     //
     // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
-    // on the %some_cond).
-    // Find all the ops that (transitively) depend on this tile.
+    // scf.if, and moving the contents of %tileA or %tileB to result tile
+    // (based on the %some_cond). Find all the ops that (transitively) depend
+    // on this tile.
     SetVector<Operation *> dependantOps;
     findDependantOps(tileOp->getResult(0), dependantOps);
     auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 526246b646da..e533ff7627fa 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -5,6 +5,14 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// This pass legalizes vector operations so they can be lowered to ArmSME.
+// Currently, this only implements the decomposition of vector operations that
+// use vector sizes larger than an SME tile, into multiple SME-sized operations.
+//
+// Note: In the context of this pass 'tile' always refers to an SME tile.
+//
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
@@ -35,23 +43,37 @@ static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
 static constexpr StringLiteral
     MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
 
-struct SMETile {
+/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
+/// larger vector type. The (`row`, `col`) are the position of the tile in the
+/// original vector type. For example for an [8]x[8] tile would have four
+/// [4]x[4] sub-tiles.
+///
+///           8 x vscale
+/// ┌─────────────┬─────────────┐
+/// │(0,0)        │(0,4)        │
+/// │             │             │
+/// ├─────────────┼─────────────┤ 8 x vscale
+/// │(4,0)        │(4,4)        │
+/// │             │             │
+/// └─────────────┴─────────────┘
+struct SMESubTile {
   // Note: The units of (row, col) are vscale (as SME tiles are scalable).
   int row{0};
   int col{0};
+  // The SME tile type.
   VectorType type;
 };
 
-/// Adds a constant scalable offset to `indices` (which are of equal length).
-/// For example, in the 2D case this would return:
+/// Adds a constant elementwise scalable offset to `indices` (which are of equal
+/// length). For example, in the 2D case this would return:
 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
                                                 Location loc,
                                                 ValueRange indices,
-                                                ArrayRef<int> scalableOffset) {
+                                                ArrayRef<int> scalableOffsets) {
   auto vscale = builder.create<vector::VectorScaleOp>(loc);
   return llvm::map_to_vector(
-      llvm::zip_equal(indices, scalableOffset), [&](auto pair) -> Value {
+      llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
         auto [index, base] = pair;
         auto offset = builder.create<arith::MulIOp>(
             loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
@@ -59,11 +81,11 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
       });
 }
 
-/// Remaps `indices` (e.g. from a load/store) for a larger vector type to
-/// indices for one of the SME tiles it will decompose into.
+/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
+/// indices for one of the SME sub-tiles it will decompose into.
 ///
 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
-/// indices for each tile would need to be remapped as follows:
+/// indices for each tile would need to be adjusted as follows:
 ///
 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
 /// ┌─────────────┬─────────────┐
@@ -73,11 +95,11 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
 /// │[a+4,b]      │[a+4,b+4]    │
 /// │             │             │
 /// └─────────────┴─────────────┘
-SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
-                                             ValueRange indices,
-                                             SMETile tileTile) {
+SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
+                                           ValueRange indices,
+                                           SMESubTile smeTile) {
   return addConstantScalableOffset(builder, loc, indices,
-                                   {tileTile.row, tileTile.col});
+                                   {smeTile.row, smeTile.col});
 }
 
 /// Returns true if `mask` is generated by an operation that can be decomposed
@@ -86,21 +108,21 @@ bool isSupportedMaskOp(Value mask) {
   return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
 }
 
-/// Extracts a mask for an SME tile from the mask of a larger vector type.
+/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
-                     SMETile tileTile) {
+                     SMESubTile smeTile) {
   assert(isSupportedMaskOp(mask));
   if (!mask)
     return Value{};
   auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
   // The operands of `vector.create_mask` (from a 2D perspective) are the
   // coordinates where the mask ends. So we subtract where this tile starts,
-  // from the mask operands to get the parameters for this tile tile.
-  auto tileMaskDims = addConstantScalableOffset(
-      builder, loc, createMask.getOperands(), {-tileTile.row, -tileTile.col});
-  auto createTileMask = builder.create<vector::CreateMaskOp>(
-      loc, tileTile.type.clone(builder.getI1Type()), tileMaskDims);
-  return createTileMask.getResult();
+  // from the mask operands to get the parameters for this sub-tile.
+  auto smeTileMaskDims = addConstantScalableOffset(
+      builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
+  auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
+      loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
+  return smeTileCreateMask.getResult();
 }
 
 /// Constructs an iterator that returns each SME tile (with coordinates)
@@ -110,7 +132,8 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                          VectorType smeTileType,
                          bool transposeIndices = false) {
-  assert(isMultipleOfSMETileVectorType(type));
+  assert(isMultipleOfSMETileVectorType(type) &&
+         "`type` not multiple of SME tiles");
   return llvm::map_range(
       StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
                                               smeTileType.getDimSize(1)}),
@@ -119,14 +142,15 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
         int col = int(indices[1]);
         if (transposeIndices)
           std::swap(row, col);
-        return SMETile{row, col, smeTileType};
+        return SMESubTile{row, col, smeTileType};
       });
 }
 
 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
 /// `type`.
 int getNumberOfSMETilesForVectorType(VectorType type) {
-  assert(isMultipleOfSMETileVectorType(type));
+  assert(isMultipleOfSMETileVectorType(type) &&
+         "`type` not multiple of SME tiles");
   int64_t vectorRows = type.getDimSize(0);
   int64_t vectorCols = type.getDimSize(1);
   auto elementType = type.getElementType();
@@ -162,25 +186,25 @@ struct LegalizeVectorOuterProductOpsByDecomposition
                                          MATCH_FAILURE_UNSUPPORTED_MASK_OP);
 
     ValueRange accSMETiles = adaptor.getAcc();
-    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
-    VectorType sliceType = VectorType::Builder(tileType).dropDim(0);
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
 
     SmallVector<Value> resultSMETiles;
-    for (auto [index, tileTile] :
-         llvm::enumerate(decomposeToSMETiles(rewriter, vectorType, tileType))) {
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
 
-      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
+      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
       auto lhs = rewriter.create<vector::ScalableExtractOp>(
-          loc, sliceType, outerProductOp.getLhs(), tileTile.row);
+          loc, sliceType, outerProductOp.getLhs(), smeTile.row);
       auto rhs = rewriter.create<vector::ScalableExtractOp>(
-          loc, sliceType, outerProductOp.getRhs(), tileTile.col);
-      auto tileOuterProduct = rewriter.create<vector::OuterProductOp>(
-          loc, tileType, lhs, rhs,
+          loc, sliceType, outerProductOp.getRhs(), smeTile.col);
+      auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
+          loc, smeTileType, lhs, rhs,
           !accSMETiles.empty() ? accSMETiles[index] : Value{},
           outerProductOp.getKind());
 
       auto maskedOuterProduct =
-          vector::maskOperation(rewriter, tileOuterProduct, tileMask);
+          vector::maskOperation(rewriter, smeOuterProduct, smeMask);
       resultSMETiles.push_back(maskedOuterProduct->getResult(0));
     }
 
@@ -241,18 +265,18 @@ struct LegalizeTransferReadOpsByDecomposition
     bool transposed = !permutationMap.isIdentity();
 
     auto loc = readOp.getLoc();
-    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
 
     SmallVector<Value> resultSMETiles;
-    for (SMETile tileTile :
-         decomposeToSMETiles(rewriter, vectorType, tileType, transposed)) {
-      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
-      auto transferRead = rewriter.create<vector::TransferReadOp>(
-          loc, tileType, readOp.getSource(),
-          remapIndicesForSMETile(rewriter, loc, readOp.getIndices(), tileTile),
-          readOp.getPermutationMapAttr(), readOp.getPadding(), tileMask,
+    for (SMESubTile smeTile :
+         decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
+      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
+      auto smeRead = rewriter.create<vector::TransferReadOp>(
+          loc, smeTileType, readOp.getSource(),
+          getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
+          readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
           readOp.getInBoundsAttr());
-      resultSMETiles.push_back(transferRead);
+      resultSMETiles.push_back(smeRead);
     }
 
     rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
@@ -289,19 +313,19 @@ struct LegalizeTransferWriteOpsByDecomposition
     bool transposed = !permutationMap.isIdentity();
 
     auto loc = writeOp.getLoc();
-    auto tileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
     auto inputSMETiles = adaptor.getVector();
 
     Value destTensorOrMemref = writeOp.getSource();
-    for (auto [index, tileTile] : llvm::enumerate(
-             decomposeToSMETiles(rewriter, vectorType, tileType, transposed))) {
-      auto tileMask = extractSMEMask(rewriter, loc, mask, tileTile);
-      auto tileWrite = rewriter.create<vector::TransferWriteOp>(
+    for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
+             rewriter, vectorType, smeTileType, transposed))) {
+      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
+      auto smeWrite = rewriter.create<vector::TransferWriteOp>(
           loc, inputSMETiles[index], destTensorOrMemref,
-          remapIndicesForSMETile(rewriter, loc, writeOp.getIndices(), tileTile),
-          writeOp.getPermutationMapAttr(), tileMask, writeOp.getInBoundsAttr());
+          getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
+          writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
       if (writeOp.hasPureTensorSemantics())
-        destTensorOrMemref = tileWrite.getResult();
+        destTensorOrMemref = smeWrite.getResult();
     }
 
     if (writeOp.hasPureTensorSemantics())
@@ -326,9 +350,10 @@ struct VectorLegalizationPass
            SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
           if (!isMultipleOfSMETileVectorType(vectorType))
             return std::nullopt;
-          auto tileTileCount = getNumberOfSMETilesForVectorType(vectorType);
-          auto tileType = getSMETileTypeForElement(vectorType.getElementType());
-          types = SmallVector<Type>(tileTileCount, tileType);
+          auto smeTileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto smeTileType =
+              getSMETileTypeForElement(vectorType.getElementType());
+          types = SmallVector<Type>(smeTileTileCount, smeTileType);
           return success();
         });
 

>From 2dede5ef2656feed54cfa0f4e57b82151f5f7744 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Jan 2024 14:47:18 +0000
Subject: [PATCH 4/5] Fixups

---
 .../ArmSME/Transforms/VectorLegalization.cpp  |  5 +-
 .../Linalg/CPU/ArmSME/multi-tile-matmul.mlir  |  6 +++
 .../CPU/ArmSME/test-multi-tile-transpose.mlir | 46 +++++++++----------
 3 files changed, 31 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index e533ff7627fa..d436802f98a7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -45,8 +45,8 @@ static constexpr StringLiteral
 
 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
 /// larger vector type. The (`row`, `col`) are the position of the tile in the
-/// original vector type. For example for an [8]x[8] tile would have four
-/// [4]x[4] sub-tiles.
+/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
+/// sub-tiles, we would have:
 ///
 ///           8 x vscale
 /// ┌─────────────┬─────────────┐
@@ -104,6 +104,7 @@ SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
 
 /// Returns true if `mask` is generated by an operation that can be decomposed
 /// for SME. Currently, that is just no mask, or vector.create_mask.
+/// TODO: Add support for vector.constant_mask once required for SME.
 bool isSupportedMaskOp(Value mask) {
   return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
 }
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index fb192a829173..327f237ba894 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -12,6 +12,12 @@
 // RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
 // RUN: FileCheck %s
 
+/// This is very similar to the SME matmul.mlir test, except that it uses a tile
+/// size of [8]x[8]xf32, which is larger than a 32-bit SME virtual tile, which
+/// would be [4]x[4]xf32. The [8]x[8] tile will be decomposed into four
+/// by the `-arm-sme-vector-legalization` pass, which should then use all 32-bit
+/// SME accumulators.
+
 func.func @matmul(%A : tensor<?x?xf32>, %B : tensor<?x?xf32>, %C : tensor<?x?xf32>) {
   %res = linalg.matmul ins(%A, %B: tensor<?x?xf32>, tensor<?x?xf32>)
                        outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
index 7821f7cd865d..0827d9b7464a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-multi-tile-transpose.mlir
@@ -29,7 +29,7 @@ func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
   return
 }
 
-func.func @testTransposedReadWithMask() {
+func.func @testTransposedReadWithMask(%maskRows: index, %maskCols: index) {
   %in = memref.alloca() : memref<4x16xf32>
   %out = memref.alloca() : memref<16x4xf32>
 
@@ -38,9 +38,7 @@ func.func @testTransposedReadWithMask() {
 
   func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
 
-  /// A mask so we only read the first 2x15 portion of %in.
-  %maskRows = arith.constant 2 : index
-  %maskCols = arith.constant 15 : index
+  /// A mask so we only read the first maskRows x maskCols portion of %in.
   %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
   %pad = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
@@ -59,35 +57,31 @@ func.func @testTransposedReadWithMask() {
   call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
 
   /// Print the result memref.
-  vector.print str "(Masked 15x2) transposed result:"
+  vector.print str "Masked transposed result:"
   %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
   call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
 
   return
 }
 
-func.func @testTransposedWriteWithMask() {
+func.func @testTransposedWriteWithMask(%maskRows: index, %maskCols: index) {
   %in = memref.alloca() : memref<16x4xf32>
   %out = memref.alloca() : memref<4x16xf32>
 
-  %fill = arith.constant -1.0 : f32
-  linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+  %c0_f32 = arith.constant 0.0 : f32
+  linalg.fill ins(%c0_f32 : f32) outs(%out : memref<4x16xf32>)
 
   %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
   %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
 
   func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
 
-  %pad = arith.constant 0.0 : f32
-  %c0 = arith.constant 0 : index
-
   /// A regular read.
-  %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+  %c0 = arith.constant 0 : index
+  %read = vector.transfer_read %inDyn[%c0, %c0], %c0_f32 {in_bounds = [true, true]}
     : memref<?x?xf32>, vector<[16]x[4]xf32>
 
-  /// A mask so we only write the first 3x8 portion of transpose(%in).
-  %maskRows = arith.constant 3 : index
-  %maskCols = arith.constant 8 : index
+  /// A mask so we only write the first maskRows x maskCols portion of transpose(%in).
   %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
 
   /// Write out the data with a transpose. Here (like the read test) the mask
@@ -101,7 +95,7 @@ func.func @testTransposedWriteWithMask() {
   call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
 
   /// Print the result memref.
-  vector.print str "(Masked 3x8) transposed result:"
+  vector.print str "Masked transposed result:"
   %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
   call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
 
@@ -120,7 +114,7 @@ func.func @main() {
   // CHECK-NEXT:  [3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3]
   // CHECK-NEXT:  [4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4]
   //
-  //      CHECK: (Masked 15x2) transposed result:
+  //      CHECK:  Masked transposed result:
   //      CHECK:  [1,   2,   0,   0]
   // CHECK-NEXT:  [1,   2,   0,   0]
   // CHECK-NEXT:  [1,   2,   0,   0]
@@ -137,7 +131,9 @@ func.func @main() {
   // CHECK-NEXT:  [1,   2,   0,   0]
   // CHECK-NEXT:  [1,   2,   0,   0]
   // CHECK-NEXT:  [0,   0,   0,   0]
-  func.call @testTransposedReadWithMask() : () -> ()
+  %readMaskRows = arith.constant 2 : index
+  %readMaskCols = arith.constant 15 : index
+  func.call @testTransposedReadWithMask(%readMaskRows, %readMaskCols) : (index, index) -> ()
 
   //      CHECK: Input memref:
   //      CHECK:  [1,   1,   1,   1]
@@ -157,12 +153,14 @@ func.func @main() {
   // CHECK-NEXT:  [15,   15,   15,   15]
   // CHECK-NEXT:  [16,   16,   16,   16]
   //
-  //      CHECK: (Masked 3x8) transposed result:
-  //      CHECK:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
-  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
-  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
-  // CHECK-NEXT:  [-1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1,   -1]
-  func.call @testTransposedWriteWithMask() : () -> ()
+  //      CHECK:  Masked transposed result:
+  //      CHECK:  [1,   2,   3,   4,   5,   6,   7,   8,   0,   0,   0,   0,   0,   0,   0,   0]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   0,   0,   0,   0,   0,   0,   0,   0]
+  // CHECK-NEXT:  [1,   2,   3,   4,   5,   6,   7,   8,   0,   0,   0,   0,   0,   0,   0,   0]
+  // CHECK-NEXT:  [0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]
+  %writeMaskRows = arith.constant 3 : index
+  %writeMaskCols = arith.constant 8 : index
+  func.call @testTransposedWriteWithMask(%writeMaskRows, %writeMaskCols) : (index, index) -> ()
 
   return
 }

>From 0450b4ab02fcabca7f02c082659124027e7f3ce4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Jan 2024 12:46:34 +0000
Subject: [PATCH 5/5] Fixups

---
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 11 +++++------
 .../Dialect/ArmSME/Transforms/VectorLegalization.cpp  |  4 ++--
 2 files changed, 7 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 19994bf47873..4acb2a8fb7b5 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -244,9 +244,8 @@ struct AssignTileIDsPattern
 
     // Set all operations dependent on `tileOp` to use the same tile ID.
     // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't
-    // solve cases like this (%tileA and %tileB come from different root
-    // operations):
+    // example, as this only allocates tile IDs to existing ops, it can't solve
+    // cases like this (%tileA and %tileB come from different root operations):
     //
     // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
     //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
@@ -255,9 +254,9 @@ struct AssignTileIDsPattern
     // }
     //
     // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile
-    // (based on the %some_cond). Find all the ops that (transitively) depend
-    // on this tile.
+    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
+    // on the %some_cond).
+    // Find all the ops that (transitively) depend on this tile.
     SetVector<Operation *> dependantOps;
     findDependantOps(tileOp->getResult(0), dependantOps);
     auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index d436802f98a7..85ec53c2618a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -351,10 +351,10 @@ struct VectorLegalizationPass
            SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
           if (!isMultipleOfSMETileVectorType(vectorType))
             return std::nullopt;
-          auto smeTileTileCount = getNumberOfSMETilesForVectorType(vectorType);
+          auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
           auto smeTileType =
               getSMETileTypeForElement(vectorType.getElementType());
-          types = SmallVector<Type>(smeTileTileCount, smeTileType);
+          types = SmallVector<Type>(smeTileCount, smeTileType);
           return success();
         });
 



More information about the Mlir-commits mailing list