[Mlir-commits] [mlir] [mlir][ArmSME] Add rudimentary support for tile spills to the stack (PR #76086)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Dec 21 08:09:18 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/76086
>From cf39a7aea7ef4f0eb4ddd6a03b9171209ee68f6c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 20 Dec 2023 16:24:27 +0000
Subject: [PATCH 1/6] [mlir][ArmSME] Add rudimentary support for tile spills to
the stack
This adds very basic and inelegant support for something like spilling
and reloading tiles if you use more SME tiles than physically exist.
This is purely implemented to prevent the compiler from aborting if a
function uses too many tiles (i.e. due to bad unrolling), but is
expected to perform very poorly.
Currenly, this works in two stages:
During tile allocation, if we run out of tiles instead of giving up, we
switch to allocating 'in-memory' tile IDs. These are tile IDs that start
at 16 (which is higher than any real tile ID). A warning will also be
emitted for each (root) tile op assigned an in-memory tile ID:
```
warning: failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation
```
Everything after this works like normal until `-convert-arm-sme-to-llvm`
Here the in-memory tile op:
```mlir
arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
```
Is lowered to:
```mlir
// At function entry:
%alloca = memref.alloca ... : memref<?x?xty>
// Around the op:
// Swap the contents of %alloca and tile 0.
scf.for %slice_idx {
%current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
"arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
vector.store %current_slice, %alloca[%slice_idx, %c0]
}
// Execute op using tile 0.
arm_sme.tile_op { tile_id = 0 }
// Swap the contents of %alloca and tile 0.
// This restores tile 0 to its original state.
scf.for %slice_idx {
%current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
"arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
vector.store %current_slice, %alloca[%slice_idx, %c0]
}
```
This is inserted during the lowering to LLVM as spilling/reloading
registers is a very low-level concept, that can't really be modeled
correctly at a high level in MLIR.
Note: This is always doing the worst case full-tile swap. This could be
optimized to only spill/load data the tile op will use, which could be
just a slice. It's also not making any use of liveness, which could
allow reusing tiles. But these is not seen as important as correct code
should only use the available number of tiles.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h | 3 +-
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 33 ++
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 458 +++++++++++++-----
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 1 -
.../ArmSME/Transforms/TileAllocation.cpp | 66 ++-
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 96 ++++
mlir/test/Dialect/ArmSME/tile-allocation.mlir | 18 +-
.../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 78 +++
8 files changed, 597 insertions(+), 156 deletions(-)
create mode 100644 mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
create mode 100644 mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index 9982d4278b6033..c507cea5357a74 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -25,8 +25,9 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir::arm_sme {
+static constexpr unsigned kInMemoryTileIdBase = 16;
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-}
+} // namespace mlir::arm_sme
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index f7cc1d3fe7517f..adb3fae87e1017 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -97,6 +97,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
// This operation does not allocate a tile.
return std::nullopt;
}]
+ >,
+ InterfaceMethod<
+ "Returns the VectorType of the tile used by this operation.",
+ /*returnType=*/"VectorType",
+ /*methodName=*/"getTileType",
+ /*arguments=*/(ins),
+ /*methodBody=*/[{}]
>
];
@@ -117,6 +124,11 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
rewriter.replaceOp($_op, newOp);
return newOp;
}
+
+ bool isInMemoryTile() {
+ auto tileId = getTileId();
+ return tileId && tileId.getInt() >= kInMemoryTileIdBase;
+ }
}];
let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
@@ -316,6 +328,9 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
+ VectorType getTileType() {
+ return getVectorType();
+ }
}];
let assemblyFormat = "attr-dict `:` type($res)";
}
@@ -392,6 +407,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
return arm_sme::getSMETileType(getVectorType());
}
+ VectorType getTileType() {
+ return getVectorType();
+ }
}];
let builders = [
@@ -460,6 +478,9 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getValueToStore().getType());
}
+ VectorType getTileType() {
+ return getVectorType();
+ }
}];
let builders = [
@@ -524,6 +545,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
+ VectorType getTileType() {
+ return getVectorType();
+ }
}];
let assemblyFormat = [{
@@ -581,6 +605,9 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
+ VectorType getTileType() {
+ return getVectorType();
+ }
}];
let assemblyFormat = [{
@@ -673,6 +700,9 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
let extraClassDeclaration = [{
VectorType getSliceType() { return getResult().getType(); }
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getTile().getType());
+ }
}];
let assemblyFormat = [{
@@ -765,6 +795,9 @@ let arguments = (ins
return arm_sme::getSMETileType(getResultType());
return std::nullopt;
}
+ VectorType getTileType() {
+ return getResultType();
+ }
}];
}
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index f9d6f04a811f3e..131f734b4c7485 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -32,6 +33,97 @@ using namespace mlir;
namespace {
+static constexpr StringLiteral kInMemoryTileId("arm_sme.in_memory_tile_id");
+
+/// Helper to create a arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
+static Operation *createLoadTileSliceIntrinsic(
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+ IntegerAttr tileId, Value tileSliceI32) {
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ } else {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ break;
+ }
+ }
+}
+
+/// Helper to create a arm_sme.intr.st1*.(horiz|vert)' intrinsic.
+static Operation *createStoreTileSliceIntrinsic(
+ RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
+ arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
+ IntegerAttr tileId, Value tileSliceI32) {
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ } else {
+ switch (type) {
+ case arm_sme::ArmSMETileType::ZAB:
+ return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAH:
+ return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAS:
+ return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAD:
+ return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ case arm_sme::ArmSMETileType::ZAQ:
+ return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
+ loc, maskOp, ptr, tileId, tileSliceI32);
+ }
+ }
+}
+
IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
auto tileId = op.getTileId();
if (!tileId)
@@ -40,6 +132,209 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
return tileId;
}
+/// Creates a alloca matching the size of tile used by `tileOp`. The alloca is
+/// placed in the first block of the function.
+static memref::AllocaOp
+createAllocaForTile(RewriterBase &rewriter, Location loc,
+ FunctionOpInterface func,
+ arm_sme::ArmSMETileOpInterface tileOp) {
+ RewriterBase::InsertionGuard g(rewriter);
+ // Move to the first operation in the function.
+ rewriter.setInsertionPoint(&func.getBlocks().front().front());
+ // Create an alloca matching the tile size of the `tileOp`.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto tileElementType =
+ llvm::cast<VectorType>(tileOp.getTileType()).getElementType();
+ auto memrefType = MemRefType::get(
+ {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
+ auto minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
+ auto minElementsOp =
+ rewriter.create<arith::ConstantIndexOp>(loc, minElements);
+ auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
+ auto alloca = rewriter.create<memref::AllocaOp>(
+ loc, memrefType, ValueRange{vectorLen, vectorLen});
+ return alloca;
+}
+
+/// Finds or creates an alloca for a spill of a tile.
+static memref::AllocaOp
+getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
+ FunctionOpInterface func,
+ arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
+ // Find an alloca at the top of the function tagged with a
+ // 'arm_sme.in_memory_tile_id' that matches `tileId`.
+ for (auto &op : func.getBlocks().front()) {
+ auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
+ if (!alloca)
+ continue;
+ auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
+ alloca->getDiscardableAttr(kInMemoryTileId));
+ if (!inMemoryTileId)
+ continue;
+ if (inMemoryTileId.getInt() == tileId)
+ return alloca;
+ }
+ // Otherwise, create a new alloca:
+ auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
+ alloca->setDiscardableAttr(kInMemoryTileId,
+ rewriter.getI32IntegerAttr(tileId));
+ return alloca;
+}
+
+/// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
+/// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
+/// the op to tile 0, then emitting a full tile swap between ZA and memory
+/// before + after the tile op.
+///
+/// Example:
+///
+/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
+///
+/// is converted to:
+/// // At function entry:
+/// %alloca = memref.alloca ... : memref<?x?xty>
+///
+/// // Around op:
+/// scf.for %slice_idx {
+/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %alloca[%slice_idx, %c0]
+/// }
+/// arm_sme.tile_op { tile_id = 0 }
+/// scf.for %slice_idx {
+/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %alloca[%slice_idx, %c0]
+/// }
+///
+/// Note that these spills/fills are not inserted earlier as concept of a
+/// register, and the need to swap the contents, can't really be represented
+/// correctly at a high level in MLIR.
+///
+/// TODO: Reduce the spills/reloads to single slices where possible.
+struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
+
+ ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
+ const LLVMTypeConverter &typeConverter,
+ PatternBenefit benefit)
+ : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
+ typeConverter, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
+ // Tile has a real (hardware) tile. No spills/reloads required.
+ if (!tileOp.isInMemoryTile())
+ return failure();
+
+ // Step 1. Create an alloca for the tile at the top of the function (if one
+ // does not already exist).
+ auto loc = tileOp.getLoc();
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
+ auto tileAlloca = getOrCreateTileMemory(rewriter, loc, func, tileOp,
+ tileOp.getTileId().getInt());
+
+ // Step 2. Assign the op a real tile ID.
+ // For simplicity, we always use tile 0.
+ auto zeroTileId = rewriter.getI32IntegerAttr(0);
+ {
+ rewriter.startRootUpdate(tileOp);
+ tileOp.setTileId(zeroTileId);
+ rewriter.finalizeRootUpdate(tileOp);
+ }
+
+ VectorType tileVectorType = tileOp.getTileType();
+ auto sliceType = VectorType::Builder(tileOp.getTileType()).dropDim(0);
+ auto emitTileSwap = [&] {
+ emitFullTileSwap(rewriter, loc, tileAlloca,
+ *arm_sme::getSMETileType(tileVectorType), sliceType,
+ zeroTileId);
+ };
+
+ // Step 3. Emit tile swaps before and after the op.
+ // TODO: Reduce the amount spilled to the amount of data the `tileOp`
+ // touches (i.e. a single tile slice).
+ {
+ rewriter.setInsertionPoint(op);
+ // Swap the in-memory tile's contents into ZA before the op.
+ emitTileSwap();
+ rewriter.setInsertionPointAfter(op);
+ // Swap the tile back out to memory again after the op.
+ emitTileSwap();
+ }
+
+ return success();
+ }
+
+ /// Extracts a pointer to a slice of an in-memory tile.
+ Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
+ Value tileMemory, Value sliceIndex) const {
+ auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
+ auto descriptor =
+ rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
+ auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
+ auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI64Type(), sliceIndex);
+ return getStridedElementPtr(
+ loc, llvm::cast<MemRefType>(tileMemory.getType()),
+ descriptor.getResult(0), {sliceIndexI64, zero},
+ static_cast<ConversionPatternRewriter &>(rewriter));
+ }
+
+ /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
+ /// tile-sized memref (`tileAlloca`).
+ void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+ arm_sme::ArmSMETileType tileType, VectorType sliceType,
+ IntegerAttr tileId, Value sliceIndex) const {
+ // Cast the slice index to an i32.
+ auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), sliceIndex);
+ // Create an all-true predicate for the slice.
+ auto predicateType = sliceType.clone(rewriter.getI1Type());
+ auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+ // Create zero padding vector (never used due to all-true predicate).
+ auto zeroVector = rewriter.create<arith::ConstantOp>(
+ loc, sliceType, rewriter.getZeroAttr(sliceType));
+ // Get a pointer to the current slice.
+ auto slicePtr =
+ getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
+ // Read the value of the current slice from ZA.
+ auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+ loc, sliceType, zeroVector, allTruePredicate, tileId, sliceIndexI32);
+ // Load the new tile slice back from memory into ZA.
+ createLoadTileSliceIntrinsic(
+ rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
+ allTruePredicate, slicePtr, tileId, sliceIndexI32);
+ // Store the current tile slice to memory.
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
+ ValueRange{sliceIndex, zero});
+ }
+
+ /// Emits a full in-place swap of the contents of a tile in ZA and a
+ /// tile-sized memref (`tileAlloca`).
+ void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
+ arm_sme::ArmSMETileType tileType, VectorType sliceType,
+ IntegerAttr tileId) const {
+ RewriterBase::InsertionGuard guard(rewriter);
+ // Create an scf.for over all tile slices.
+ auto minNumElts =
+ rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(
+ loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ // Emit a swap for each tile slice.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ auto sliceIndex = forOp.getInductionVar();
+ emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
+ sliceIndex);
+ }
+};
+
struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
@@ -75,8 +370,8 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
- unsigned tileElementWidth =
- zero.getVectorType().getElementType().getIntOrFloatBitWidth();
+ arm_sme::ArmSMETileType tileType =
+ *arm_sme::getSMETileType(zero.getVectorType());
auto tileId = getTileIdOrError(zero);
if (!tileId)
@@ -87,22 +382,22 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
auto baseMaskForSize = [&] {
- switch (tileElementWidth) {
- case 8:
+ switch (tileType) {
+ case arm_sme::ArmSMETileType::ZAB:
// Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
// 64-bit element tiles named ZA0.D to ZA7.D.
return 0b1111'1111;
- case 16:
- // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element
- // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D.
- // Shift this left once for ZA1.H.
+ case arm_sme::ArmSMETileType::ZAH:
+ // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
+ // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
+ // once for ZA1.H.
return 0b0101'0101;
- case 32:
+ case arm_sme::ArmSMETileType::ZAS:
// Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
// element tiles named ZA0.D and ZA4.D.
// Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
return 0b0001'0001;
- case 64:
+ case arm_sme::ArmSMETileType::ZAD:
// Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
// setting the bit for that tile.
return 0b0000'0001;
@@ -172,63 +467,14 @@ struct LoadTileSliceConversion
// Create all active predicate mask.
auto maskOp = loadTileSliceOp.getMask();
- auto tileType = loadTileSliceOp.getVectorType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+ auto tileVectorType = loadTileSliceOp.getVectorType();
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
- // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- }
- } else {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
- tileId, tileSliceI32);
- break;
- }
- }
+ // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile
+ // slice.
+ createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
+ tileId, tileSliceI32);
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
// the input tile to preserve dataflow.
@@ -249,9 +495,7 @@ struct StoreTileSliceConversion
arm_sme::StoreTileSliceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = storeTileSliceOp.getLoc();
- auto tileType = storeTileSliceOp.getVectorType();
- auto tileElementType = tileType.getElementType();
- unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+ auto tileVectorType = storeTileSliceOp.getVectorType();
auto tileId = getTileIdOrError(storeTileSliceOp);
if (!tileId)
@@ -271,58 +515,12 @@ struct StoreTileSliceConversion
auto maskOp = storeTileSliceOp.getMask();
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+ arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
- if (layout == arm_sme::TileSliceLayout::Horizontal) {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- }
- } else {
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
- storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
- break;
- }
- }
+ rewriter.replaceOp(storeTileSliceOp,
+ createStoreTileSliceIntrinsic(rewriter, loc, tileType,
+ layout, maskOp, ptr,
+ tileId, tileSliceI32));
return success();
}
@@ -548,6 +746,17 @@ struct ConvertArmSMEToLLVMPass
}
};
+template <typename... TileOp>
+static void addSpillAndFillsForTileOp(RewritePatternSet &patterns,
+ LLVMTypeConverter const &typeConverter) {
+ // Add spill/fill conversions with a very high benefit to ensure they are
+ // lowered first.
+ (patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(TileOp::getOperationName(),
+ typeConverter,
+ /*benefit=*/1337),
+ ...);
+}
+
} // namespace
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
@@ -567,7 +776,10 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
- target.addLegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<arith::ArithDialect,
+ /* The following are used to lower tile spills/fills */
+ vector::VectorDialect, scf::SCFDialect,
+ memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
}
@@ -581,6 +793,12 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
return std::nullopt;
});
+ // Register ops that need spills/fills.
+ addSpillAndFillsForTileOp<
+ arm_sme::LoadTileSliceOp, arm_sme::MoveTileSliceToVectorOp,
+ arm_sme::MoveVectorToTileSliceOp, arm_sme::StoreTileSliceOp,
+ arm_sme::OuterProductOp, arm_sme::ZeroOp>(patterns, converter);
+
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6105cd62252830..1fa060cafc0bc6 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -69,7 +69,6 @@ LogicalResult verifyOperationHasValidTileId(Operation *op) {
return success(); // Not having a tile ID (yet) is okay.
if (!tileId.getType().isSignlessInteger(32))
return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
- // TODO: Verify value of tile ID is in range.
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 8aa51f352f822d..a77b218bc1a60b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -61,7 +61,9 @@ using namespace mlir::arm_sme;
namespace {
-static constexpr char kTilesInUseAttr[] = "arm_sme.tiles_in_use";
+static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
+static constexpr StringLiteral
+ kNextTileMemoryIndex("arm_sme.next_in_memory_tile_id");
enum class TileMask : unsigned {
// clang-format off
@@ -200,37 +202,49 @@ static void findDependantOps(Value rootValue,
});
}
}
-
struct AssignTileIDsPattern
: public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
PatternRewriter &rewriter) const override {
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
if (tileOp.getTileId())
return failure();
+ auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
+ if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
+ func->getDiscardableAttr(name)))
+ return unsigned(attr.getInt());
+ return defaultVal;
+ };
+
+ auto setDiscardableIntAttr = [&](StringRef name, auto value) {
+ rewriter.updateRootInPlace(tileOp, [&] {
+ func->setDiscardableAttr(name,
+ rewriter.getI32IntegerAttr((unsigned)value));
+ });
+ };
+
std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
if (!tileType)
return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
- TileMask tilesInUse = TileMask::kNone;
- if (auto tilesInUseAttr = llvm::dyn_cast_or_null<IntegerAttr>(
- func->getDiscardableAttr(kTilesInUseAttr)))
- tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
-
+ TileMask tilesInUse =
+ static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
auto tileId = allocateTileId(*tileType, tilesInUse);
- if (failed(tileId))
- return tileOp.emitError("ran out of SME virtual tiles!");
-
- rewriter.updateRootInPlace(func, [&]() {
- func->setDiscardableAttr(
- kTilesInUseAttr, rewriter.getI32IntegerAttr((unsigned)tilesInUse));
- });
-
- // Find all the ops that (transitively) depend on this tile.
- SetVector<Operation *> dependantOps;
- findDependantOps(tileOp->getResult(0), dependantOps);
+ bool tileIsInMemory = failed(tileId);
+ if (!tileIsInMemory)
+ setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+ else {
+ // If we could not find a real tile, set use a virtual tile ID (ID >= 16).
+ // A later pass will insert the necessary spills and reloads.
+ tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase);
+ setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1);
+ tileOp->emitWarning(
+ "failed to allocate physical tile to operation, all tile "
+ "operations will go through memory, expect "
+ "performance degradation");
+ }
// Set all operations dependent on `tileOp` to use the same tile ID.
// This is a naive tile allocation scheme, but works for common cases. For
@@ -246,16 +260,18 @@ struct AssignTileIDsPattern
// This case would require allocating a new tile for the result of the
// scf.if, and moving the contents of %tileA or %tileB to result tile (based
// on the %some_cond).
+ // Find all the ops that (transitively) depend on this tile.
+ SetVector<Operation *> dependantOps;
+ findDependantOps(tileOp->getResult(0), dependantOps);
auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
- rewriter.updateRootInPlace(tileOp, [&]() { tileOp.setTileId(tileIDAttr); });
+ rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
for (auto *op : dependantOps) {
- if (auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- auto currentTileId = tileOp.getTileId();
+ if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
+ auto currentTileId = dependantTileOp.getTileId();
if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
- return tileOp.emitOpError(
+ return dependantTileOp.emitOpError(
"already assigned different SME virtual tile!");
- rewriter.updateRootInPlace(tileOp,
- [&]() { tileOp.setTileId(tileIDAttr); });
+ dependantTileOp.setTileId(tileIDAttr);
}
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
new file mode 100644
index 00000000000000..9908f04b7c8557
--- /dev/null
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -0,0 +1,96 @@
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \
+// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \
+// RUN: -split-input-file -verify-diagnostics | \
+// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING
+
+// -----
+
+/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles
+/// that were not assigned a physical SME tile).
+///
+/// These spills are currently very naive and paranoid and will spill/reload
+/// entire tiles around ArmSME ops.
+///
+/// The general pattern is:
+///
+/// During tile allocation if there's not a physical tile ID available an op
+/// will be assigned an in-memory tile ID (which is a tile ID >= 16).
+///
+/// Example:
+///
+/// arm_sme.zero : vector<[8]x[8]xi16>
+///
+/// Becomes:
+///
+/// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16>
+///
+/// This works like normal till the final lowering to LLVM, where spills and
+/// reloads will be inserted around uses of in-memory tiles.
+///
+/// So the above example becomes:
+///
+/// // Placed at the top of the function:
+/// %tileAlloca = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
+///
+/// Then around the op:
+///
+/// // Swap contents of %tileAlloca and tile 0
+/// scf.for %sliceIdx ... {
+/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0}
+/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0}
+/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0]
+/// }
+/// // Execute the op using tile 0
+/// arm_sme.intr.zero
+/// // Swap contents of %tileAlloca and tile 0
+/// scf.for %sliceIdx ... {
+/// %currentSlice = arm_sme.intr.read.horiz {tile_id = 0}
+/// arm_sme.intr.ld1h.horiz %tileAlloca[%sliceIdx, %c0] {tile_id = 0}
+/// vector.store %currentSlice, %tileAlloca[%sliceIdx, %c0]
+/// }
+///
+
+func.func @use_too_many_tiles() {
+ %0 = arm_sme.zero : vector<[4]x[4]xi32>
+ %1 = arm_sme.zero : vector<[4]x[4]xi32>
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ %2 = arm_sme.zero : vector<[8]x[8]xi16>
+ return
+}
+// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
+// AFTER-TILE-ALLOC: arm_sme.zero
+// AFTER-TILE-ALLOC-SAME: tile_id = 0
+// AFTER-TILE-ALLOC: arm_sme.zero
+// AFTER-TILE-ALLOC-SAME: tile_id = 1
+// AFTER-TILE-ALLOC: arm_sme.zero
+// AFTER-TILE-ALLOC-SAME: tile_id = 16
+
+// AFTER-LLVM-LOWERING-LABEL: @use_too_many_tiles
+// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index
+// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index
+// AFTER-LLVM-LOWERING-DAG: %[[C8:.*]] = arith.constant 8 : index
+// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale
+// AFTER-LLVM-LOWERING-DAG: %[[SVL_H:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
+// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_H]], %[[SVL_H]])
+// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xi16>
+//
+// AFTER-LLVM-LOWERING-NOT: scf.for
+// AFTER-LLVM-LOWERING: arm_sme.intr.zero
+//
+// AFTER-LLVM-LOWERING-NOT: scf.for
+// AFTER-LLVM-LOWERING: arm_sme.intr.zero
+//
+// AFTER-LLVM-LOWERING: scf.for
+// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
+// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz
+// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz
+// AFTER-LLVM-LOWERING-NEXT: vector.store
+// AFTER-LLVM-LOWERING-NEXT: }
+// AFTER-LLVM-LOWERING: arm_sme.intr.zero
+// AFTER-LLVM-LOWERING: scf.for
+// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
+// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz
+// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz
+// AFTER-LLVM-LOWERING-NEXT: vector.store
+// AFTER-LLVM-LOWERING-NEXT: }
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index 1f895e4984ba84..7c887ced160b14 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -35,7 +35,7 @@ func.func @za_b() {
func.func @za_b__out_of_tiles() {
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
return
}
@@ -44,7 +44,7 @@ func.func @za_b__out_of_tiles() {
func.func @za_b_overlapping_za_q() {
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -79,7 +79,7 @@ func.func @za_h__out_of_tiles() {
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// CHECK-NEXT: tile_id = 1
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
return
}
@@ -136,7 +136,7 @@ func.func @za_h_overlapping_za_q() {
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -174,7 +174,7 @@ func.func @za_s__out_of_tiles() {
%za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
%za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
@@ -218,7 +218,7 @@ func.func @za_s_overlapping_za_q() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -268,7 +268,7 @@ func.func @za_d__out_of_tiles() {
%za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
%za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
@@ -291,7 +291,7 @@ func.func @za_d_overlapping_za_q() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -365,7 +365,7 @@ func.func @za_q__out_of_tiles() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-error at +1 {{ran out of SME virtual tiles!}}
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
new file mode 100644
index 00000000000000..ea48fa77861cf1
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -0,0 +1,78 @@
+
+// RUN: mlir-opt %s \
+// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \
+// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN: -convert-arm-sme-to-llvm -convert-vector-to-llvm=enable-arm-sve -cse \
+// RUN: -canonicalize -test-lower-to-llvm -verify-diagnostics | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -e=main -entry-point-result=void \
+// RUN: -march=aarch64 -mattr="+sve,+sme" \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
+// RUN: FileCheck %s
+
+/// This function uses too many tiles! There's only two i16 tiles (ZA0.H and
+/// ZA1.H), but this function uses five i16 tiles! Very expensive spills/reloads
+/// will be inserted to emulate the extra three tiles. Note: This is only done
+/// to avoid the compiler erroring out but is expected to have very poor
+/// performance (hence the warning).
+func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b: memref<?x?xi16>, %c: memref<?x?xi16>) {
+ %c0 = arith.constant 0 : index
+ %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
+ %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+
+ // CHECK-LABEL: tile_a:
+ // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0
+ vector.print str "tile_a:"
+ vector.print %tile_a : vector<[8]x[8]xi16>
+ // CHECK-LABEL: tile_b:
+ // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1
+ vector.print str "tile_b:"
+ vector.print %tile_b : vector<[8]x[8]xi16>
+ // CHECK-LABEL: tile_c:
+ // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2
+ vector.print str "tile_c:"
+ vector.print %tile_c : vector<[8]x[8]xi16>
+ // CHECK-LABEL: tile_d:
+ // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3
+ vector.print str "tile_d:"
+ vector.print %tile_d : vector<[8]x[8]xi16>
+ // CHECK-LABEL: tile_e:
+ // CHECK-COUNT-8: ( 4, 4, 4, 4, 4, 4, 4, 4
+ vector.print str "tile_e:"
+ vector.print %tile_e : vector<[8]x[8]xi16>
+ return
+}
+
+func.func @get_svl() -> index attributes { enable_arm_streaming_ignore, arm_locally_streaming }{
+ %vscale = vector.vscale
+ return %vscale : index
+}
+
+func.func @main() {
+ %c16 = arith.constant 16 : index
+ %svl = call @get_svl() : () -> index
+ %svl_h = arith.muli %c16, %svl : index
+
+ %two = arith.constant 2 : i16
+ %three = arith.constant 3 : i16
+ %four = arith.constant 4 : i16
+
+ %memA = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
+ %memB = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
+ %memC = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
+
+ linalg.fill ins(%two : i16) outs(%memA : memref<?x?xi16>)
+ linalg.fill ins(%three : i16) outs(%memB : memref<?x?xi16>)
+ linalg.fill ins(%four : i16) outs(%memC : memref<?x?xi16>)
+
+ func.call @use_too_many_tiles(%memA, %memB, %memC) : (memref<?x?xi16>, memref<?x?xi16>, memref<?x?xi16>) -> ()
+ return
+}
>From 590de4e5959ae3b35e7c4f436ad56e11afdd51ff Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 11:59:43 +0000
Subject: [PATCH 2/6] fixups
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 3 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 47 +++++++++----------
.../ArmSME/Transforms/TileAllocation.cpp | 5 +-
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 2 +-
mlir/test/Dialect/ArmSME/tile-allocation.mlir | 18 +++----
.../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 18 +++----
6 files changed, 43 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index adb3fae87e1017..d80b73c810646f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -102,8 +102,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
/*methodName=*/"getTileType",
- /*arguments=*/(ins),
- /*methodBody=*/[{}]
+ /*arguments=*/(ins)
>
];
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 131f734b4c7485..8995c2a46367c6 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -33,7 +33,7 @@ using namespace mlir;
namespace {
-static constexpr StringLiteral kInMemoryTileId("arm_sme.in_memory_tile_id");
+static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
/// Helper to create a arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
static Operation *createLoadTileSliceIntrinsic(
@@ -132,7 +132,7 @@ IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
return tileId;
}
-/// Creates a alloca matching the size of tile used by `tileOp`. The alloca is
+/// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
/// placed in the first block of the function.
static memref::AllocaOp
createAllocaForTile(RewriterBase &rewriter, Location loc,
@@ -140,14 +140,13 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
arm_sme::ArmSMETileOpInterface tileOp) {
RewriterBase::InsertionGuard g(rewriter);
// Move to the first operation in the function.
- rewriter.setInsertionPoint(&func.getBlocks().front().front());
+ rewriter.setInsertionPointToStart(&func.getBlocks().front());
// Create an alloca matching the tile size of the `tileOp`.
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- auto tileElementType =
- llvm::cast<VectorType>(tileOp.getTileType()).getElementType();
+ auto tileElementType = tileOp.getTileType().getElementType();
auto memrefType = MemRefType::get(
{ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
- auto minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
+ unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
auto minElementsOp =
rewriter.create<arith::ConstantIndexOp>(loc, minElements);
auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
@@ -157,10 +156,9 @@ createAllocaForTile(RewriterBase &rewriter, Location loc,
}
/// Finds or creates an alloca for a spill of a tile.
-static memref::AllocaOp
-getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
- FunctionOpInterface func,
- arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
+static memref::AllocaOp getOrCreateAllocaForTile(
+ RewriterBase &rewriter, Location loc, FunctionOpInterface func,
+ arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
// Find an alloca at the top of the function tagged with a
// 'arm_sme.in_memory_tile_id' that matches `tileId`.
for (auto &op : func.getBlocks().front()) {
@@ -168,7 +166,7 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
if (!alloca)
continue;
auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
- alloca->getDiscardableAttr(kInMemoryTileId));
+ alloca->getDiscardableAttr(kInMemoryTileIdAttr));
if (!inMemoryTileId)
continue;
if (inMemoryTileId.getInt() == tileId)
@@ -176,7 +174,7 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
}
// Otherwise, create a new alloca:
auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
- alloca->setDiscardableAttr(kInMemoryTileId,
+ alloca->setDiscardableAttr(kInMemoryTileIdAttr,
rewriter.getI32IntegerAttr(tileId));
return alloca;
}
@@ -188,23 +186,24 @@ getOrCreateTileMemory(RewriterBase &rewriter, Location loc,
///
/// Example:
///
+/// // Note: <IN MEMORY TILE> = tile ID >= 16.
/// arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
///
/// is converted to:
/// // At function entry:
-/// %alloca = memref.alloca ... : memref<?x?xty>
+/// %spill = memref.alloca ... : memref<?x?xty>
///
/// // Around op:
/// scf.for %slice_idx {
/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
-/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
-/// vector.store %current_slice, %alloca[%slice_idx, %c0]
+/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %spill[%slice_idx, %c0]
/// }
/// arm_sme.tile_op { tile_id = 0 }
/// scf.for %slice_idx {
/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
-/// "arm_sme.intr.ld1h.horiz"(%alloca, %slice_idx) <{tile_id = 0 : i32}>
-/// vector.store %current_slice, %alloca[%slice_idx, %c0]
+/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
+/// vector.store %current_slice, %spill[%slice_idx, %c0]
/// }
///
/// Note that these spills/fills are not inserted earlier as concept of a
@@ -232,20 +231,16 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
// does not already exist).
auto loc = tileOp.getLoc();
auto func = tileOp->getParentOfType<FunctionOpInterface>();
- auto tileAlloca = getOrCreateTileMemory(rewriter, loc, func, tileOp,
- tileOp.getTileId().getInt());
+ auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
+ tileOp.getTileId().getInt());
// Step 2. Assign the op a real tile ID.
- // For simplicity, we always use tile 0.
+ // For simplicity, we always use tile 0 (which always exists).
auto zeroTileId = rewriter.getI32IntegerAttr(0);
- {
- rewriter.startRootUpdate(tileOp);
- tileOp.setTileId(zeroTileId);
- rewriter.finalizeRootUpdate(tileOp);
- }
+ rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
VectorType tileVectorType = tileOp.getTileType();
- auto sliceType = VectorType::Builder(tileOp.getTileType()).dropDim(0);
+ auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
auto emitTileSwap = [&] {
emitFullTileSwap(rewriter, loc, tileAlloca,
*arm_sme::getSMETileType(tileVectorType), sliceType,
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index a77b218bc1a60b..3c089d47d28609 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -241,9 +241,8 @@ struct AssignTileIDsPattern
tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase);
setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1);
tileOp->emitWarning(
- "failed to allocate physical tile to operation, all tile "
- "operations will go through memory, expect "
- "performance degradation");
+ "failed to allocate SME virtual tile to operation, all tile "
+ "operations will go through memory, expect degraded performance");
}
// Set all operations dependent on `tileOp` to use the same tile ID.
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 9908f04b7c8557..999acbfc66bef4 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -54,7 +54,7 @@
func.func @use_too_many_tiles() {
%0 = arm_sme.zero : vector<[4]x[4]xi32>
%1 = arm_sme.zero : vector<[4]x[4]xi32>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index 7c887ced160b14..9c368dd4fa23f8 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -35,7 +35,7 @@ func.func @za_b() {
func.func @za_b__out_of_tiles() {
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
return
}
@@ -44,7 +44,7 @@ func.func @za_b__out_of_tiles() {
func.func @za_b_overlapping_za_q() {
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -79,7 +79,7 @@ func.func @za_h__out_of_tiles() {
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// CHECK-NEXT: tile_id = 1
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
return
}
@@ -136,7 +136,7 @@ func.func @za_h_overlapping_za_q() {
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -174,7 +174,7 @@ func.func @za_s__out_of_tiles() {
%za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
%za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
return
}
@@ -218,7 +218,7 @@ func.func @za_s_overlapping_za_q() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -268,7 +268,7 @@ func.func @za_d__out_of_tiles() {
%za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
%za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
return
}
@@ -291,7 +291,7 @@ func.func @za_d_overlapping_za_q() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
@@ -365,7 +365,7 @@ func.func @za_q__out_of_tiles() {
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
return
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index ea48fa77861cf1..ef5a8742687511 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -21,11 +21,11 @@ func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b: memref<?x?xi16>, %c: mem
%c0 = arith.constant 0 : index
%tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
%tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate physical tile to operation, all tile operations will go through memory, expect performance degradation}}
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
// CHECK-LABEL: tile_a:
@@ -61,17 +61,17 @@ func.func @main() {
%svl = call @get_svl() : () -> index
%svl_h = arith.muli %c16, %svl : index
- %two = arith.constant 2 : i16
- %three = arith.constant 3 : i16
- %four = arith.constant 4 : i16
+ %c2 = arith.constant 2 : i16
+ %c3 = arith.constant 3 : i16
+ %c4 = arith.constant 4 : i16
%memA = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
%memB = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
%memC = memref.alloca(%svl_h, %svl_h) : memref<?x?xi16>
- linalg.fill ins(%two : i16) outs(%memA : memref<?x?xi16>)
- linalg.fill ins(%three : i16) outs(%memB : memref<?x?xi16>)
- linalg.fill ins(%four : i16) outs(%memC : memref<?x?xi16>)
+ linalg.fill ins(%c2 : i16) outs(%memA : memref<?x?xi16>)
+ linalg.fill ins(%c3 : i16) outs(%memB : memref<?x?xi16>)
+ linalg.fill ins(%c4 : i16) outs(%memC : memref<?x?xi16>)
func.call @use_too_many_tiles(%memA, %memB, %memC) : (memref<?x?xi16>, memref<?x?xi16>, memref<?x?xi16>) -> ()
return
>From f59d2f9b32500d69e740c8109d9cf8d904ecb679 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 14:21:50 +0000
Subject: [PATCH 3/6] fixups
- Show alloca usage in tests
- Add test showing some very excessive spills
- Document a possible API to reduce spills
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 21 +++++-
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 73 +++++++++++++++++--
2 files changed, 85 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 8995c2a46367c6..646ae18ad345ca 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -210,7 +210,26 @@ static memref::AllocaOp getOrCreateAllocaForTile(
/// register, and the need to swap the contents, can't really be represented
/// correctly at a high level in MLIR.
///
-/// TODO: Reduce the spills/reloads to single slices where possible.
+/// TODO: Reduce the spills/reloads to single slices where possible (and omit
+/// redundant reloads). This could be done via a method on the
+/// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
+///
+/// `tileOp.getZaUsage()` could return:
+///
+/// struct ArmSMEOpZAUsage {
+/// enum class Kind {
+/// TileRead, // Omit store after tile operation.
+/// TileWrite, // Omit load before tile operation.
+/// TileReadWrite, // Needs both tile load and store.
+/// SliceRead, // Spill single slice and omit store after operation.
+/// SliceWrite, // Spill single slice and omit load before operation.
+/// SliceReadWrite // Spill single slice.
+/// };
+/// Value sliceIndex {};
+/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
+/// };
+///
+}
struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 999acbfc66bef4..ffa249f9986019 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -4,8 +4,6 @@
// RUN: -split-input-file -verify-diagnostics | \
// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING
-// -----
-
/// Checks tile spill/reloads are inserted around in-memory tiles (i.e. tiles
/// that were not assigned a physical SME tile).
///
@@ -51,6 +49,10 @@
/// }
///
+// -----
+
+/// Note: In this example loads into ZA are inserted before the zero instruction.
+/// These are obviously redundant, but there's no checks to avoid this.
func.func @use_too_many_tiles() {
%0 = arm_sme.zero : vector<[4]x[4]xi32>
%1 = arm_sme.zero : vector<[4]x[4]xi32>
@@ -83,14 +85,69 @@ func.func @use_too_many_tiles() {
//
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz
-// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz
-// AFTER-LLVM-LOWERING-NEXT: vector.store
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
+// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING-NEXT: }
// AFTER-LLVM-LOWERING: arm_sme.intr.zero
// AFTER-LLVM-LOWERING: scf.for
// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_H]] step %[[C1]] {
-// AFTER-LLVM-LOWERING: arm_sme.intr.read.horiz
-// AFTER-LLVM-LOWERING-NEXT: arm_sme.intr.ld1h.horiz
-// AFTER-LLVM-LOWERING-NEXT: vector.store
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
+// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1h.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING-NEXT: }
+
+// -----
+
+/// Note: In this example an entire tile swap is inserted before/after the
+/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
+/// single tile slice (and can omit the initial load, like in the previous example).
+func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
+ %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %mask = vector.constant_mask [4] : vector<[4]xi1>
+ %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ return %loadSlice : vector<[4]x[4]xf32>
+}
+// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
+// AFTER-TILE-ALLOC: arm_sme.get_tile
+// AFTER-TILE-ALLOC-SAME: tile_id = 0
+// AFTER-TILE-ALLOC: arm_sme.load_tile_slice
+// AFTER-TILE-ALLOC-SAME: tile_id = 16
+
+// AFTER-LLVM-LOWERING-LABEL: @very_excessive_spills
+// AFTER-LLVM-LOWERING-DAG: %[[C0:.*]] = arith.constant 0 : index
+// AFTER-LLVM-LOWERING-DAG: %[[C1:.*]] = arith.constant 1 : index
+// AFTER-LLVM-LOWERING-DAG: %[[C4:.*]] = arith.constant 4 : index
+// AFTER-LLVM-LOWERING-DAG: %[[VSCALE:.*]] = vector.vscale
+// AFTER-LLVM-LOWERING-DAG: %[[SVL_S:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// AFTER-LLVM-LOWERING-DAG: %[[TILE_ALLOCA:.*]] = memref.alloca(%[[SVL_S]], %[[SVL_S]])
+// AFTER-LLVM-LOWERING-SAME: {arm_sme.in_memory_tile_id = 16 : i32} : memref<?x?xf32>
+//
+// AFTER-LLVM-LOWERING: scf.for
+// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
+// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING-NEXT: }
+// AFTER-LLVM-LOWERING: "arm_sme.intr.ld1w.horiz"{{.*}} <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING: scf.for
+// AFTER-LLVM-LOWERING-SAME: %[[C0]] to %[[SVL_S]] step %[[C1]] {
+// AFTER-LLVM-LOWERING: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[TILE_ALLOCA]]
+// AFTER-LLVM-LOWERING: %[[BASE_PTR:.*]] = llvm.extractvalue %[[MEM_DESC]][1]
+// AFTER-LLVM-LOWERING: %[[SLICE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]]
+// AFTER-LLVM-LOWERING: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"{{.*}} <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: "arm_sme.intr.ld1w.horiz"({{.*}}, %[[SLICE_PTR]], {{.*}}) <{tile_id = 0 : i32}>
+// AFTER-LLVM-LOWERING-NEXT: vector.store %[[SLICE]], %[[TILE_ALLOCA]]
// AFTER-LLVM-LOWERING-NEXT: }
>From 861429847d864e9f7530e8925f5497e8fc0cd7b7 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 14:25:13 +0000
Subject: [PATCH 4/6] Remove newline
---
.../Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index ef5a8742687511..fe125c9f3cf160 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -1,4 +1,3 @@
-
// RUN: mlir-opt %s \
// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles \
// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
>From b8d34d7ec9f1ba31105bdc651a7de50ebed8defb Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 14:45:03 +0000
Subject: [PATCH 5/6] Fix build error
---
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 646ae18ad345ca..3fd053bef90621 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -229,7 +229,6 @@ static memref::AllocaOp getOrCreateAllocaForTile(
/// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
/// };
///
-}
struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
>From d375d3740150bc80dd4062ba77e5ec05655c4817 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 21 Dec 2023 16:08:31 +0000
Subject: [PATCH 6/6] fixups
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 3 +--
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 19 ++++++++-----------
.../ArmSME/Transforms/TileAllocation.cpp | 10 +++++-----
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 2 +-
4 files changed, 15 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index d80b73c810646f..973d83ff362b88 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -101,8 +101,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
- /*methodName=*/"getTileType",
- /*arguments=*/(ins)
+ /*methodName=*/"getTileType"
>
];
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 3fd053bef90621..35aab4640ec9f0 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -195,15 +195,15 @@ static memref::AllocaOp getOrCreateAllocaForTile(
///
/// // Around op:
/// scf.for %slice_idx {
-/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
-/// vector.store %current_slice, %spill[%slice_idx, %c0]
+/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
/// }
/// arm_sme.tile_op { tile_id = 0 }
/// scf.for %slice_idx {
-/// %current_slice = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
+/// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
/// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}>
-/// vector.store %current_slice, %spill[%slice_idx, %c0]
+/// vector.store %slice_to_save, %spill[%slice_idx, %c0]
/// }
///
/// Note that these spills/fills are not inserted earlier as concept of a
@@ -307,15 +307,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
auto predicateType = sliceType.clone(rewriter.getI1Type());
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
- // Create zero padding vector (never used due to all-true predicate).
- auto zeroVector = rewriter.create<arith::ConstantOp>(
- loc, sliceType, rewriter.getZeroAttr(sliceType));
+ // Create padding vector (never used due to all-true predicate).
+ auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
// Get a pointer to the current slice.
auto slicePtr =
getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
// Read the value of the current slice from ZA.
auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
- loc, sliceType, zeroVector, allTruePredicate, tileId, sliceIndexI32);
+ loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
// Load the new tile slice back from memory into ZA.
createLoadTileSliceIntrinsic(
rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
@@ -383,9 +382,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = zero.getLoc();
- arm_sme::ArmSMETileType tileType =
- *arm_sme::getSMETileType(zero.getVectorType());
-
auto tileId = getTileIdOrError(zero);
if (!tileId)
return failure();
@@ -394,6 +390,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
// The base mask is just the mask to zero the first tile (of a size).
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
+ arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
auto baseMaskForSize = [&] {
switch (tileType) {
case arm_sme::ArmSMETileType::ZAB:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 3c089d47d28609..51a85f516319f0 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -63,7 +63,7 @@ namespace {
static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
static constexpr StringLiteral
- kNextTileMemoryIndex("arm_sme.next_in_memory_tile_id");
+ kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
enum class TileMask : unsigned {
// clang-format off
@@ -207,17 +207,16 @@ struct AssignTileIDsPattern
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
PatternRewriter &rewriter) const override {
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
if (tileOp.getTileId())
return failure();
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
func->getDiscardableAttr(name)))
return unsigned(attr.getInt());
return defaultVal;
};
-
auto setDiscardableIntAttr = [&](StringRef name, auto value) {
rewriter.updateRootInPlace(tileOp, [&] {
func->setDiscardableAttr(name,
@@ -238,8 +237,9 @@ struct AssignTileIDsPattern
else {
// If we could not find a real tile, set use a virtual tile ID (ID >= 16).
// A later pass will insert the necessary spills and reloads.
- tileId = getDiscardableIntAttr(kNextTileMemoryIndex, kInMemoryTileIdBase);
- setDiscardableIntAttr(kNextTileMemoryIndex, *tileId + 1);
+ tileId =
+ getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
+ setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
tileOp->emitWarning(
"failed to allocate SME virtual tile to operation, all tile "
"operations will go through memory, expect degraded performance");
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index ffa249f9986019..6a4ac2dfb05bb5 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -23,7 +23,7 @@
///
/// arm_sme.zero { tile_id = 16 } : vector<[8]x[8]xi16>
///
-/// This works like normal till the final lowering to LLVM, where spills and
+/// This works like normal until the final lowering to LLVM, where spills and
/// reloads will be inserted around uses of in-memory tiles.
///
/// So the above example becomes:
More information about the Mlir-commits
mailing list