[Mlir-commits] [mlir] [mlir][ArmSME] Fix loop bounds of masked loads/stores (PR #78983)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue Jan 23 09:22:33 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/78983
>From b9bd637d21b404178be9c20e17fbbc35529864f9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 16:36:05 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Refactor ArmSMEToSCF to used shared
loop-building helper (NFC)
This will make fixing a bug (next patch) a change to one place, rather
than fixing three separate rewrites.
Note: `TileLoadOpWithMaskAndPadZeroConversion` has been merged into
`TileLoadOpConversion`, since after this change those two rewrites were
pretty much identical.
---
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 348 +++++++-----------
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 4 +-
2 files changed, 136 insertions(+), 216 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index c9d7c0c313b5c8..85ff10387628e4 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -30,11 +30,12 @@ namespace {
/// `outIndices`:
/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
/// rank 2: (indices[0] + tileSliceIndex, indices[1])
-void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
- Value tileSliceNumElts,
- SmallVectorImpl<Value> &outIndices, Location loc,
- PatternRewriter &rewriter) {
+SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
+ Value tileSliceIndex,
+ Value tileSliceNumElts, Location loc,
+ PatternRewriter &rewriter) {
assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
if (rank == 1)
@@ -47,99 +48,92 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
if (rank == 2)
outIndices.push_back(indices[1]);
-}
-/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
-/// using `arm_sme.load_tile_slice`.
-///
-/// BEFORE:
-/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
-/// memref<?x?xi32>, vector<[4]x[4]xi32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-/// %vscale = vector.vscale
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %min_svl_s = arith.constant 4 : index
-/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
-/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-/// %ptrue_s, %iter_tile, %tile_slice_idx
-/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-/// ```
-struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
- using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
- PatternRewriter &rewriter) const override {
- if (tileLoadOp.getMask())
- return rewriter.notifyMatchFailure(tileLoadOp,
- "op has mask, apply masked patterns");
-
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileLoadOp.getLoc();
- auto tileType = tileLoadOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- // Allocate a new SME tile.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ return outIndices;
+}
- // Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+/// Creates an scf.for for the load/store of an ArmSME tile.
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+ PatternRewriter &rewriter, Location loc, VectorType tileType,
+ ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
+ function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
+ /*currentTile=*/Value)>
+ makeLoopBody) {
+ PatternRewriter::InsertionGuard guard(rewriter);
+
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B,
+ // SVL_H, ..., SVL_Q).
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+ Value predicate;
+ Value upperBound;
+ if (mask) {
+ auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ loc, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto maskDim0 = createMaskOp.getOperands()[0];
+ auto maskDim1 = createMaskOp.getOperands()[1];
+
+ upperBound = maskDim0;
+ predicate =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ } else {
+ upperBound = numTileSlices;
+ // No mask. Create an 'all true' predicate for the tile slice.
+ predicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+ }
- rewriter.setInsertionPointToStart(forOp.getBody());
+ bool hasCarriedArgs = bool(initTile);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile}
+ : ValueRange{});
- // Create an 'all true' predicate for the tile slice.
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ Value tileSliceIndex = forOp.getInductionVar();
- // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
- // tile.
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- numTileSlices, memrefIndices, loc, rewriter);
- auto currentTile = forOp.getRegionIterArg(0);
- auto loadSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
- currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
+ auto adjustedIndices = getMemrefIndices(
+ memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
+ auto nextTile = makeLoopBody(
+ tileSliceIndex, adjustedIndices, predicate,
+ /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
- rewriter.setInsertionPointAfter(forOp);
+ assert(bool(nextTile) == hasCarriedArgs);
+ if (nextTile)
+ rewriter.create<scf::YieldOp>(loc, nextTile);
- // Replace 'arm_sme.tile_load' with the result.
- rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+ return forOp;
+}
- return success();
- }
-};
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+ PatternRewriter &rewriter, Location loc, VectorType tileType,
+ ValueRange memrefIndices, int memrefRank, Value mask,
+ function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
+ makeLoopBody) {
+ return createLoadStoreForOverTileSlices(
+ rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
+ [&](Value index, ValueRange adjustedIndices, Value predicate,
+ Value) -> Value {
+ makeLoopBody(index, adjustedIndices, predicate);
+ return {};
+ });
+}
-/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
///
/// BEFORE:
/// ```mlir
@@ -168,77 +162,56 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// ```
///
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
-struct TileLoadOpWithMaskAndPadZeroConversion
- : public OpRewritePattern<arm_sme::TileLoadOp> {
+struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
- OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
+ auto mask = tileLoadOp.getMask();
- auto maskOp = tileLoadOp.getMask();
- if (!maskOp)
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has no mask, needs unmasked pattern");
-
- auto padOp = tileLoadOp.getPadding();
- assert(padOp && "expected padding when masking!");
+ Value initTile;
+ if (mask) {
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
- auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
- if (!constPadOp || constPadOp.getValue() !=
- rewriter.getZeroAttr(tileType.getElementType()))
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
-
- auto numRows = createMaskOp.getOperands()[0];
- auto numCols = createMaskOp.getOperands()[1];
-
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto numColsOp =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
- // Initialize tile with zero to satisfy padding. Inactive cols will be
- // zeroed anyway since the loads use zeroing predication. For inactive rows
- // however, no load will occur so these need to be zeroed.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, tileType);
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive
+ // rows however, no load will occur so these need to be zeroed.
+ initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ rewriter, loc, tileType);
+ } else {
+ // Allocate a new SME tile.
+ initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ rewriter, loc, tileType);
+ }
// Create a loop to load the active tile slices from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = numRows;
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- ValueRange{initTile});
-
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
- // tile.
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- upperBound, memrefIndices, loc, rewriter);
- auto loadSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
- currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
-
- rewriter.setInsertionPointAfter(forOp);
+ auto forOp = createLoadStoreForOverTileSlices(
+ rewriter, loc, tileType, tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), mask, initTile,
+ [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
+ Value currentTile) -> Value {
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory
+ // into tile.
+ return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
+ });
+
+ if (failed(forOp))
+ return forOp;
// Replace 'arm_sme.tile_load' with the result.
- rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+ rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
return success();
}
@@ -345,10 +318,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
loc, predicateType, maskIndex.getResult());
- SmallVector<Value> memrefIndices;
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- numTileSlices, memrefIndices, loc, rewriter);
+ auto memrefIndices = getMemrefIndices(
+ tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
+ tileSliceIndex, numTileSlices, loc, rewriter);
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
@@ -400,77 +372,25 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileStoreOp.getLoc();
- auto tileType = tileStoreOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-
- Value maskCols;
- Value upperBound;
- auto maskOp = tileStoreOp.getMask();
- if (maskOp) {
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
- auto numRows = createMaskOp.getOperands()[0];
- auto numCols = createMaskOp.getOperands()[1];
-
- upperBound = numRows;
- maskCols =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
- } else {
- // Store all tile slices if no mask.
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B,
- // SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
- upperBound = numTileSlices;
- // Create an 'all true' predicate for the tile slice.
- maskCols = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
- }
-
// Create a loop that stores each (active) active ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- getMemrefIndices(tileStoreOp.getIndices(),
- tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
- upperBound, memrefIndices, loc, rewriter);
-
- tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
- rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
- tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
-
- return success();
+ return createLoadStoreForOverTileSlices(
+ rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
+ tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
+ tileStoreOp.getMask(),
+ [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
+ tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+ rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+ predicate, tileStoreOp.getBase(), memrefIndices,
+ tileStoreOp.getLayout());
+ });
}
};
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns
- .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
- TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
- patterns.getContext());
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
+ TileStoreOpConversion>(patterns.getContext());
}
namespace {
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 5d79a0405114a2..292f9a4d411ff7 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -11,9 +11,9 @@
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
-// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
>From 5c1bccbae611f846bb487805f02677495dad79e2 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 22 Jan 2024 12:01:48 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Fix loop bounds of masked loads/stores
Previously, for masked tile loads/stores we directly used the dimension
size from the `vector.create_mask` operation as the upper bound of the
`scf.for` over the tile slices. This was not correct, as `create_mask`
allows operands to be greater than the size of the vector dimension, in
which case the for loop bounds should be clamped to the number of tile
slices.
---
.../lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 13 ++++++++++++-
.../Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir | 18 ++++++++++++++++--
2 files changed, 28 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 85ff10387628e4..adf3aca91ba8b5 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -86,7 +86,18 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
auto maskDim0 = createMaskOp.getOperands()[0];
auto maskDim1 = createMaskOp.getOperands()[1];
- upperBound = maskDim0;
+ // The upper bound of the loop must be clamped at `numTileSlices` as
+ // `vector.create_mask` allows operands to be greater than the size of a
+ // dimension.
+ auto numRowI64 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI64Type(), maskDim0);
+ auto numTileSlicesI64 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI64Type(), numTileSlices);
+ auto upperBoundI64 =
+ rewriter.create<arith::MinSIOp>(loc, numRowI64, numTileSlicesI64);
+ upperBound = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getIndexType(), upperBoundI64);
+
predicate =
rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
} else {
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 292f9a4d411ff7..6c393bc38af9c7 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -39,10 +39,17 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
@@ -150,9 +157,16 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG: %[[NUM_ROWS_I64:.*]] = arith.index_cast %[[NUM_ROWS]] : index to i64
+// CHECK-DAG: %[[NUM_TILE_SLICES_I64:.*]] = arith.index_cast %[[NUM_TILE_SLICES]] : index to i64
+// CHECK-DAG: %[[LOOP_UPPER_BOUND_I64:.*]] = arith.minsi %[[NUM_ROWS_I64]], %[[NUM_TILE_SLICES_I64]] : i64
+// CHECK-DAG: %[[LOOP_UPPER_BOUND:.*]] = arith.index_cast %[[LOOP_UPPER_BOUND_I64]] : i64 to index
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[LOOP_UPPER_BOUND]] step %[[C1]] {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
More information about the Mlir-commits
mailing list