[Mlir-commits] [mlir] [mlir][ArmSME] Refactor ArmSMEToSCF to used shared loop-building helper (NFC) (PR #79172)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Jan 24 02:54:25 PST 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/79172
>From b9bd637d21b404178be9c20e17fbbc35529864f9 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 23 Jan 2024 16:36:05 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Refactor ArmSMEToSCF to used shared
loop-building helper (NFC)
This will make fixing a bug (next patch) a change to one place, rather
than fixing three separate rewrites.
Note: `TileLoadOpWithMaskAndPadZeroConversion` has been merged into
`TileLoadOpConversion`, since after this change those two rewrites were
pretty much identical.
---
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 348 +++++++-----------
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 4 +-
2 files changed, 136 insertions(+), 216 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index c9d7c0c313b5c8d..85ff10387628e46 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -30,11 +30,12 @@ namespace {
/// `outIndices`:
/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
/// rank 2: (indices[0] + tileSliceIndex, indices[1])
-void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
- Value tileSliceNumElts,
- SmallVectorImpl<Value> &outIndices, Location loc,
- PatternRewriter &rewriter) {
+SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
+ Value tileSliceIndex,
+ Value tileSliceNumElts, Location loc,
+ PatternRewriter &rewriter) {
assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+ SmallVector<Value, 2> outIndices;
auto tileSliceOffset = tileSliceIndex;
if (rank == 1)
@@ -47,99 +48,92 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
if (rank == 2)
outIndices.push_back(indices[1]);
-}
-/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
-/// using `arm_sme.load_tile_slice`.
-///
-/// BEFORE:
-/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
-/// memref<?x?xi32>, vector<[4]x[4]xi32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-/// %vscale = vector.vscale
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %min_svl_s = arith.constant 4 : index
-/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
-/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-/// %ptrue_s, %iter_tile, %tile_slice_idx
-/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-/// ```
-struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
- using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
- PatternRewriter &rewriter) const override {
- if (tileLoadOp.getMask())
- return rewriter.notifyMatchFailure(tileLoadOp,
- "op has mask, apply masked patterns");
-
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileLoadOp.getLoc();
- auto tileType = tileLoadOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- // Allocate a new SME tile.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ return outIndices;
+}
- // Create a loop that loads each ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
- step, ValueRange{initTile});
+/// Creates an scf.for for the load/store of an ArmSME tile.
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+ PatternRewriter &rewriter, Location loc, VectorType tileType,
+ ValueRange memrefIndices, int memrefRank, Value mask, Value initTile,
+ function_ref<Value(/*index=*/Value, ValueRange, /*predicate=*/Value,
+ /*currentTile=*/Value)>
+ makeLoopBody) {
+ PatternRewriter::InsertionGuard guard(rewriter);
+
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType()));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto predicateType =
+ VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B,
+ // SVL_H, ..., SVL_Q).
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+ Value predicate;
+ Value upperBound;
+ if (mask) {
+ auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ loc, "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+
+ auto maskDim0 = createMaskOp.getOperands()[0];
+ auto maskDim1 = createMaskOp.getOperands()[1];
+
+ upperBound = maskDim0;
+ predicate =
+ rewriter.create<vector::CreateMaskOp>(loc, predicateType, maskDim1);
+ } else {
+ upperBound = numTileSlices;
+ // No mask. Create an 'all true' predicate for the tile slice.
+ predicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+ }
- rewriter.setInsertionPointToStart(forOp.getBody());
+ bool hasCarriedArgs = bool(initTile);
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+ hasCarriedArgs ? ValueRange{initTile}
+ : ValueRange{});
- // Create an 'all true' predicate for the tile slice.
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto allTruePredicate = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ Value tileSliceIndex = forOp.getInductionVar();
- // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
- // tile.
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- numTileSlices, memrefIndices, loc, rewriter);
- auto currentTile = forOp.getRegionIterArg(0);
- auto loadSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
- currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
+ auto adjustedIndices = getMemrefIndices(
+ memrefIndices, memrefRank, tileSliceIndex, numTileSlices, loc, rewriter);
+ auto nextTile = makeLoopBody(
+ tileSliceIndex, adjustedIndices, predicate,
+ /*currentTile=*/hasCarriedArgs ? forOp.getRegionIterArg(0) : Value{});
- rewriter.setInsertionPointAfter(forOp);
+ assert(bool(nextTile) == hasCarriedArgs);
+ if (nextTile)
+ rewriter.create<scf::YieldOp>(loc, nextTile);
- // Replace 'arm_sme.tile_load' with the result.
- rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+ return forOp;
+}
- return success();
- }
-};
+FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
+ PatternRewriter &rewriter, Location loc, VectorType tileType,
+ ValueRange memrefIndices, int memrefRank, Value mask,
+ function_ref<void(/*index=*/Value, ValueRange, /*predicate=*/Value)>
+ makeLoopBody) {
+ return createLoadStoreForOverTileSlices(
+ rewriter, loc, tileType, memrefIndices, memrefRank, mask, Value{},
+ [&](Value index, ValueRange adjustedIndices, Value predicate,
+ Value) -> Value {
+ makeLoopBody(index, adjustedIndices, predicate);
+ return {};
+ });
+}
-/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
///
/// BEFORE:
/// ```mlir
@@ -168,77 +162,56 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// ```
///
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
-struct TileLoadOpWithMaskAndPadZeroConversion
- : public OpRewritePattern<arm_sme::TileLoadOp> {
+struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
PatternRewriter &rewriter) const override {
- OpBuilder::InsertionGuard g(rewriter);
auto loc = tileLoadOp.getLoc();
auto tileType = tileLoadOp.getVectorType();
+ auto mask = tileLoadOp.getMask();
- auto maskOp = tileLoadOp.getMask();
- if (!maskOp)
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has no mask, needs unmasked pattern");
-
- auto padOp = tileLoadOp.getPadding();
- assert(padOp && "expected padding when masking!");
+ Value initTile;
+ if (mask) {
+ auto padOp = tileLoadOp.getPadding();
+ assert(padOp && "expected padding when masking!");
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
- auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
- if (!constPadOp || constPadOp.getValue() !=
- rewriter.getZeroAttr(tileType.getElementType()))
- return rewriter.notifyMatchFailure(
- tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
-
- auto numRows = createMaskOp.getOperands()[0];
- auto numCols = createMaskOp.getOperands()[1];
-
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
- auto numColsOp =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+ auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+ if (!constPadOp || constPadOp.getValue() !=
+ rewriter.getZeroAttr(tileType.getElementType()))
+ return rewriter.notifyMatchFailure(
+ tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
- // Initialize tile with zero to satisfy padding. Inactive cols will be
- // zeroed anyway since the loads use zeroing predication. For inactive rows
- // however, no load will occur so these need to be zeroed.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, tileType);
+ // Initialize tile with zero to satisfy padding. Inactive cols will be
+ // zeroed anyway since the loads use zeroing predication. For inactive
+ // rows however, no load will occur so these need to be zeroed.
+ initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ rewriter, loc, tileType);
+ } else {
+ // Allocate a new SME tile.
+ initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ rewriter, loc, tileType);
+ }
// Create a loop to load the active tile slices from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = numRows;
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
- ValueRange{initTile});
-
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
- // tile.
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- auto currentTile = forOp.getRegionIterArg(0);
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- upperBound, memrefIndices, loc, rewriter);
- auto loadSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
- currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
-
- rewriter.setInsertionPointAfter(forOp);
+ auto forOp = createLoadStoreForOverTileSlices(
+ rewriter, loc, tileType, tileLoadOp.getIndices(),
+ tileLoadOp.getMemRefType().getRank(), mask, initTile,
+ [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate,
+ Value currentTile) -> Value {
+ // Create 'arm_sme.load_tile_slice' to load tile slice from memory
+ // into tile.
+ return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
+ currentTile, memrefIndices, tileSliceIndex,
+ tileLoadOp.getLayout());
+ });
+
+ if (failed(forOp))
+ return forOp;
// Replace 'arm_sme.tile_load' with the result.
- rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
+ rewriter.replaceOp(tileLoadOp, forOp->getResult(0));
return success();
}
@@ -345,10 +318,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto maskOp1D = rewriter.create<vector::CreateMaskOp>(
loc, predicateType, maskIndex.getResult());
- SmallVector<Value> memrefIndices;
- getMemrefIndices(tileLoadOp.getIndices(),
- tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
- numTileSlices, memrefIndices, loc, rewriter);
+ auto memrefIndices = getMemrefIndices(
+ tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(),
+ tileSliceIndex, numTileSlices, loc, rewriter);
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
@@ -400,77 +372,25 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
- OpBuilder::InsertionGuard g(rewriter);
- auto loc = tileStoreOp.getLoc();
- auto tileType = tileStoreOp.getVectorType();
- auto tileElementType = tileType.getElementType();
-
- auto predicateType =
- VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-
- Value maskCols;
- Value upperBound;
- auto maskOp = tileStoreOp.getMask();
- if (maskOp) {
- auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
- if (!createMaskOp)
- return rewriter.notifyMatchFailure(
- tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
- "currently supported");
-
- auto numRows = createMaskOp.getOperands()[0];
- auto numCols = createMaskOp.getOperands()[1];
-
- upperBound = numRows;
- maskCols =
- rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
- } else {
- // Store all tile slices if no mask.
- auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
- auto vscale =
- rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
- // This describes both the number of ZA tile slices and the number of
- // elements in a vector of SVL bits for a given element type (SVL_B,
- // SVL_H,
- // ..., SVL_Q).
- auto numTileSlices =
- rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-
- upperBound = numTileSlices;
- // Create an 'all true' predicate for the tile slice.
- maskCols = rewriter.create<arith::ConstantOp>(
- loc, DenseElementsAttr::get(predicateType, true));
- }
-
// Create a loop that stores each (active) active ZA tile slice from memory.
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-
- rewriter.setInsertionPointToStart(forOp.getBody());
-
- SmallVector<Value> memrefIndices;
- auto tileSliceIndex = forOp.getInductionVar();
- getMemrefIndices(tileStoreOp.getIndices(),
- tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
- upperBound, memrefIndices, loc, rewriter);
-
- tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
- rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
- tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
-
- return success();
+ return createLoadStoreForOverTileSlices(
+ rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
+ tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
+ tileStoreOp.getMask(),
+ [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
+ tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+ rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+ predicate, tileStoreOp.getBase(), memrefIndices,
+ tileStoreOp.getLayout());
+ });
}
};
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns
- .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
- TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
- patterns.getContext());
+ patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadNonZeroConversion,
+ TileStoreOpConversion>(patterns.getContext());
}
namespace {
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 5d79a0405114a29..292f9a4d411ff7c 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -11,9 +11,9 @@
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
-// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
>From dc62aeaee40739f4d4887381581788feb103b100 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 24 Jan 2024 10:33:29 +0000
Subject: [PATCH 2/2] Fixups
---
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 25 ++++++++-----------
1 file changed, 11 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 85ff10387628e46..056d71030bd7986 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -26,8 +26,7 @@ namespace mlir {
using namespace mlir;
namespace {
-/// Adjusts `indices` as follows for a given tile slice and returns them in
-/// `outIndices`:
+/// Returns adjusted (1-D or 2-D) `indices` for a tile slice as follows:
/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts))
/// rank 2: (indices[0] + tileSliceIndex, indices[1])
SmallVector<Value, 2> getMemrefIndices(ValueRange indices, unsigned rank,
@@ -135,11 +134,11 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
/// Lower `arm_sme.tile_load` without a mask, or with a mask and a zero pad.
///
+/// With a mask:
+///
/// BEFORE:
/// ```mlir
/// %pad = arith.constant 0 : i32
-/// %num_rows = arith.constant 2 : index
-/// %num_cols = arith.constant 4 : index
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -147,12 +146,10 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
///
/// AFTER:
/// ```mlir
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
-/// %num_rows = arith.constant 2 : index
-/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
-/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
+/// %mask_cols = vector.create_mask %num_cols : vector<[4]xi1>
+/// %loop_rows = arith.minsi %num_rows, %svl_s : index
+/// %tile = scf.for %tile_slice_idx = %c0 to %loop_rows step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice
/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
@@ -161,6 +158,9 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
/// }
/// ```
///
+/// Without a mask the lowering is pretty much identical. The only difference is
+/// %mask_cols becomes an all-true mask, and %loop_rows becomes %svl_s.
+///
/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
@@ -221,9 +221,6 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// %pad = arith.constant 1 : i32
-/// %num_rows = arith.constant 2 : index
-/// %num_cols = arith.constant 4 : index
/// %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
/// %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -232,7 +229,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
-/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+/// %pad_1d = vector.splat %pad : vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
@@ -372,7 +369,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
PatternRewriter &rewriter) const override {
- // Create a loop that stores each (active) active ZA tile slice from memory.
+ // Create a loop that stores each active ZA tile slice from memory.
return createLoadStoreForOverTileSlices(
rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),
tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
More information about the Mlir-commits
mailing list