[Mlir-commits] [mlir] [mlir][armsme][vector] Replace splat with broadcast (PR #148024)
James Newling
llvmlistbot at llvm.org
Wed Jul 23 09:52:47 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/148024
>From 35db575a7fced8afc0ebced36984184ecb16fa83 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Jul 2025 11:34:20 -0700
Subject: [PATCH 1/2] changes for splat->broadcast for armsme
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 2 +-
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 2 +-
.../VectorToArmSME/VectorToArmSME.cpp | 76 ++++---------------
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 2 +-
4 files changed, 18 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 9bc3fa3473398..5f59a73fd249d 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -607,7 +607,7 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
+ auto allActiveMask = vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 9a37b30c14813..200ddf761855b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -327,7 +327,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
+ auto pad1DOp = vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 125ea1eb60ed6..9efa34a9a3acc 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-/// {
-/// %tile_update = arm_sme.insert_tile_slice
-/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-/// vector<[4]xi32> into vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
- auto tileType = splatOp.getResult().getType();
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto loc = splatOp.getLoc();
- auto srcType = splatOp.getOperand().getType();
-
- assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
- // Avoid unused-variable warning when building without assertions.
- (void)srcType;
-
- // First, broadcast the scalar to a 1-d vector.
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = vector::BroadcastOp::create(
- rewriter, loc, tileSliceType, splatOp.getInput());
-
- auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
-
- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
- Value currentTile) {
- auto nextTile = arm_sme::InsertTileSliceOp::create(
- b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- return nextTile.getResult();
- };
-
- // Next, create a loop over ZA tile slices and "move" the generated 1-d
- // vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
- rewriter.replaceOp(splatOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..6f2766ddc6e6e 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
>From 44dd36b44d84a50d9c2c61fc35237bae3ac16528 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 23 Jul 2025 09:53:47 -0700
Subject: [PATCH 2/2] formatting after rebase
---
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 3 ++-
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 3 ++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 5f59a73fd249d..8a2e3b639aaa7 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = vector::BroadcastOp::create(rewriter, loc, predTy, one);
+ auto allActiveMask =
+ vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 200ddf761855b..e28d51220244c 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
+ auto pad1DOp =
+ vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),
More information about the Mlir-commits
mailing list