[Mlir-commits] [mlir] ca9a335 - [mlir][ArmSME] Add tile load op and extend tile store tile size support

Cullen Rhodes llvmlistbot at llvm.org
Tue Jul 25 01:39:55 PDT 2023


Author: Cullen Rhodes
Date: 2023-07-25T08:28:36Z
New Revision: ca9a3354d04b15366088d7831b40f891e3d77b95

URL: https://github.com/llvm/llvm-project/commit/ca9a3354d04b15366088d7831b40f891e3d77b95
DIFF: https://github.com/llvm/llvm-project/commit/ca9a3354d04b15366088d7831b40f891e3d77b95.diff

LOG: [mlir][ArmSME] Add tile load op and extend tile store tile size support

This extends the existing 'arm_sme.tile_store' op to support all tile
sizes and adds a new op 'arm_sme.tile_load', as well as lowerings from
vector -> custom ops and custom ops -> intrinsics. Currently there's no
lowering for i128.

Depends on D154867

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D155306

Added: 
    mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
    mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
    mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/lib/Dialect/ArmSME/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Dialect/ArmSME/roundtrip.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 09f8bfb314a6e9..caa6e384bdb2b5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -224,21 +224,74 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+def TileLoadOp : ArmSME_Op<"tile_load"> {
+  let summary = "Tile load operation";
+  let description = [{
+    Loads a 2D SME "virtual tile" from memory defined by a base and indices,
+    with the shape defined by the 2D scalable vector type of the result tile.
+    The slice of memory must be contiguous. The memref must be either rank 1 or
+    rank 2 with dynamic dimensions, since the operation is scalable, and the
+    element type must be a scalar that matches the element type of the result.
+
+    Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+    ```
+
+    Example 2: Load a FP 32-bit element ZA tile from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
+
+    Example 3: Load a 128-bit element ZA tile from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+    ```
+  }];
+  let arguments = (ins
+      Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+      Variadic<Index>:$indices);
+  let results = (outs SMETile:$result);
+
+  let extraClassDeclaration = [{
+    MemRefType getMemRefType() {
+      return ::llvm::cast<MemRefType>(getBase().getType());
+    }
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+
+  let assemblyFormat =
+      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+}
+
 def TileStoreOp : ArmSME_Op<"tile_store"> {
   let summary = "Tile store operation";
   let description = [{
-    Store a 2D SME "virtual tile" to memory.
-
-    NOTE: At the moment it is assumed that the element type is `i8` and that
-    there's only one "virtual tile".
+    Stores a 2D SME "virtual tile" to memory defined by a base and indices,
+    with the shape defined by the 2D scalable vector type of the tile being
+    stored. The slice of memory must be contiguous. The memref must be either
+    rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
+    and the element type must be a scalar that matches the element type of the
+    result.
+
+    Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+    ```mlir
+    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    ```
 
-    Example:
+    Example 2: Store a FP 32-bit element ZA tile to memory.
+    ```mlir
+    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    ```
 
+    Example 3: Store a 128-bit element ZA tile to memory.
     ```mlir
-    arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8> 
+    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
-  let arguments = (ins nxnxv16i8:$valueToStore,
+  let arguments = (ins SMETile:$valueToStore,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -304,7 +357,7 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 class ArmSME_IntrLoadOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic>,
       Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
-                 Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
+                 Arg<LLVM_AnyPointer, "Load address">,
                  Arg<I32, "Virtual tile ID">,
                  Arg<I32, "Tile slice">)>;
 

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
new file mode 100644
index 00000000000000..554b9f11923066
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -0,0 +1,38 @@
+//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various utilities for the ArmSME
+// dialect. These are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
+#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+namespace mlir {
+namespace arm_sme {
+
+/// Return minimum number of elements for the given element `type` in
+/// a vector of SVL bits.
+unsigned getSMETileSliceMinNumElts(Type type);
+
+/// Returns true if `type` is a valid element type for an SME tile or false
+/// otherwise.
+bool isValidSMETileElementType(Type type);
+
+/// Returns true if `vType` is a valid vector type for an SME tile or false
+/// otherwise.
+bool isValidSMETileVectorType(VectorType vType);
+
+} // namespace arm_sme
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

diff  --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65e914e8b..715816a90128cd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
+  MLIRArmSMEUtils
   MLIRLLVMCommonConversion
   )

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index cd0d99c5b5074f..4106b04877ec52 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -76,9 +77,42 @@ struct TransferWriteToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.load.
+struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
+  using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::LoadOp load,
+                                PatternRewriter &rewriter) const override {
+    if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        load, load.getVectorType(), load.getBase(), load.getIndices());
+
+    return success();
+  }
+};
+
+/// Conversion pattern for vector.store.
+struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
+  using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::StoreOp store,
+                                PatternRewriter &rewriter) const override {
+    if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
+      return failure();
+
+    rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
+        store, store.getValueToStore(), store.getBase(), store.getIndices());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<TransferWriteToArmSMELowering>(&ctx);
+  patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
+               VectorStoreToArmSMELowering>(&ctx);
 }

diff  --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
index 9f57627c321fb0..31167e6af908b9 100644
--- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 991beae0bec9cf..8f485db4e8438b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
+  MLIRArmSMEUtils
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e837432410de89..b3a747a8fe8448 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -19,7 +20,6 @@
 using namespace mlir;
 using namespace mlir::arm_sme;
 
-static constexpr unsigned kMinNumElts = 16;
 static constexpr unsigned kZeroZAMask = 255;
 
 namespace {
@@ -50,7 +50,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
     return success();
   }
 };
-} // namespace
 
 /// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
 /// value. The latter is a nop, which should be folded away (e.g. during
@@ -95,68 +94,285 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
   }
 };
 
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc,
+                      ConversionPatternRewriter &rewriter) {
+  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+             tile.getDefiningOp())) &&
+         "expected ArmSME GetTileID or CastVectorToTile op!");
+  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+  if (tileElementWidth < 32)
+    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+  if (tileElementWidth > 32)
+    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+  return tile;
+}
+
+/// Returns the following
+/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does
+///   the arithmetic.
+/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the
+///   index by the number of elements in a vector of SVL bits.
+/// * otherwise throws an unreachable error.
+Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
+                           Value tileSliceNumElts, Location loc,
+                           ConversionPatternRewriter &rewriter) {
+  assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+
+  auto tileSliceIndexI64 = rewriter.create<arith::IndexCastUIOp>(
+      loc, rewriter.getI64Type(), tileSliceIndex);
+
+  if (rank == 1) {
+    auto tileSliceNumEltsI64 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI64Type(), tileSliceNumElts);
+    return rewriter.create<arith::MulIOp>(loc, tileSliceIndexI64,
+                                          tileSliceNumEltsI64);
+  }
+
+  if (rank == 2)
+    return tileSliceIndexI64;
+
+  llvm_unreachable("memref has unexpected rank!");
+}
+
+/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics.
+///
+/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row
+/// using `arm_sme.intr.ld1*.horiz`.
+///
+///  BEFORE:
+///  ```mlir
+///  %tile = arm_sme.tile_load %base[%c0, %c0] :
+///     memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice = %c0 to %svl_s step %c1 {
+///    // (...)
+///    "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
+///         (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+///  }
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  ```
+struct TileLoadToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::TileLoadOp> {
+  using ConvertOpToLLVMPattern<arm_sme::TileLoadOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                  arm_sme::TileLoadOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    // Create 'arm_sme.get_tile_id' op.
+    auto tile = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    // This describes both the number of ZA tile slices and the number of
+    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+    // ..., SVL_Q).
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
+    auto memRefType = tileLoadOp.getMemRefType();
+    auto tileSlice = forOp.getInductionVar();
+    // TODO: The 'indices' argument for the 'base' memref is currently ignored,
+    // 'tileSliceIndex' should be added to 'indices[0]'.
+    Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
+                                                numTileSlices, loc, rewriter);
+    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+                                           {tileSliceIndex}, rewriter);
+
+    // Cast tile slice to i32 for intrinsic.
+    auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI32Type(), tileSlice);
+
+    // Create all active predicate mask.
+    auto one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI1Type(),
+        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
+                                  /*scalableDims=*/{true});
+    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+    switch (tileElementWidth) {
+    default:
+      llvm_unreachable("unexpected element type!");
+    case 8:
+      rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, allActiveMask, ptr,
+                                                       tileI32, tileSliceI32);
+      break;
+    case 16:
+      rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, allActiveMask, ptr,
+                                                       tileI32, tileSliceI32);
+      break;
+    case 32:
+      rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, allActiveMask, ptr,
+                                                       tileI32, tileSliceI32);
+      break;
+    case 64:
+      rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
+                                                       tileI32, tileSliceI32);
+      break;
+    }
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // The load intrinsics have no result, replace 'arm_sme.tile_load' with
+    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
+    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(tileLoadOp, tileType,
+                                                           tile);
+
+    return success();
+  }
+};
+
+/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics.
+///
+/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row
+/// using `arm_sme.intr.st1*.horiz`.
 ///
 ///  BEFORE:
 ///  ```mlir
-///     arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-///     vector<[16]x[16]xi8
+///  arm_sme.tile_store %value, %base[%c0, %c0] : memref<?x?xi32>,
+///     vector<[4]x[4]xi32
 ///  ```
 ///
 ///  AFTER:
 ///  ```mlir
-///      %vscale = "llvm.intr.vscale"() : () -> index
-///      %c0 = arith.constant 0 : index
-///      %c1 = arith.constant 1 : index
-///      %c16 = arith.constant 16 : index
-///      %vec_size = arith.muli %c16, %vscale : index
-///      scf.for %row_idx = %c0 to %vec_size step %c1 {
-///        // (...)
-///        "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> ()
+///  %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice = %c0 to %svl_s step %c1 {
+///    // (...)
+///    "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
+///         (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+///  }
 ///  ```
-struct TileStoreOpConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
-  using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+struct TileStoreToArmSMELowering
+    : public ConvertOpToLLVMPattern<arm_sme::TileStoreOp> {
+  using ConvertOpToLLVMPattern<arm_sme::TileStoreOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(TileStoreOp store, OpAdaptor adaptor,
+  matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
+                  arm_sme::TileStoreOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = store.getLoc();
+    auto loc = tileStoreOp.getLoc();
+    auto tileType = tileStoreOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
+    // being stored.
+    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
+        loc, rewriter.getIntegerType(tileElementWidth),
+        tileStoreOp.getValueToStore());
 
-    // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
-    // vectors in ZA) and stores each ZA vector to memory.
+    // Create a loop that stores each ZA tile slice to memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
     auto vscale =
         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    // This describes both the number of ZA tile slices and the number of
+    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+    // ..., SVL_Q).
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
     rewriter.setInsertionPointToStart(forOp.getBody());
 
-    // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
-    auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
-        loc, rewriter.getI64Type(), forOp.getInductionVar());
-    auto offset =
-        rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
-    Value ptr =
-        getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(),
-                             ValueRange{vnumI64, offset}, rewriter);
-    auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
-        loc, rewriter.getI32Type(), forOp.getInductionVar());
-    rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
-
-    rewriter.eraseOp(store);
+    // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
+    auto memRefType = tileStoreOp.getMemRefType();
+    auto tileSlice = forOp.getInductionVar();
+    // TODO: The 'indices' argument for the 'base' memref is currently ignored,
+    // 'tileSliceIndex' should be added to 'indices[0]'.
+    Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
+                                                numTileSlices, loc, rewriter);
+    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+                                           {tileSliceIndex}, rewriter);
+
+    // Cast tile slice to i32 for intrinsic.
+    auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
+        loc, rewriter.getI32Type(), tileSlice);
+
+    // Create all active predicate mask.
+    auto one = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getI1Type(),
+        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
+                                  /*scalableDims=*/{true});
+    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+    Value tileI32 = castTileIDToI32(tile, loc, rewriter);
+    switch (tileElementWidth) {
+    default:
+      llvm_unreachable("unexpected element type!");
+    case 8:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
+          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      break;
+    case 16:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
+          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      break;
+    case 32:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
+          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      break;
+    case 64:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
+          tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      break;
+    }
+
     return success();
   }
 };
 
+} // namespace
+
 void mlir::configureArmSMELegalizeForExportTarget(
     LLVMConversionTarget &target) {
-  target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
-                    arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
-                    arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
-                    arm_sme::aarch64_sme_za_disable>();
+  target.addLegalOp<
+      scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+      arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz,
+      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable,
+      arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
 
   // Mark 'func.func' ops as legal if either:
@@ -187,5 +403,6 @@ void mlir::configureArmSMELegalizeForExportTarget(
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
-  patterns.add<TileStoreOpConversion, ZeroOpConversion>(converter);
+  patterns.add<ZeroOpConversion, TileLoadToArmSMELowering,
+               TileStoreToArmSMELowering>(converter);
 }

diff  --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
new file mode 100644
index 00000000000000..da8517aaf80a9f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRArmSMEUtils
+  Utils.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
+
+  LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
+  MLIRDialect
+  MLIRIR
+  )

diff  --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
new file mode 100644
index 00000000000000..a5908a5a8f330f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+static constexpr unsigned MinStreamingVectorLengthInBits = 128;
+
+unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+  assert(isValidSMETileElementType(type) && "invalid tile type!");
+  return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
+}
+
+bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+  // TODO: add support for i128.
+  return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
+         type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() ||
+         type.isF64();
+}
+
+bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+  if ((vType.getRank() != 2) && vType.allDimsScalable())
+    return false;
+
+  // TODO: add support for i128.
+  auto elemType = vType.getElementType();
+  if (!isValidSMETileElementType(elemType))
+    return false;
+
+  unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+    return false;
+
+  return true;
+}

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 5c1f3a9e26db0d..66be8a20fbb1b4 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,7 +1,5 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// -----
-
 func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
   // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
@@ -194,6 +192,87 @@ func.func @arm_sme_zero() -> () {
 
 // -----
 
+func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) -> () {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index cb52ab5ff1f134..9c76a4c48a5746 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,28 +1,30 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
 
-// CHECK-LABEL: @transfer_write_2d_zero_i8
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-LABEL: @transfer_write_2d_zero_i8(
+// CHECK-SAME:                             %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
 // CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-DAG:  %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
 // CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
-// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
-// CHECK-NEXT:   %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
-// CHECK-NEXT:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]]  : i64
-// CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]]  : i64
-// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
-// CHECK-NEXT:   %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
-// CHECK-NEXT:   "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
+// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -30,3 +32,329 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: @vector_load_i8(
+// CHECK-SAME:                  %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return %tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
+// CHECK-SAME:                                     %[[ARG0:.*]]: memref<?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT:   %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64
+// CHECK-NEXT:   %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
+// CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+  return %tile : vector<[16]x[16]xi8>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_load_i16(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK:   arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
+func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return %tile : vector<[8]x[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i32(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
+// CHECK:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT:   arith.extui %[[TILE_ID]]
+// CHECK-NOT:   arith.trunci %[[TILE_ID]]
+// CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return %tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i64(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
+// CHECK:   arm_sme.intr.ld1d.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
+func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return %tile : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f16(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK:   arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return %tile : vector<[8]x[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_bf16(
+// CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xbf16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK:   arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
+func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return %tile : vector<[8]x[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f32(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT:   arith.extui %[[TILE_ID]]
+// CHECK-NOT:   arith.trunci %[[TILE_ID]]
+// CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
+func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return %tile : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f64(
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
+// CHECK:   arm_sme.intr.ld1d.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
+func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
+  %c0 = arith.constant 0 : index
+  %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return %tile : vector<[2]x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i8(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT:   %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
+// CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT:   %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT:   %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i16(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xi16>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK:   arm_sme.intr.st1h.horiz
+func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i32(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
+// CHECK:     %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
+// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
+// CHECK:       arm_sme.intr.st1w.horiz
+func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i64(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
+// CHECK:   arm_sme.intr.st1d.horiz
+func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_f16(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK:   arm_sme.intr.st1h.horiz
+func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_bf16(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xbf16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK:   arm_sme.intr.st1h.horiz
+func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+// -----
+
+// CHECK-LABEL: @vector_store_f32(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK:     %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
+// CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK:     %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
+// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
+// CHECK:       arm_sme.intr.st1w.horiz
+func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_f64(
+// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK:   %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
+// CHECK:   arm_sme.intr.st1d.horiz
+func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+  %c0 = arith.constant 0 : index
+  vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
new file mode 100644
index 00000000000000..f0db7529fd062e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
+// RUN:   --entry-function=za0_d_f64 \
+// RUN:   --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D
+
+// Integration test demonstrating load/store to/from SME ZA tile.
+
+llvm.func @printF64(f64)
+llvm.func @printOpen()
+llvm.func @printClose()
+llvm.func @printComma()
+llvm.func @printNewline()
+
+func.func @za0_d_f64() -> i32 {
+  %c0 = arith.constant 0 : index
+  %c0_f64 = arith.constant 0.0 : f64
+  %c1_f64 = arith.constant 1.0 : f64
+  %c1_index = arith.constant 1 : index
+
+  %min_elts_d = arith.constant 2 : index
+  %vscale = vector.vscale
+
+  // "svl" refers to the Streaming Vector Length and "svl_d" the number of
+  // 64-bit elements in a vector of SVL bits.
+  %svl_d = arith.muli %min_elts_d, %vscale : index
+
+  // Allocate "mem1" and fill each "row" with row number.
+  //
+  // For example, assuming an SVL of 256-bits:
+  //
+  //   0.1, 0.1, 0.1, 0.1
+  //   1.1, 1.1, 1.1, 1.1
+  //   2.1, 2.1, 2.1, 2.1
+  //   3.1, 3.1, 3.1, 3.1
+  //
+  %tilesize = arith.muli %svl_d, %svl_d : index
+  %mem1 = memref.alloca(%tilesize) : memref<?xf64>
+  %init_0 = arith.constant 0.1 : f64
+  scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
+    %splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
+    vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+    %val_next = arith.addf %val, %c1_f64 : f64
+    scf.yield %val_next : f64
+  }
+
+  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xi64.
+  //
+  // CHECK-ZA0_D:      ( 0.1, 0.1
+  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
+  scf.for %i = %c0 to %tilesize step %svl_d {
+    %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+
+    llvm.call @printOpen() : () -> ()
+    scf.for %i2 = %c0 to %svl_d step %c1_index {
+      %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+      llvm.call @printF64(%elem) : (f64) -> ()
+      %last_i = arith.subi %svl_d, %c1_index : index
+      %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+      scf.if %isNotLastIter {
+        llvm.call @printComma() : () -> ()
+      }
+    }
+    llvm.call @printClose() : () -> ()
+    llvm.call @printNewline() : () -> ()
+  }
+
+  // Load ZA0.D from "mem1"
+  %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+
+  // Allocate "mem2" to store ZA0.D to
+  %mem2 = memref.alloca(%tilesize) : memref<?xf64>
+
+  // Zero "mem2"
+  scf.for %i = %c0 to %tilesize step %c1_index {
+    memref.store %c0_f64, %mem2[%i] : memref<?xf64>
+  }
+
+  // Verify "mem2" is zeroed by doing an add reduction with initial value of
+  // zero
+  %init_0_f64 = arith.constant 0.0 : f64
+  %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
+    %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+
+    %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
+      %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
+      %inner_add_reduce_next = arith.addf %inner_iter, %t : f64
+      scf.yield %inner_add_reduce_next : f64
+    }
+
+    %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64
+    scf.yield %add_reduce_next : f64
+  }
+
+  // CHECK-ZA0_D: 0
+  vector.print %add_reduce : f64
+
+  // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at
+  // least 2x2xi64.
+  //
+  // CHECK-ZA0_D-NEXT: ( 0, 0
+  // CHECK-ZA0_D-NEXT: ( 0, 0
+  scf.for %i = %c0 to %tilesize step %svl_d {
+    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+
+    llvm.call @printOpen() : () -> ()
+    scf.for %i2 = %c0 to %svl_d step %c1_index {
+      %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+      llvm.call @printF64(%elem) : (f64) -> ()
+      %last_i = arith.subi %svl_d, %c1_index : index
+      %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+      scf.if %isNotLastIter {
+        llvm.call @printComma() : () -> ()
+      }
+    }
+    llvm.call @printClose() : () -> ()
+    llvm.call @printNewline() : () -> ()
+  }
+
+  // Verify "mem1" != "mem2"
+  %init_1 = arith.constant 1 : i64
+  %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
+    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
+    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+    %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
+
+    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+      %t_i64 = arith.extui %t : i1 to i64
+      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+      scf.yield %inner_mul_reduce_next : i64
+    }
+
+    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+    scf.yield %mul_reduce_next : i64
+  }
+
+  // CHECK-ZA0_D: 1
+  vector.print %mul_reduce_0 : i64
+
+  // Store ZA0.D to "mem2"
+  vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+
+  // Verify "mem1" == "mem2"
+  %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
+    %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
+    %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+    %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64>
+
+    %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+      %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+      %t_i64 = arith.extui %t : i1 to i64
+      %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+      scf.yield %inner_mul_reduce_next : i64
+    }
+
+    %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+    scf.yield %mul_reduce_next : i64
+  }
+
+  // CHECK-ZA0_D-NEXT: 1
+  vector.print %mul_reduce_1 : i64
+
+  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xi64.
+  //
+  // CHECK-ZA0_D-NEXT: ( 0.1, 0.1
+  // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
+  scf.for %i = %c0 to %tilesize step %svl_d {
+    %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+
+    llvm.call @printOpen() : () -> ()
+    scf.for %i2 = %c0 to %svl_d step %c1_index {
+      %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+      llvm.call @printF64(%elem) : (f64) -> ()
+      %last_i = arith.subi %svl_d, %c1_index : index
+      %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+      scf.if %isNotLastIter {
+        llvm.call @printComma() : () -> ()
+      }
+    }
+    llvm.call @printClose() : () -> ()
+    llvm.call @printNewline() : () -> ()
+  }
+
+  %c0_i32 = arith.constant 0 : i32
+  return %c0_i32 : i32
+}


        


More information about the Mlir-commits mailing list