[Mlir-commits] [mlir] [mlir][ArmSME] More precisely model dataflow in ArmSME to SCF lowerings (PR #73922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 02:45:50 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

Since #<!-- -->73253 loops over tiles in SSA form (i.e. loops that take `iter_args` and yield a new tile) are supported, so this patch updates ArmSME lowerings to this form. This is a NFC, as it still lowers to the same intrinsics, but this makes IR less 'surprising' at a higher-level, and may be recognised by more transforms.


Example:

IR Before:
```mlir
scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 
{
   arm_sme.move_vector_to_tile_slice 
     %broadcast_to_1d, %tile, %tile_slice_index : 
     vector<[4]xi32> into vector<[4]x[4]xi32>
}
// ... later use %tile
```
IR Now:
```mlir
%broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
    step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
{
   %tile_update = arm_sme.move_vector_to_tile_slice
      %broadcast_to_1d, %iter_tile, %tile_slice_index :
      vector<[4]xi32> into vector<[4]x[4]xi32>
  scf.yield %tile_update : vector<[4]x[4]xi32>
}
// ... later use %broadcast_to_tile
```




---

Patch is 28.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73922.diff


5 Files Affected:

- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+49-32) 
- (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+60-53) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+11-8) 
- (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+8-6) 
- (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+4-3) 


``````````diff
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6d5..72b476c9f049537 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,16 +61,18 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  AFTER:
 ///  ```mlir
 ///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+///  %init_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %ptrue_s, %tile, %tile_slice_idx
+///      %ptrue_s, %iter_tile, %tile_slice_idx
 ///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -88,7 +90,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto tileElementType = tileType.getElementType();
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
@@ -103,8 +105,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+                                             step, ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
@@ -121,14 +123,17 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-        rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
-        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    auto currentTile = forOp.getRegionIterArg(0);
+    auto loadSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+            rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate,
+            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
@@ -150,13 +155,15 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  ```mlir
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
-///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %init_tile = arm_sme.zero : vector<[4]x[4]xi32>
 ///  %num_rows = arith.constant 2 : index
 ///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
-///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %num_rows step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    %tile_update = arm_sme.load_tile_slice
-///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      %src[%tile_slice_idx], %num_cols, %iter_tile, %tile_slice_idx :
 ///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 ///
@@ -202,14 +209,15 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // Initialize tile with zero to satisfy padding. Inactive cols will be
     // zeroed anyway since the loads use zeroing predication. For inactive rows
     // however, no load will occur so these need to be zeroed.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
         rewriter, loc, tileType);
 
     // Create a loop to load the active tile slices from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     auto upperBound = numRows;
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
+                                             ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
@@ -217,17 +225,20 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // tile.
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
+    auto currentTile = forOp.getRegionIterArg(0);
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-        rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
-        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    auto loadSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+            rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp,
+            currentTile, memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{loadSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
@@ -249,15 +260,18 @@ struct TileLoadOpWithMaskAndPadZeroConversion
 ///  ```mlir
 ///  ...
 ///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
-///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///  %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
+///                iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
 ///    ...
 ///    %mask_1d = vector.create_mask <combined_mask> : vector<[4]xi1>
 ///    %slice = vector.maskedload %base[%tile_slice_idx, %c0], %mask_1d, %pad_1d
 ///      : memref<?x?xi32>, vector<[4]xi1>,
 ///        vector<[4]xi32> into vector<[4]xi32>
 ///    // Insert slice into tile
-///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
-///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///    %tile_update = arm_sme.move_vector_to_tile_slice
+///      %slice, %iter_tile, %tile_slice_idx :
+///      vector<[4]xi32> into vector<[4]x[4]xi32>
+///    scf.yield %tile_update : vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpWithMaskAndPadNonZeroConversion
@@ -298,7 +312,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         loc, rewriter.getI32Type(), numCols);
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
@@ -310,12 +324,13 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices,
+                                             step, ValueRange{initTile});
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
     auto tileSliceIndex = forOp.getInductionVar();
+    auto currentTile = forOp.getRegionIterArg(0);
 
     // Combine masks.
     auto rowIsActive = rewriter.create<arith::CmpIOp>(
@@ -344,14 +359,16 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
-        rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
-        tileLoadOp.getLayout());
+    auto moveSlice =
+        tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+            rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
+            tileSliceIndex, tileLoadOp.getLayout());
+    rewriter.create<scf::YieldOp>(loc, ValueRange{moveSlice});
 
     rewriter.setInsertionPointAfter(forOp);
 
-    // Replace 'arm_sme.tile_load' with the tile.
-    rewriter.replaceOp(tileLoadOp, tile);
+    // Replace 'arm_sme.tile_load' with the result.
+    rewriter.replaceOp(tileLoadOp, forOp.getResult(0));
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a84772d..250c9914b8c2823 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -26,21 +26,26 @@ static bool isSplatZero(Type elemType, DenseElementsAttr val) {
 }
 
 /// Generates a for loop over ZA tile slices where the induction variable is
-/// the tile slice index. Sets the IR Builder insertion point as the loop body.
-/// Callers of this method are responsible for restoring it if needed.
-static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
-                                        Type eltType) {
+/// the tile slice index and each iteration yields a new tile. Loop body is
+/// built via the callback, which returns the next tile value.
+template <typename LoopBodyCallback>
+static scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter,
+                                           Location loc, Value initTile,
+                                           LoopBodyCallback callback) {
+  OpBuilder::InsertionGuard g(rewriter);
   auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
   auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-      loc, arm_sme::getSMETileSliceMinNumElts(eltType));
+      loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
   auto vscale =
       rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
   auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
   auto numTileSlices =
       rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-  auto forOp =
-      rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
+                                           ValueRange{initTile});
   rewriter.setInsertionPointToStart(forOp.getBody());
+  auto nextTile = callback(forOp);
+  rewriter.create<scf::YieldOp>(loc, ValueRange{nextTile});
   return forOp;
 }
 
@@ -242,27 +247,25 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
 
     // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
     // ops that broadcast the constant to each tile slice.
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = constantOp.getLoc();
 
     // Unpack 1-d vector type from 2-d vector type.
-    auto tileSliceType =
-        VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                        /*scalableDims=*/{true});
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
     auto denseAttr1D = DenseElementsAttr::get(
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
-
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, constantOp1D, tile, tileSliceIndex);
-
-    rewriter.replaceOp(constantOp, tile);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
+          // slice.
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, constantOp1D, currentTile, tileSliceIndex);
+        });
+    rewriter.replaceOp(constantOp, forOp.getResult(0));
 
     return success();
   }
@@ -277,9 +280,13 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
 /// is converted to:
 ///
 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+///   {
+///     %tile_update = arm_sme.move_vector_to_tile_slice
+///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
+///        vector<[4]xi32> into vector<[4]x[4]xi32>
+///     scf.yield %tile_update : vector<[4]x[4]xi32>
 ///   }
 ///
 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
@@ -293,20 +300,16 @@ struct BroadcastOpToArmSMELowering
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = broadcastOp.getLoc();
 
     auto srcType = broadcastOp.getSourceType();
     auto srcVectorType = dyn_cast<VectorType>(srcType);
-    auto tileElementType = tileType.getElementType();
 
     Value broadcastOp1D;
     if (srcType.isIntOrFloat() ||
         (srcVectorType && (srcVectorType.getRank() == 0))) {
       // Broadcast scalar or 0-d vector to 1-d vector.
-      auto tileSliceType =
-          VectorType::get(tileType.getShape().drop_front(), tileElementType,
-                          /*scalableDims=*/{true});
+      VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
       broadcastOp1D = rewriter.create<vector::BroadcastOp>(
           loc, tileSliceType, broadcastOp.getSource());
     } else if (srcVectorType && (srcVectorType.getRank() == 1))
@@ -315,18 +318,20 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop over ZA tile slices.
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each
-    // tile slice.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
-
-    rewriter.replaceOp(broadcastOp, tile);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value
+          // to each tile slice.
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+        });
+
+    rewriter.replaceOp(broadcastOp, forOp.getResult(0));
 
     return success();
   }
@@ -341,9 +346,13 @@ struct BroadcastOpToArmSMELowering
 /// is converted to:
 ///
 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
-///   scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 {
-///     arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile,
-///       %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
+///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
+///   {
+///     %tile_update = arm_sme.move_vector_to_tile_slice
+///        %broadcast_to_1d, %iter_tile, %tile_slice_index :
+///        vector<[4]xi32> into vector<[4]x[4]xi32>
+///     scf.yield %tile_update : vector<[4]x[4]xi32>
 ///   }
 ///
 /// This is identical to vector.broadcast of a scalar.
@@ -356,11 +365,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = splatOp.getLoc();
-
     auto srcType = splatOp.getOperand().getType();
-    auto tileElementType = tileType.getElementType();
 
     assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
     // Avoid unused-variable warning when building without assertions.
@@ -371,17 +377,19 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
-    auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
-    auto tileSliceIndex = forOp.getInductionVar();
-
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, broadcastOp1D, tile, tileSliceIndex);
+    auto forOp =
+        createLoopOverTileSlices(rewriter, loc, initTile, [&](auto forOp) {
+          auto tileSliceIndex = forOp.getInductionVar();
+          auto currentTile = forOp.getRegionIterArg(0);
+          return rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+              loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
+        });
 
-    rewriter.replaceOp(splatOp, tile);
+    rewriter.replaceOp(splatOp, forOp.getResult(0));
 
     return success();
   }
@@ -424,7 +432,6 @@ struct TransposeOpToArmSMELowering
     if (permutation[0] != 1 || permutation[1] != 0)
       return failure();
 
-    OpBuilder::InsertionGuard g(rewriter);
     auto loc = transposeOp.getLoc();
 
     // Allocate buffer to store input tile to.
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7c0..9fe80192809f310 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,16 +6,17 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vecto...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/73922


More information about the Mlir-commits mailing list