[Mlir-commits] [mlir] 9e1b825 - [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops

Cullen Rhodes llvmlistbot at llvm.org
Tue Aug 1 01:20:25 PDT 2023


Author: Cullen Rhodes
Date: 2023-08-01T08:20:02Z
New Revision: 9e1b8253214572fea7b7f90f9f8dff28f75444bc

URL: https://github.com/llvm/llvm-project/commit/9e1b8253214572fea7b7f90f9f8dff28f75444bc
DIFF: https://github.com/llvm/llvm-project/commit/9e1b8253214572fea7b7f90f9f8dff28f75444bc.diff

LOG: [mlir][ArmSME] Add conversion from ArmSME to SCF to materialize loops

Currently a loop is materialized when lowering ArmSME loads and stores
to intrinsics. This patch introduces two new ops to the ArmSME dialect
that map 1-1 with intrinsics:

  1. arm_sme.load_tile_slice  - Loads a 1D tile slice from
     memory into a 2D SME "virtual tile".
  2. arm_sme.store_tile_slice - Stores a 1D tile slice from a 2D SME
     "virtual tile" into memory.

As well as a new conversion pass '-convert-arm-sme-to-scf' that
materializes loops with these ops. The existing load/store lowering to
intrinsics is updated to use these ops.

Depends on D156517

Discourse thread:
https://discourse.llvm.org/t/loop-materialization-in-armsme/72354

Reviewed By: awarzynski, dcaballe, WanderAway

Differential Revision: https://reviews.llvm.org/D156467

Added: 
    mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
    mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
    mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSME/roundtrip.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
new file mode 100644
index 00000000000000..3a28ca11862afd
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h
@@ -0,0 +1,29 @@
+//===- ArmSMEToSCF.h - Convert ArmSME to SCF dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_
+#define MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_CONVERTARMSMETOSCF
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the ArmSME dialect to SCF.
+void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns);
+
+/// Create a pass to convert a subset of ArmSME ops to SCF.
+std::unique_ptr<Pass> createConvertArmSMEToSCFPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 71d5c4aa267ee1..014e976586af4c 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
+#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
 #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 04d65964478da6..9608d771a5dd52 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1108,6 +1108,22 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
   let dependentDialects = ["arm_sme::ArmSMEDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// ArmSMEToSCF
+//===----------------------------------------------------------------------===//
+
+def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
+  let summary = "Lower the operations from the ArmSME dialect into the SCF "
+                "dialect";
+  let constructor = "mlir::createConvertArmSMEToSCFPass()";
+  let dependentDialects = [
+    "scf::SCFDialect",
+    "arith::ArithDialect",
+    "vector::VectorDialect",
+    "arm_sme::ArmSMEDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // VectorToLLVM
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 01925d023b7902..11b96f20acdfae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -16,6 +16,7 @@
 
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 
 //===----------------------------------------------------------------------===//
 // ArmSME dialect definition
@@ -307,6 +308,102 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
                        "`:` type($base) `,` type($valueToStore)";
 }
 
+def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+    AllTypesMatch<["tile", "result"]>
+]> {
+  let summary = "Tile slice load and update operation";
+  let description = [{
+    Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
+    slice is defined by the dimension of the 2D scalable vector type pointed by
+    the index. A tile slice index describes where in the input tile the tile
+    slice is loaded to. The updated tile is returned as the result.
+
+    The slice of memory read is defined by a base and indices and must be
+    contiguous. The memref must be either rank 1 or rank 2, have dynamic
+    dimensions since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the result.
+
+    Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+    ```mlir
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    ```
+
+    Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+    ```mlir
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+    ```mlir
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+    ```
+  }];
+  let arguments = (ins
+      Arg<AnyMemRef, "the reference to load from">:$base,
+      SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
+  let results = (outs SMETile:$result);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return ::llvm::cast<MemRefType>(getBase().getType());
+    }
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+
+  let assemblyFormat = [{
+    $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+      attr-dict `:` type($base) `,` type($result)
+  }];
+}
+
+def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
+  let summary = "Tile slice store operation";
+  let description = [{
+    Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
+    slice is defined by the dimension of the 2D scalable vector type pointed by
+    the index. A tile slice index describes where in the input tile the tile
+    slice is stored from.
+
+    The slice of memory written is defined by a base and indices and must be
+    contiguous. The memref must be either rank 1 or rank 2, have dynamic
+    dimensions since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the input tile.
+
+    Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+    ```mlir
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    ```
+
+    Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+    ```mlir
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    ```
+
+    Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+    ```mlir
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    ```
+  }];
+  let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+      Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+      Variadic<Index>:$indices);
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return ::llvm::cast<MemRefType>(getBase().getType());
+    }
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
+    }
+  }];
+
+  let assemblyFormat = [{
+    $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+      attr-dict `:` type($base) `,` type($tile)
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME Intrinsic op definitions
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
new file mode 100644
index 00000000000000..e143726cf234f2
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -0,0 +1,187 @@
+//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of ArmSME operations to SCF.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARMSMETOSCF
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice
+/// using `arm_sme.load_tile_slice`.
+///
+///  BEFORE:
+///  ```mlir
+///  %tile = arm_sme.tile_load %src[%c0, %c0] :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+    // TODO: use indices
+    // Create 'arm_sme.load_tile_slice' to load tile slice from
+    // memory into tile.
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex,
+        tileSliceIndex);
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
+/// slice using `arm_sme.store_tile_slice`.
+///
+///  BEFORE:
+///  ```mlir
+///  arm_sme.tile_store %tile, %dest[%c0, %c0]
+///    : memref<?x?xi32>, vector<[4]x[4]xi32
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
+///      : memref<?x?xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
+  using OpRewritePattern<arm_sme::TileStoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileStoreOp.getLoc();
+    auto tileType = tileStoreOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+
+    // Create a loop that stores each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+    // TODO: use indices
+    rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
+        tileStoreOp.getBase(), tileSliceIndex);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
+      patterns.getContext());
+}
+
+namespace {
+
+struct ConvertArmSMEToSCFPass
+    : public impl::ConvertArmSMEToSCFBase<ConvertArmSMEToSCFPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    ConversionTarget target(getContext());
+    populateArmSMEToSCFConversionPatterns(patterns);
+    target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
+                           arith::ArithDialect, scf::SCFDialect>();
+    target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createConvertArmSMEToSCFPass() {
+  return std::make_unique<ConvertArmSMEToSCFPass>();
+}

diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
new file mode 100644
index 00000000000000..3bf4d7082afe42
--- /dev/null
+++ b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_mlir_conversion_library(MLIRArmSMEToSCF
+  ArmSMEToSCF.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToSCF
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
+  MLIRArmSMEUtils
+  MLIRTransforms
+  )

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 3ec4eb4382a91f..9fabeae0710383 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(ArithCommon)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
+add_subdirectory(ArmSMEToSCF)
 add_subdirectory(AsyncToLLVM)
 add_subdirectory(BufferizationToMemRef)
 add_subdirectory(ComplexToLibm)

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index b3a747a8fe8448..d61bde971647dc 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -138,68 +138,40 @@ Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
   llvm_unreachable("memref has unexpected rank!");
 }
 
-/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics.
-///
-/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row
-/// using `arm_sme.intr.ld1*.horiz`.
-///
-///  BEFORE:
-///  ```mlir
-///  %tile = arm_sme.tile_load %base[%c0, %c0] :
-///     memref<?x?xi32>, vector<[4]x[4]xi32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %tile_id = arm_sme.get_tile_id : i32
-///  %vscale = vector.vscale
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %min_svl_s = arith.constant 4 : index
-///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  scf.for %tile_slice = %c0 to %svl_s step %c1 {
-///    // (...)
-///    "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
-///         (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-///  }
-///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-///  ```
-struct TileLoadToArmSMELowering
-    : public ConvertOpToLLVMPattern<arm_sme::TileLoadOp> {
-  using ConvertOpToLLVMPattern<arm_sme::TileLoadOp>::ConvertOpToLLVMPattern;
+/// Lower `arm_sme.load_tile_slice` to SME intrinsics.
+struct LoadTileSliceToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
+  using ConvertOpToLLVMPattern<
+      arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
-                  arm_sme::TileLoadOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
+                  arm_sme::LoadTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = tileLoadOp.getLoc();
-    auto tileType = tileLoadOp.getVectorType();
+    auto loc = loadTileSliceOp.getLoc();
+    auto tileType = loadTileSliceOp.getVectorType();
     auto tileElementType = tileType.getElementType();
     unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.get_tile_id' op.
-    auto tile = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
+    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
+    // loaded to.
+    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
+        loc, rewriter.getIntegerType(tileElementWidth),
+        loadTileSliceOp.getTile());
 
-    // Create a loop that loads each ZA tile slice from memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
         loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
     auto vscale =
         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     // This describes both the number of ZA tile slices and the number of
     // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
     // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
-    rewriter.setInsertionPointToStart(forOp.getBody());
 
     // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
-    auto memRefType = tileLoadOp.getMemRefType();
-    auto tileSlice = forOp.getInductionVar();
+    auto memRefType = loadTileSliceOp.getMemRefType();
+    auto tileSlice = loadTileSliceOp.getTileSliceIndex();
     // TODO: The 'indices' argument for the 'base' memref is currently ignored,
     // 'tileSliceIndex' should be added to 'indices[0]'.
     Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
@@ -241,52 +213,27 @@ struct TileLoadToArmSMELowering
       break;
     }
 
-    rewriter.setInsertionPointAfter(forOp);
-
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
     // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(tileLoadOp, tileType,
-                                                           tile);
+    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
+                                                           tileType, tile);
 
     return success();
   }
 };
 
-/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics.
-///
-/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row
-/// using `arm_sme.intr.st1*.horiz`.
-///
-///  BEFORE:
-///  ```mlir
-///  arm_sme.tile_store %value, %base[%c0, %c0] : memref<?x?xi32>,
-///     vector<[4]x[4]xi32
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32
-///  %vscale = vector.vscale
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %min_svl_s = arith.constant 4 : index
-///  %svl_s = arith.muli %min_svl_s, %vscale : index
-///  scf.for %tile_slice = %c0 to %svl_s step %c1 {
-///    // (...)
-///    "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
-///         (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-///  }
-///  ```
-struct TileStoreToArmSMELowering
-    : public ConvertOpToLLVMPattern<arm_sme::TileStoreOp> {
-  using ConvertOpToLLVMPattern<arm_sme::TileStoreOp>::ConvertOpToLLVMPattern;
+/// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
+struct StoreTileSliceToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
+  using ConvertOpToLLVMPattern<
+      arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
-                  arm_sme::TileStoreOp::Adaptor adaptor,
+  matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
+                  arm_sme::StoreTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = tileStoreOp.getLoc();
-    auto tileType = tileStoreOp.getVectorType();
+    auto loc = storeTileSliceOp.getLoc();
+    auto tileType = storeTileSliceOp.getVectorType();
     auto tileElementType = tileType.getElementType();
     unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
@@ -294,27 +241,21 @@ struct TileStoreToArmSMELowering
     // being stored.
     auto tile = rewriter.create<arm_sme::CastVectorToTile>(
         loc, rewriter.getIntegerType(tileElementWidth),
-        tileStoreOp.getValueToStore());
+        storeTileSliceOp.getTile());
 
-    // Create a loop that stores each ZA tile slice to memory.
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
     auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
         loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
     auto vscale =
         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     // This describes both the number of ZA tile slices and the number of
     // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
     // ..., SVL_Q).
     auto numTileSlices =
         rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
-    rewriter.setInsertionPointToStart(forOp.getBody());
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
-    auto memRefType = tileStoreOp.getMemRefType();
-    auto tileSlice = forOp.getInductionVar();
+    auto memRefType = storeTileSliceOp.getMemRefType();
+    auto tileSlice = storeTileSliceOp.getTileSliceIndex();
     // TODO: The 'indices' argument for the 'base' memref is currently ignored,
     // 'tileSliceIndex' should be added to 'indices[0]'.
     Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
@@ -340,19 +281,19 @@ struct TileStoreToArmSMELowering
       llvm_unreachable("unexpected element type!");
     case 8:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 16:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 32:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 64:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     }
 
@@ -403,6 +344,6 @@ void mlir::configureArmSMELegalizeForExportTarget(
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
-  patterns.add<ZeroOpConversion, TileLoadToArmSMELowering,
-               TileStoreToArmSMELowering>(converter);
+  patterns.add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
+               LoadTileSliceToArmSMELowering>(converter);
 }

diff  --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
new file mode 100644
index 00000000000000..b64c79663038c0
--- /dev/null
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @arm_sme_tile_load(
+// CHECK-SAME:                               %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-NEXT:    %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-NEXT:    %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store(
+// CHECK-SAME:                                %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
new file mode 100644
index 00000000000000..2c26c62ad42481
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s
+
+// This test verifies the temporary casts that are emitted when lowering to
+// intrinsics to preserve data flow are correct. Canonicalization will remove
+// these.
+
+// CHECK-LABEL: @arm_sme_zero
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK: arm_sme.intr.zero
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: scf.for
+// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.zero : vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_load
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: scf.for
+// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK:   "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: }
+// CHECK: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return %tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_store(
+// CHECK-SAME:                      %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK: scf.for
+// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index a53be0d47d2cd2..022ae272c4a35a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -351,3 +351,165 @@ func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 8d52ebb417b80d..402da661cfee44 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s
 
 // CHECK-LABEL: @transfer_write_2d_zero_i8(
 // CHECK-SAME:                             %[[ARG0:.*]]: memref<?x?xi8>)
@@ -8,13 +8,13 @@
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-DAG:  %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
-// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+// CHECK-DAG:  "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
 // CHECK-DAG:  %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
@@ -35,6 +35,7 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-SAME:                  %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -43,7 +44,7 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
@@ -52,7 +53,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
@@ -66,6 +66,7 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
 // CHECK-SAME:                                     %[[ARG0:.*]]: memref<?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -74,8 +75,11 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT:   %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT:   %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index
+// CHECK-NEXT:   %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT:   %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64
+// CHECK-NEXT:   %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64
 // CHECK-NEXT:   %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
@@ -83,7 +87,6 @@ func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
 // CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
 // CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
@@ -97,11 +100,11 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
 // CHECK-LABEL: @vector_load_i16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
 func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -113,12 +116,12 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
 // CHECK-LABEL: @vector_load_i32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
 // CHECK-NOT:   arith.extui %[[TILE_ID]]
 // CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -130,11 +133,11 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
 // CHECK-LABEL: @vector_load_i64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
 func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -146,11 +149,11 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
 // CHECK-LABEL: @vector_load_f16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
 func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -162,11 +165,11 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
 // CHECK-LABEL: @vector_load_bf16(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xbf16>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
 func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -178,12 +181,12 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
 // CHECK-LABEL: @vector_load_f32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
 // CHECK-NOT:   arith.extui %[[TILE_ID]]
 // CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
 func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -195,11 +198,11 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
 // CHECK-LABEL: @vector_load_f64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
 // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
-// CHECK:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
 func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -212,7 +215,6 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG:  %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -221,7 +223,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
-// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK:        %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
@@ -242,9 +245,9 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
 // CHECK-LABEL: @vector_store_i16(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xi16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
 func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
@@ -258,9 +261,9 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
 // CHECK-LABEL: @vector_store_i32(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
 // CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
 // CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
@@ -275,9 +278,9 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
 // CHECK-LABEL: @vector_store_i64(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xi64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
 func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
@@ -291,9 +294,9 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
 // CHECK-LABEL: @vector_store_f16(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
 func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
@@ -307,9 +310,9 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
 // CHECK-LABEL: @vector_store_bf16(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xbf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
 func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
@@ -322,9 +325,9 @@ func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf
 // CHECK-LABEL: @vector_store_f32(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xf32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
 // CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
 // CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
@@ -339,9 +342,9 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
 // CHECK-LABEL: @vector_store_f64(
 // CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xf64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
+// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
 // CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
 func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index f0db7529fd062e..0da4b1cf319e6d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index e2991d18a03a1c..082419ce05eba3 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-llvm="enable-arm-sme" \
 // RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \


        


More information about the Mlir-commits mailing list