[Mlir-commits] [mlir] ca9a335 - [mlir][ArmSME] Add tile load op and extend tile store tile size support
Cullen Rhodes
llvmlistbot at llvm.org
Tue Jul 25 01:39:55 PDT 2023
Author: Cullen Rhodes
Date: 2023-07-25T08:28:36Z
New Revision: ca9a3354d04b15366088d7831b40f891e3d77b95
URL: https://github.com/llvm/llvm-project/commit/ca9a3354d04b15366088d7831b40f891e3d77b95
DIFF: https://github.com/llvm/llvm-project/commit/ca9a3354d04b15366088d7831b40f891e3d77b95.diff
LOG: [mlir][ArmSME] Add tile load op and extend tile store tile size support
This extends the existing 'arm_sme.tile_store' op to support all tile
sizes and adds a new op 'arm_sme.tile_load', as well as lowerings from
vector -> custom ops and custom ops -> intrinsics. Currently there's no
lowering for i128.
Depends on D154867
Reviewed By: awarzynski, dcaballe
Differential Revision: https://reviews.llvm.org/D155306
Added:
mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
mlir/lib/Dialect/ArmSME/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Dialect/ArmSME/roundtrip.mlir
mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 09f8bfb314a6e9..caa6e384bdb2b5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -224,21 +224,74 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
+def TileLoadOp : ArmSME_Op<"tile_load"> {
+ let summary = "Tile load operation";
+ let description = [{
+ Loads a 2D SME "virtual tile" from memory defined by a base and indices,
+ with the shape defined by the 2D scalable vector type of the result tile.
+ The slice of memory must be contiguous. The memref must be either rank 1 or
+ rank 2 with dynamic dimensions, since the operation is scalable, and the
+ element type must be a scalar that matches the element type of the result.
+
+ Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+ ```mlir
+ %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ ```
+
+ Example 2: Load a FP 32-bit element ZA tile from memory.
+ ```mlir
+ %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ ```
+
+ Example 3: Load a 128-bit element ZA tile from memory.
+ ```mlir
+ %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ ```
+ }];
+ let arguments = (ins
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Variadic<Index>:$indices);
+ let results = (outs SMETile:$result);
+
+ let extraClassDeclaration = [{
+ MemRefType getMemRefType() {
+ return ::llvm::cast<MemRefType>(getBase().getType());
+ }
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+
+ let assemblyFormat =
+ "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+}
+
def TileStoreOp : ArmSME_Op<"tile_store"> {
let summary = "Tile store operation";
let description = [{
- Store a 2D SME "virtual tile" to memory.
-
- NOTE: At the moment it is assumed that the element type is `i8` and that
- there's only one "virtual tile".
+ Stores a 2D SME "virtual tile" to memory defined by a base and indices,
+ with the shape defined by the 2D scalable vector type of the tile being
+ stored. The slice of memory must be contiguous. The memref must be either
+ rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
+ and the element type must be a scalar that matches the element type of the
+ result.
+
+ Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+ ```mlir
+ arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ ```
- Example:
+ Example 2: Store a FP 32-bit element ZA tile to memory.
+ ```mlir
+ arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ ```
+ Example 3: Store a 128-bit element ZA tile to memory.
```mlir
- arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
- let arguments = (ins nxnxv16i8:$valueToStore,
+ let arguments = (ins SMETile:$valueToStore,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -304,7 +357,7 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
class ArmSME_IntrLoadOp<string mnemonic>
: ArmSME_IntrOp<mnemonic>,
Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
- Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
+ Arg<LLVM_AnyPointer, "Load address">,
Arg<I32, "Virtual tile ID">,
Arg<I32, "Tile slice">)>;
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
new file mode 100644
index 00000000000000..554b9f11923066
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -0,0 +1,38 @@
+//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines prototypes for various utilities for the ArmSME
+// dialect. These are not passes by themselves but are used either by passes,
+// optimization sequences, or in turn by other transformation utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
+#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+namespace mlir {
+namespace arm_sme {
+
+/// Return minimum number of elements for the given element `type` in
+/// a vector of SVL bits.
+unsigned getSMETileSliceMinNumElts(Type type);
+
+/// Returns true if `type` is a valid element type for an SME tile or false
+/// otherwise.
+bool isValidSMETileElementType(Type type);
+
+/// Returns true if `vType` is a valid vector type for an SME tile or false
+/// otherwise.
+bool isValidSMETileVectorType(VectorType vType);
+
+} // namespace arm_sme
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65e914e8b..715816a90128cd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
LINK_LIBS PUBLIC
MLIRArmSMEDialect
+ MLIRArmSMEUtils
MLIRLLVMCommonConversion
)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index cd0d99c5b5074f..4106b04877ec52 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"
@@ -76,9 +77,42 @@ struct TransferWriteToArmSMELowering
}
};
+/// Conversion pattern for vector.load.
+struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
+ using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::LoadOp load,
+ PatternRewriter &rewriter) const override {
+ if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+ load, load.getVectorType(), load.getBase(), load.getIndices());
+
+ return success();
+ }
+};
+
+/// Conversion pattern for vector.store.
+struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::StoreOp store,
+ PatternRewriter &rewriter) const override {
+ if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
+ store, store.getValueToStore(), store.getBase(), store.getIndices());
+
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<TransferWriteToArmSMELowering>(&ctx);
+ patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
+ VectorStoreToArmSMELowering>(&ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
index 9f57627c321fb0..31167e6af908b9 100644
--- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 991beae0bec9cf..8f485db4e8438b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
LINK_LIBS PUBLIC
MLIRArmSMEDialect
+ MLIRArmSMEUtils
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index e837432410de89..b3a747a8fe8448 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -19,7 +20,6 @@
using namespace mlir;
using namespace mlir::arm_sme;
-static constexpr unsigned kMinNumElts = 16;
static constexpr unsigned kZeroZAMask = 255;
namespace {
@@ -50,7 +50,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
return success();
}
};
-} // namespace
/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return
/// value. The latter is a nop, which should be folded away (e.g. during
@@ -95,68 +94,285 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
}
};
-/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row
-/// using 'arm_sme.intr.str'.
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+ tile.getDefiningOp())) &&
+ "expected ArmSME GetTileID or CastVectorToTile op!");
+ unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+ if (tileElementWidth < 32)
+ return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+ if (tileElementWidth > 32)
+ return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+ return tile;
+}
+
+/// Returns the following
+/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does
+/// the arithmetic.
+/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the
+/// index by the number of elements in a vector of SVL bits.
+/// * otherwise throws an unreachable error.
+Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex,
+ Value tileSliceNumElts, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ assert((rank == 1 || rank == 2) && "memref has unexpected rank!");
+
+ auto tileSliceIndexI64 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI64Type(), tileSliceIndex);
+
+ if (rank == 1) {
+ auto tileSliceNumEltsI64 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI64Type(), tileSliceNumElts);
+ return rewriter.create<arith::MulIOp>(loc, tileSliceIndexI64,
+ tileSliceNumEltsI64);
+ }
+
+ if (rank == 2)
+ return tileSliceIndexI64;
+
+ llvm_unreachable("memref has unexpected rank!");
+}
+
+/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics.
+///
+/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row
+/// using `arm_sme.intr.ld1*.horiz`.
+///
+/// BEFORE:
+/// ```mlir
+/// %tile = arm_sme.tile_load %base[%c0, %c0] :
+/// memref<?x?xi32>, vector<[4]x[4]xi32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %tile_id = arm_sme.get_tile_id : i32
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice = %c0 to %svl_s step %c1 {
+/// // (...)
+/// "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
+/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+/// }
+/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+/// ```
+struct TileLoadToArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::TileLoadOp> {
+ using ConvertOpToLLVMPattern<arm_sme::TileLoadOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+ arm_sme::TileLoadOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = tileLoadOp.getLoc();
+ auto tileType = tileLoadOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ // Create 'arm_sme.get_tile_id' op.
+ auto tile = rewriter.create<arm_sme::GetTileID>(
+ loc, rewriter.getIntegerType(tileElementWidth));
+
+ // Create a loop that loads each ZA tile slice from memory.
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+ auto vscale =
+ rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+ // ..., SVL_Q).
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
+ auto memRefType = tileLoadOp.getMemRefType();
+ auto tileSlice = forOp.getInductionVar();
+ // TODO: The 'indices' argument for the 'base' memref is currently ignored,
+ // 'tileSliceIndex' should be added to 'indices[0]'.
+ Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
+ numTileSlices, loc, rewriter);
+ Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ {tileSliceIndex}, rewriter);
+
+ // Cast tile slice to i32 for intrinsic.
+ auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI32Type(), tileSlice);
+
+ // Create all active predicate mask.
+ auto one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI1Type(),
+ rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+ auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
+ /*scalableDims=*/{true});
+ auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+ auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ }
+
+ rewriter.setInsertionPointAfter(forOp);
+
+ // The load intrinsics have no result, replace 'arm_sme.tile_load' with
+ // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
+ rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(tileLoadOp, tileType,
+ tile);
+
+ return success();
+ }
+};
+
+/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics.
+///
+/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row
+/// using `arm_sme.intr.st1*.horiz`.
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref<?x?xi8>,
-/// vector<[16]x[16]xi8
+/// arm_sme.tile_store %value, %base[%c0, %c0] : memref<?x?xi32>,
+/// vector<[4]x[4]xi32
/// ```
///
/// AFTER:
/// ```mlir
-/// %vscale = "llvm.intr.vscale"() : () -> index
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %c16 = arith.constant 16 : index
-/// %vec_size = arith.muli %c16, %vscale : index
-/// scf.for %row_idx = %c0 to %vec_size step %c1 {
-/// // (...)
-/// "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> ()
+/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32
+/// %vscale = vector.vscale
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %min_svl_s = arith.constant 4 : index
+/// %svl_s = arith.muli %min_svl_s, %vscale : index
+/// scf.for %tile_slice = %c0 to %svl_s step %c1 {
+/// // (...)
+/// "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) :
+/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+/// }
/// ```
-struct TileStoreOpConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
- using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
+struct TileStoreToArmSMELowering
+ : public ConvertOpToLLVMPattern<arm_sme::TileStoreOp> {
+ using ConvertOpToLLVMPattern<arm_sme::TileStoreOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(TileStoreOp store, OpAdaptor adaptor,
+ matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
+ arm_sme::TileStoreOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = store.getLoc();
+ auto loc = tileStoreOp.getLoc();
+ auto tileType = tileStoreOp.getVectorType();
+ auto tileElementType = tileType.getElementType();
+ unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+ // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
+ // being stored.
+ auto tile = rewriter.create<arm_sme::CastVectorToTile>(
+ loc, rewriter.getIntegerType(tileElementWidth),
+ tileStoreOp.getValueToStore());
- // Create loop that iterates from 0 to SVLB-1 inclusive (the number of
- // vectors in ZA) and stores each ZA vector to memory.
+ // Create a loop that stores each ZA tile slice to memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto minElems = rewriter.create<arith::ConstantIndexOp>(loc, kMinNumElts);
+ auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+ loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minElems, vscale);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ // This describes both the number of ZA tile slices and the number of
+ // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
+ // ..., SVL_Q).
+ auto numTileSlices =
+ rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+ auto forOp =
+ rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
rewriter.setInsertionPointToStart(forOp.getBody());
- // Create 'arm_sme.intr.str' intrinsic to store ZA vector.
- auto vnumI64 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI64Type(), forOp.getInductionVar());
- auto offset =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
- Value ptr =
- getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(),
- ValueRange{vnumI64, offset}, rewriter);
- auto vnumI32 = rewriter.create<arith::IndexCastUIOp>(
- loc, rewriter.getI32Type(), forOp.getInductionVar());
- rewriter.create<arm_sme::aarch64_sme_str>(loc, vnumI32, ptr);
-
- rewriter.eraseOp(store);
+ // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
+ auto memRefType = tileStoreOp.getMemRefType();
+ auto tileSlice = forOp.getInductionVar();
+ // TODO: The 'indices' argument for the 'base' memref is currently ignored,
+ // 'tileSliceIndex' should be added to 'indices[0]'.
+ Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice,
+ numTileSlices, loc, rewriter);
+ Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(),
+ {tileSliceIndex}, rewriter);
+
+ // Cast tile slice to i32 for intrinsic.
+ auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
+ loc, rewriter.getI32Type(), tileSlice);
+
+ // Create all active predicate mask.
+ auto one = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI1Type(),
+ rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
+ auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
+ /*scalableDims=*/{true});
+ auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+
+ Value tileI32 = castTileIDToI32(tile, loc, rewriter);
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
+ tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
+ tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
+ tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
+ tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ }
+
return success();
}
};
+} // namespace
+
void mlir::configureArmSMELegalizeForExportTarget(
LLVMConversionTarget &target) {
- target.addLegalOp<scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_za_enable,
- arm_sme::aarch64_sme_za_disable>();
+ target.addLegalOp<
+ scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
+ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
+ arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+ arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz,
+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable,
+ arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
// Mark 'func.func' ops as legal if either:
@@ -187,5 +403,6 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
- patterns.add<TileStoreOpConversion, ZeroOpConversion>(converter);
+ patterns.add<ZeroOpConversion, TileLoadToArmSMELowering,
+ TileStoreToArmSMELowering>(converter);
}
diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
new file mode 100644
index 00000000000000..da8517aaf80a9f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRArmSMEUtils
+ Utils.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
+
+ LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
+ MLIRDialect
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
new file mode 100644
index 00000000000000..a5908a5a8f330f
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -0,0 +1,48 @@
+//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utilities for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+static constexpr unsigned MinStreamingVectorLengthInBits = 128;
+
+unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+ assert(isValidSMETileElementType(type) && "invalid tile type!");
+ return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
+}
+
+bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+ // TODO: add support for i128.
+ return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
+ type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() ||
+ type.isF64();
+}
+
+bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+ if ((vType.getRank() != 2) && vType.allDimsScalable())
+ return false;
+
+ // TODO: add support for i128.
+ auto elemType = vType.getElementType();
+ if (!isValidSMETileElementType(elemType))
+ return false;
+
+ unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+ if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
+ return false;
+
+ return true;
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 5c1f3a9e26db0d..66be8a20fbb1b4 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,7 +1,5 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
-// -----
-
func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
// CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
@@ -194,6 +192,87 @@ func.func @arm_sme_zero() -> () {
// -----
+func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) -> () {
+ // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) -> () {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index cb52ab5ff1f134..9c76a4c48a5746 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,28 +1,30 @@
// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
-// CHECK-LABEL: @transfer_write_2d_zero_i8
-// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-LABEL: @transfer_write_2d_zero_i8(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32
// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG: %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
-// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
-// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index
-// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] {
-// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64
-// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64
-// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64
-// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
-// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32
-// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> ()
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -30,3 +32,329 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
return
}
+// -----
+
+// CHECK-LABEL: @vector_load_i8(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @vector_load_i8(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return %tile : vector<[16]x[16]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64
+// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64
+// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
+func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
+ return %tile : vector<[16]x[16]xi8>
+}
+
+
+// -----
+
+// CHECK-LABEL: @vector_load_i16(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
+func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return %tile : vector<[8]x[8]xi16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i32(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT: arith.extui %[[TILE_ID]]
+// CHECK-NOT: arith.trunci %[[TILE_ID]]
+// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return %tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_i64(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
+// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
+func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return %tile : vector<[2]x[2]xi64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f16(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
+func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return %tile : vector<[8]x[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_bf16(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
+// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
+func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return %tile : vector<[8]x[8]xbf16>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f32(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT: arith.extui %[[TILE_ID]]
+// CHECK-NOT: arith.trunci %[[TILE_ID]]
+// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
+func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return %tile : vector<[4]x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_load_f64(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
+// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
+// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
+func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
+ %c0 = arith.constant 0 : index
+ %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return %tile : vector<[2]x[2]xf64>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i8(
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
+// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index
+// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64
+// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index
+// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64
+// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64
+// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
+// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true
+// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
+// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT: }
+// CHECK-NEXT: return
+func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i16(
+// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i32(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
+// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_i64(
+// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi64>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_f16(
+// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_bf16(
+// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xbf16>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
+// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
+// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+// -----
+
+// CHECK-LABEL: @vector_store_f32(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
+// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
+// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index
+// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]]
+// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]]
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @vector_store_f64(
+// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf64>)
+// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
+// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
+// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index
+// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+ %c0 = arith.constant 0 : index
+ vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
new file mode 100644
index 00000000000000..f0db7529fd062e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -0,0 +1,192 @@
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN: mlir-translate -mlir-to-llvmir | \
+// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
+// RUN: --entry-function=za0_d_f64 \
+// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D
+
+// Integration test demonstrating load/store to/from SME ZA tile.
+
+llvm.func @printF64(f64)
+llvm.func @printOpen()
+llvm.func @printClose()
+llvm.func @printComma()
+llvm.func @printNewline()
+
+func.func @za0_d_f64() -> i32 {
+ %c0 = arith.constant 0 : index
+ %c0_f64 = arith.constant 0.0 : f64
+ %c1_f64 = arith.constant 1.0 : f64
+ %c1_index = arith.constant 1 : index
+
+ %min_elts_d = arith.constant 2 : index
+ %vscale = vector.vscale
+
+ // "svl" refers to the Streaming Vector Length and "svl_d" the number of
+ // 64-bit elements in a vector of SVL bits.
+ %svl_d = arith.muli %min_elts_d, %vscale : index
+
+ // Allocate "mem1" and fill each "row" with row number.
+ //
+ // For example, assuming an SVL of 256-bits:
+ //
+ // 0.1, 0.1, 0.1, 0.1
+ // 1.1, 1.1, 1.1, 1.1
+ // 2.1, 2.1, 2.1, 2.1
+ // 3.1, 3.1, 3.1, 3.1
+ //
+ %tilesize = arith.muli %svl_d, %svl_d : index
+ %mem1 = memref.alloca(%tilesize) : memref<?xf64>
+ %init_0 = arith.constant 0.1 : f64
+ scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) {
+ %splat_val = vector.broadcast %val : f64 to vector<[2]xf64>
+ vector.store %splat_val, %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+ %val_next = arith.addf %val, %c1_f64 : f64
+ scf.yield %val_next : f64
+ }
+
+ // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+ // 2x2xi64.
+ //
+ // CHECK-ZA0_D: ( 0.1, 0.1
+ // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
+ scf.for %i = %c0 to %tilesize step %svl_d {
+ %tileslice = vector.load %mem1[%i] : memref<?xf64>, vector<[2]xf64>
+
+ llvm.call @printOpen() : () -> ()
+ scf.for %i2 = %c0 to %svl_d step %c1_index {
+ %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+ llvm.call @printF64(%elem) : (f64) -> ()
+ %last_i = arith.subi %svl_d, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+ scf.if %isNotLastIter {
+ llvm.call @printComma() : () -> ()
+ }
+ }
+ llvm.call @printClose() : () -> ()
+ llvm.call @printNewline() : () -> ()
+ }
+
+ // Load ZA0.D from "mem1"
+ %za0_d = vector.load %mem1[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+
+ // Allocate "mem2" to store ZA0.D to
+ %mem2 = memref.alloca(%tilesize) : memref<?xf64>
+
+ // Zero "mem2"
+ scf.for %i = %c0 to %tilesize step %c1_index {
+ memref.store %c0_f64, %mem2[%i] : memref<?xf64>
+ }
+
+ // Verify "mem2" is zeroed by doing an add reduction with initial value of
+ // zero
+ %init_0_f64 = arith.constant 0.0 : f64
+ %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) {
+ %row = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+
+ %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) {
+ %t = vector.extractelement %row[%offset : index] : vector<[2]xf64>
+ %inner_add_reduce_next = arith.addf %inner_iter, %t : f64
+ scf.yield %inner_add_reduce_next : f64
+ }
+
+ %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64
+ scf.yield %add_reduce_next : f64
+ }
+
+ // CHECK-ZA0_D: 0
+ vector.print %add_reduce : f64
+
+ // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at
+ // least 2x2xi64.
+ //
+ // CHECK-ZA0_D-NEXT: ( 0, 0
+ // CHECK-ZA0_D-NEXT: ( 0, 0
+ scf.for %i = %c0 to %tilesize step %svl_d {
+ %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+
+ llvm.call @printOpen() : () -> ()
+ scf.for %i2 = %c0 to %svl_d step %c1_index {
+ %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+ llvm.call @printF64(%elem) : (f64) -> ()
+ %last_i = arith.subi %svl_d, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+ scf.if %isNotLastIter {
+ llvm.call @printComma() : () -> ()
+ }
+ }
+ llvm.call @printClose() : () -> ()
+ llvm.call @printNewline() : () -> ()
+ }
+
+ // Verify "mem1" != "mem2"
+ %init_1 = arith.constant 1 : i64
+ %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
+ %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64>
+
+ %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+ %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+ %t_i64 = arith.extui %t : i1 to i64
+ %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+ scf.yield %inner_mul_reduce_next : i64
+ }
+
+ %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+ scf.yield %mul_reduce_next : i64
+ }
+
+ // CHECK-ZA0_D: 1
+ vector.print %mul_reduce_0 : i64
+
+ // Store ZA0.D to "mem2"
+ vector.store %za0_d, %mem2[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
+
+ // Verify "mem1" == "mem2"
+ %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) {
+ %row_1 = vector.load %mem1[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %row_2 = vector.load %mem2[%vnum] : memref<?xf64>, vector<[2]xf64>
+ %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64>
+
+ %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) {
+ %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1>
+ %t_i64 = arith.extui %t : i1 to i64
+ %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64
+ scf.yield %inner_mul_reduce_next : i64
+ }
+
+ %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64
+ scf.yield %mul_reduce_next : i64
+ }
+
+ // CHECK-ZA0_D-NEXT: 1
+ vector.print %mul_reduce_1 : i64
+
+ // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+ // 2x2xi64.
+ //
+ // CHECK-ZA0_D-NEXT: ( 0.1, 0.1
+ // CHECK-ZA0_D-NEXT: ( 1.1, 1.1
+ scf.for %i = %c0 to %tilesize step %svl_d {
+ %tileslice = vector.load %mem2[%i] : memref<?xf64>, vector<[2]xf64>
+
+ llvm.call @printOpen() : () -> ()
+ scf.for %i2 = %c0 to %svl_d step %c1_index {
+ %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64>
+ llvm.call @printF64(%elem) : (f64) -> ()
+ %last_i = arith.subi %svl_d, %c1_index : index
+ %isNotLastIter = arith.cmpi ult, %i2, %last_i : index
+ scf.if %isNotLastIter {
+ llvm.call @printComma() : () -> ()
+ }
+ }
+ llvm.call @printClose() : () -> ()
+ llvm.call @printNewline() : () -> ()
+ }
+
+ %c0_i32 = arith.constant 0 : i32
+ return %c0_i32 : i32
+}
More information about the Mlir-commits
mailing list