[Mlir-commits] [mlir] [mlir][armsme][vector] Replace splat with broadcast (PR #148024)

James Newling llvmlistbot at llvm.org
Wed Jul 23 09:52:47 PDT 2025


https://github.com/newling updated https://github.com/llvm/llvm-project/pull/148024

>From 35db575a7fced8afc0ebced36984184ecb16fa83 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 10 Jul 2025 11:34:20 -0700
Subject: [PATCH 1/2] changes for splat->broadcast for armsme

---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  2 +-
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  2 +-
 .../VectorToArmSME/VectorToArmSME.cpp         | 76 ++++---------------
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  2 +-
 4 files changed, 18 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 9bc3fa3473398..5f59a73fd249d 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -607,7 +607,7 @@ struct InsertTileSliceConversion
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
     auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
                                   /*scalableDims=*/{true});
-    auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one);
+    auto allActiveMask = vector::BroadcastOp::create(rewriter, loc, predTy, one);
 
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 9a37b30c14813..200ddf761855b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -327,7 +327,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
 
     // Splat pad into 1-D vector matching type of tile slice.
     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp);
+    auto pad1DOp = vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
 
     auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
                                                   tileLoadOp.getBase(),
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 125ea1eb60ed6..9efa34a9a3acc 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -255,66 +255,6 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
-/// Conversion pattern for vector.splat.
-///
-/// Example:
-///
-///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
-///
-/// is converted to:
-///
-///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
-///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
-///   {
-///     %tile_update = arm_sme.insert_tile_slice
-///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
-///        vector<[4]xi32> into vector<[4]x[4]xi32>
-///     scf.yield %tile_update : vector<[4]x[4]xi32>
-///   }
-///
-/// This is identical to vector.broadcast of a scalar.
-struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
-  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
-                                PatternRewriter &rewriter) const final {
-    auto tileType = splatOp.getResult().getType();
-    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
-      return failure();
-
-    auto loc = splatOp.getLoc();
-    auto srcType = splatOp.getOperand().getType();
-
-    assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
-    // Avoid unused-variable warning when building without assertions.
-    (void)srcType;
-
-    // First, broadcast the scalar to a 1-d vector.
-    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    Value broadcastOp1D = vector::BroadcastOp::create(
-        rewriter, loc, tileSliceType, splatOp.getInput());
-
-    auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
-
-    auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
-                            Value currentTile) {
-      auto nextTile = arm_sme::InsertTileSliceOp::create(
-          b, loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
-      return nextTile.getResult();
-    };
-
-    // Next, create a loop over ZA tile slices and "move" the generated 1-d
-    // vector to each slice.
-    auto forOp =
-        createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
-
-    rewriter.replaceOp(splatOp, forOp.getResult(0));
-
-    return success();
-  }
-};
-
 /// Conversion pattern for vector.transpose.
 ///
 /// Stores the input tile to memory and reloads vertically.
@@ -791,11 +731,25 @@ struct ExtractFromCreateMaskToPselLowering
   }
 };
 
+// Convert all `vector.splat` to `vector.broadcast`. There is a path from
+// `vector.broadcast` to ArmSME via another pattern.
+struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
+  using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::SplatOp splatOp,
+                                PatternRewriter &rewriter) const final {
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+                                                     splatOp.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
+  patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
                TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
                TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4ae710aa29113..6f2766ddc6e6e 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -87,7 +87,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[MASK_INDEX:.*]] = arith.index_cast %[[MASK]] : i32 to index
 // CHECK-NEXT:        %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[PAD_1D:.*]] = vector.broadcast %[[PAD]] : i32 to vector<[4]xi32>
 // CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
 // CHECK:             %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 // CHECK-NEXT:        scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>

>From 44dd36b44d84a50d9c2c61fc35237bae3ac16528 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 23 Jul 2025 09:53:47 -0700
Subject: [PATCH 2/2] formatting after rebase

---
 mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 3 ++-
 mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp   | 3 ++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 5f59a73fd249d..8a2e3b639aaa7 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -607,7 +607,8 @@ struct InsertTileSliceConversion
         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
     auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
                                   /*scalableDims=*/{true});
-    auto allActiveMask = vector::BroadcastOp::create(rewriter, loc, predTy, one);
+    auto allActiveMask =
+        vector::BroadcastOp::create(rewriter, loc, predTy, one);
 
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (insertTileSliceOp.getLayout()) {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 200ddf761855b..e28d51220244c 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -327,7 +327,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
 
     // Splat pad into 1-D vector matching type of tile slice.
     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
-    auto pad1DOp = vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
+    auto pad1DOp =
+        vector::BroadcastOp::create(rewriter, loc, tileSliceType, padOp);
 
     auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType,
                                                   tileLoadOp.getBase(),



More information about the Mlir-commits mailing list