[Mlir-commits] [mlir] [mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings (PR #73922)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Dec 6 05:09:12 PST 2023
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/73922
>From 56738497f3b20a304efd6695b9a7955928b1d5db Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 24 Nov 2023 10:26:40 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] More precisely model dataflow in ArmSME to
SCF lowerings
Since #73253 we now loops in SSA form for tiles (i.e. loops that take
`iter_args` and yield a new tile), so this patch updates lowerings to
use that. This is a NFC, as it still lowers to the same intrinsics,
but this makes IR less 'surprising' at a higher-level, and may be
recognised by more transforms.
---
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 81 ++++++++-----
.../VectorToArmSME/VectorToArmSME.cpp | 113 ++++++++++--------
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 19 +--
.../test/Dialect/ArmSME/arith-ops-to-sme.mlir | 14 ++-
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 7 +-
5 files changed, 132 insertions(+), 102 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index fece03040dbb8..849afa36a5ff1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// AFTER:
/// ```mlir
/// %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-/// %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+/// %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
/// %vscale = vector.vscale
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
-/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-/// %ptrue_s, %tile, %tile_slice_idx
+/// %ptrue_s, %iter_tile, %tile_slice_idx
/// : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
auto tileElementType = tileType.getElementType();
// Allocate a new SME tile.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ auto currentTile = forOp.getRegionIterArg(0);
+ auto loadSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %c1 = arith.constant 1 : index
-/// %tile = arm_sme.zero : vector<[4]x[4]xi32>
+/// %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
/// %num_rows = arith.constant 2 : index
/// %num_cols = vector.create_mask %c4 : vector<[4]xi1>
-/// scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// %tile_update = arm_sme.load_tile_slice
-/// %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+/// %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
/// memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
///
@@ -202,14 +209,15 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive rows
// however, no load will occur so these need to be zeroed.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
rewriter, loc, tileType);
// Create a loop to load the active tile slices from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = numRows;
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+ ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
@@ -217,17 +225,20 @@ struct TileLoadOpWithMaskAndPadZeroConversion
// tile.
SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
upperBound, memrefIndices, loc, rewriter);
- tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
- memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ auto loadSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+ rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
+ currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
/// ```mlir
/// ...
/// %pad_1d = arith.constant dense<1> : vector<[4]xi32>
-/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...
/// %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
/// %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
/// : memref<?x?xi32>, vector<[4]xi1>,
/// vector<[4]xi32> into vector<[4]xi32>
/// // Insert slice into tile
-/// arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
-/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %slice, %iter_tile, %tile_slice_idx :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
/// ```
struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
loc, rewriter.getI32Type(), numCols);
// Allocate a new SME tile.
- auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+ auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
rewriter, loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+ step, ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
// Combine masks.
auto rowIsActive = rewriter.create<arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
- tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
- rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
- tileLoadOp.getLayout());
+ auto moveSlice =
+ tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+ rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+ tileSliceIndex, tileLoadOp.getLayout());
+ rewriter.create<scf::YieldOp>(loc, ValueRange{moveSlice});
rewriter.setInsertionPointAfter(forOp);
- // Replace 'arm_sme.tile_load' with the tile.
- rewriter.replaceOp(tileLoadOp, tile);
+ // Replace 'arm_sme.tile_load' with the result.
+ rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
return success();
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 4b3fd26c6d59e..47cdffe7a5af1 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -26,21 +26,26 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
}
/// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index. Sets the IR Builder insertion point as the loop body.
-/// Callers of this method are responsible for restoring it if needed.
-static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
- Type eltType) {
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via the callback, which returns the next tile value.
+template <typename LoopBodyCallback>
+static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
+ Location loc, Value initTile,
+ LoopBodyCallback callback) {
+ OpBuilder::InsertionGuard g(rewriter);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
- loc, arm_sme::getSMETileSliceMinNumElts(eltType));
+ loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
- auto forOp =
- rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+ ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
+ auto nextTile = callback(forOp);
+ rewriter.create<scf::YieldOp>(loc, ValueRange{nextTile});
return forOp;
}
@@ -242,27 +247,25 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
// Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
// ops that broadcast the constant to each tile slice.
- OpBuilder::InsertionGuard g(rewriter);
auto loc = constantOp.getLoc();
// Unpack 1-d vector type from 2-d vector type.
- auto tileSliceType =
- VectorType::get(tileType.getShape().drop_front(), tileElementType,
- /*scalableDims=*/{true});
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, constantOp1D, tile, tileSliceIndex);
-
- rewriter.replaceOp(constantOp, tile);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+ // slice.
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+ });
+ rewriter.replaceOp(constantOp, forOp.getResult(0));
return success();
}
@@ -277,9 +280,13 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+/// {
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// Supports scalar, 0-d vector, and 1-d vector broadcasts.
@@ -293,20 +300,16 @@ struct BroadcastOpToArmSMELowering
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = broadcastOp.getLoc();
auto srcType = broadcastOp.getSourceType();
auto srcVectorType = dyn_cast<VectorType>(srcType);
- auto tileElementType = tileType.getElementType();
Value broadcastOp1D;
if (srcType.isIntOrFloat() ||
(srcVectorType && (srcVectorType.getRank() == 0))) {
// Broadcast scalar or 0-d vector to 1-d vector.
- auto tileSliceType =
- VectorType::get(tileType.getShape().drop_front(), tileElementType,
- /*scalableDims=*/{true});
+ VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, broadcastOp.getSource());
} else if (srcVectorType && (srcVectorType.getRank() == 1))
@@ -315,18 +318,20 @@ struct BroadcastOpToArmSMELowering
else
return failure();
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Create a loop over ZA tile slices.
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
- // tile slice.
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, tile, tileSliceIndex);
-
- rewriter.replaceOp(broadcastOp, tile);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+ // to each tile slice.
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ });
+
+ rewriter.replaceOp(broadcastOp, forOp.getResult(0));
return success();
}
@@ -341,9 +346,13 @@ struct BroadcastOpToArmSMELowering
/// is converted to:
///
/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+/// {
+/// %tile_update = arm_sme.move_vector_to_tile_slice
+/// %broadcast_to_1d, %iter_tile, %tile_slice_index :
+/// vector<[4]xi32> into vector<[4]x[4]xi32>
+/// scf.yield %tile_update : vector<[4]x[4]xi32>
/// }
///
/// This is identical to vector.broadcast of a scalar.
@@ -356,11 +365,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = splatOp.getLoc();
-
auto srcType = splatOp.getOperand().getType();
- auto tileElementType = tileType.getElementType();
assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
// Avoid unused-variable warning when building without assertions.
@@ -371,17 +377,19 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
loc, tileSliceType, splatOp.getInput());
- auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Next, create a loop over ZA tile slices and "move" the generated 1-d
// vector to each slice.
- auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
- auto tileSliceIndex = forOp.getInductionVar();
-
- rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
- loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+ auto forOp =
+ createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+ auto tileSliceIndex = forOp.getInductionVar();
+ auto currentTile = forOp.getRegionIterArg(0);
+ return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+ });
- rewriter.replaceOp(splatOp, tile);
+ rewriter.replaceOp(splatOp, forOp.getResult(0));
return success();
}
@@ -424,7 +432,6 @@ struct TransposeOpToArmSMELowering
if (permutation[0] != 1 || permutation[1] != 0)
return failure();
- OpBuilder::InsertionGuard g(rewriter);
auto loc = transposeOp.getLoc();
// Allocate buffer to store input tile to.
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index efefc6c49e08f..f2787aa72ae59 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,16 +6,17 @@
// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK-DAG: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: %[[NEW_TILE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -40,10 +41,11 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
-// CHECK-DAG: %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: %[[NEW_TILE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
@@ -68,7 +70,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-DAG: %[[NUM_COLS_I32:.*]] = arith.index_castui %[[NUM_COLS]] : index to i32
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
// CHECK-NEXT: %[[ROW_IS_ACTIVE_SEXT_I32:.*]] = arith.extsi %[[ROW_IS_ACTIVE]] : i1 to i32
// CHECK-NEXT: %[[MASK:.*]] = arith.andi %[[ROW_IS_ACTIVE_SEXT_I32]], %[[NUM_COLS_I32]] : i32
@@ -77,7 +79,8 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK: arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index b8db105f9c601..43564bb3ac628 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -95,11 +95,12 @@ func.func @arith_constant_dense_2d_zero_f64() {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
+// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
-// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8>
+// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[16]x[16]xi8>) {
+// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8>
+// CHECK: scf.yield %[[NEW_TILE]] : vector<[16]x[16]xi8>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[16]x[16]xi8>) -> ()
func.func @arith_constant_dense_2d_nonzero_i8() {
%two = arith.constant dense<2> : vector<[16]x[16]xi8>
@@ -114,11 +115,12 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
+// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
-// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64>
+// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[2]x[2]xf64>) {
+// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64>
+// CHECK: scf.yield %[[NEW_TILE]] : vector<[2]x[2]xf64>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[2]x[2]xf64>) -> ()
func.func @arith_constant_dense_2d_nonzero_f64() {
%two = arith.constant dense<2.0> : vector<[2]x[2]xf64>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 5bc147c60f3a6..6ea949d9c1650 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -452,11 +452,12 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
-// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK: %[[C10:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
+// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> ()
func.func @broadcast_vec2d_from_i32(%arg0: i32) {
%0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32>
>From c1cd243d281519a2cda34e8fe6db1d3b4097b0ba Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 30 Nov 2023 12:17:27 +0000
Subject: [PATCH 2/3] Fixups
---
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 6 +++---
.../lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 2 +-
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir | 12 ++++++------
mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir | 8 ++++----
4 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 849afa36a5ff1..c3c9780318a9e 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -128,7 +128,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
+ rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
@@ -233,7 +233,7 @@ struct TileLoadOpWithMaskAndPadZeroConversion
tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
+ rewriter.create<scf::YieldOp>(loc, loadSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
@@ -363,7 +363,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
tileSliceIndex, tileLoadOp.getLayout());
- rewriter.create<scf::YieldOp>(loc, ValueRange{moveSlice});
+ rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 47cdffe7a5af1..d49bdedce16c0 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -45,7 +45,7 @@ static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
ValueRange{initTile});
rewriter.setInsertionPointToStart(forOp.getBody());
auto nextTile = callback(forOp);
- rewriter.create<scf::YieldOp>(loc, ValueRange{nextTile});
+ rewriter.create<scf::YieldOp>(loc, nextTile.getResult());
return forOp;
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index f2787aa72ae59..5d79a0405114a 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -15,8 +15,8 @@
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: %[[NEW_TILE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
+// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -44,8 +44,8 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
// CHECK-DAG: %[[TILE_ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[TILE_ZERO]]) -> (vector<[4]x[4]xi32>) {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: %[[NEW_TILE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
-// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
+// CHECK-NEXT: %[[TILE_UPDATE:.*]] = arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
@@ -79,8 +79,8 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
-// CHECK-NEXT: scf.yield %[[NEW_TILE]] : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index 43564bb3ac628..ae2d0f40f03af 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -99,8 +99,8 @@ func.func @arith_constant_dense_2d_zero_f64() {
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[16]x[16]xi8>) {
-// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8>
-// CHECK: scf.yield %[[NEW_TILE]] : vector<[16]x[16]xi8>
+// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8>
+// CHECK: scf.yield %[[TILE_UPDATE]] : vector<[16]x[16]xi8>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[16]x[16]xi8>) -> ()
func.func @arith_constant_dense_2d_nonzero_i8() {
%two = arith.constant dense<2> : vector<[16]x[16]xi8>
@@ -119,8 +119,8 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
// CHECK: %[[TILE:.*]] = scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] iter_args(%[[CURRENT_TILE:.*]] = %[[INIT_TILE]]) -> (vector<[2]x[2]xf64>) {
-// CHECK: %[[NEW_TILE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64>
-// CHECK: scf.yield %[[NEW_TILE]] : vector<[2]x[2]xf64>
+// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[CURRENT_TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64>
+// CHECK: scf.yield %[[TILE_UPDATE]] : vector<[2]x[2]xf64>
// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[2]x[2]xf64>) -> ()
func.func @arith_constant_dense_2d_nonzero_f64() {
%two = arith.constant dense<2.0> : vector<[2]x[2]xf64>
>From c8411d3fdf9affb0b315f4fdc4c61028375b877d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 6 Dec 2023 13:07:41 +0000
Subject: [PATCH 3/3] Update comment
---
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d49bdedce16c0..312e89c8f100d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -249,7 +249,9 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
// ops that broadcast the constant to each tile slice.
auto loc = constantOp.getLoc();
- // Unpack 1-d vector type from 2-d vector type.
+ // To fill a tile with a constant, we create a 1-D splat of the constant,
+ // then move that into each tile slice (the largest unit we can set at once,
+ // outside of operations like the outerproduct).
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
More information about the Mlir-commits
mailing list