[Mlir-commits] [mlir] e67f323 - [mlir][armsme][vector] Replace splat with broadcast (#148024)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 10:18:13 PDT 2025
Author: James Newling
Date: 2025-07-23T10:18:09-07:00
New Revision: e67f3237d6242d1c362fa52e782ddfd5ae54a8af
URL: https://github.com/llvm/llvm-project/commit/e67f3237d6242d1c362fa52e782ddfd5ae54a8af
DIFF: https://github.com/llvm/llvm-project/commit/e67f3237d6242d1c362fa52e782ddfd5ae54a8af.diff
LOG: [mlir][armsme][vector] Replace splat with broadcast (#148024)
Part of deprecation of vector.splat
RFC: https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/4
Added:
Modified:
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 9bc3fa3473398..8a2e3b639aaa7 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
/*scalableDims=*/{true});
- auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
+ auto allActiveMask =
+ vector::BroadcastOp::create(rewriter, loc, predTy, one);
// Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 9a37b30c14813..e28d51220244c 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
// Splat pad into 1-D vector matching type of tile slice.
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
+ auto pad1DOp =
+ vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
tileLoadOp.getBase(),
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 125ea1eb60ed6..9efa34a9a3acc 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
}
};
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-/// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-/// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-/// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-/// {
-/// %tile_update = arm_sme.insert_tile_slice
-/// %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-/// vector<[4]xi32> into vector<[4]x[4]xi32>
-/// scf.yield %tile_update : vector<[4]x[4]xi32>
-/// }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
- using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::SplatOp splatOp,
- PatternRewriter &rewriter) const final {
- auto tileType = splatOp.getResult().getType();
- if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
- return failure();
-
- auto loc = splatOp.getLoc();
- auto srcType = splatOp.getOperand().getType();
-
- assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
- // Avoid unused-variable warning when building without assertions.
- (void)srcType;
-
- // First, broadcast the scalar to a 1-d vector.
- VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
- Value broadcastOp1D = vector::BroadcastOp::create(
- rewriter, loc, tileSliceType, splatOp.getInput());
-
- auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
-
- auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
- Value currentTile) {
- auto nextTile = arm_sme::InsertTileSliceOp::create(
- b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
- return nextTile.getResult();
- };
-
- // Next, create a loop over ZA tile slices and "move" the generated 1-d
- // vector to each slice.
- auto forOp =
- createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
- rewriter.replaceOp(splatOp, forOp.getResult(0));
-
- return success();
- }
-};
-
/// Conversion pattern for vector.transpose.
///
/// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
}
};
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+ using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+ PatternRewriter &rewriter) const final {
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getInput());
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+ patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..6f2766ddc6e6e 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK: %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
More information about the Mlir-commits
mailing list