[Mlir-commits] [mlir] 75a71c2 - [mlir][ArmSME] Support vertical layout in load and store ops (#66758)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 25 01:34:27 PDT 2023
Author: Cullen Rhodes
Date: 2023-09-25T09:34:23+01:00
New Revision: 75a71c27c12a943333405a3299c100b04f65a37e
URL: https://github.com/llvm/llvm-project/commit/75a71c27c12a943333405a3299c100b04f65a37e
DIFF: https://github.com/llvm/llvm-project/commit/75a71c27c12a943333405a3299c100b04f65a37e.diff
LOG: [mlir][ArmSME] Support vertical layout in load and store ops (#66758)
In SME a ZA tile slice is a one-dimensional set of horizontally or
vertically contiguous elements within a ZA tile. Currently the load and
store ops only support horizontal tile slices. This patch adds a tile
slice layout attribute to the load and store ops to support both
horizontal and vertical tile slices.
When lowering from Vector dialect horizontal layout is the default.
Added:
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
Modified:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 1ca284a3e70dcec..01a8670fd3817b0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+ let useDefaultAttributePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+ I32EnumAttrCase<"Horizontal", 0, "horizontal">,
+ I32EnumAttrCase<"Vertical", 1, "vertical">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+ "layout"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -240,28 +260,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the result tile.
- The slice of memory must be contiguous. The memref must be either rank 1 or
- rank 2 with dynamic dimensions, since the operation is scalable, and the
- element type must be a scalar that matches the element type of the result.
+ An optional tile slice layout attribute specifies whether the slices of the
+ tile being loaded are horizontal (default) or vertical. The slice of memory
+ must be contiguous. The memref must be either rank 1 or rank 2 with dynamic
+ dimensions, since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the result.
- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+ Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
```mlir
%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a FP 32-bit element ZA tile from memory.
+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a 128-bit element ZA tile from memory.
+ Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
- Variadic<Index>:$indices);
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
+ Variadic<Index>:$indices,
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ );
let results = (outs SMETile:$result);
let extraClassDeclaration = [{
@@ -274,7 +299,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];
let assemblyFormat =
- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+ "$base `[` $indices `]` (`,` $layout^)? attr-dict "
+ "`:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,29 +308,33 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
- stored. The slice of memory must be contiguous. The memref must be either
- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
- and the element type must be a scalar that matches the element type of the
- result.
+ stored. An optional tile slice layout attribute specifies whether the
+ slices of the tile being stored are horizontal (default) or vertical. The
+ slice of memory must be contiguous. The memref must be either rank 1 or
+ rank 2 with dynamic dimensions, since the operation is scalable, and the
+ element type must be a scalar that matches the element type of the result.
- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+ Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
```mlir
arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store a FP 32-bit element ZA tile to memory.
+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a 128-bit element ZA tile to memory.
+ Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices);
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Variadic<Index>:$indices,
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ );
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -314,8 +344,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];
- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
- "`:` type($base) `,` type($valueToStore)";
+ let assemblyFormat =
+ "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
+ "`:` type($base) `,` type($valueToStore)";
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,31 +357,36 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is loaded to. The updated tile is returned as the result.
+ slice is loaded to. An optional tile slice layout attribute specifies
+ whether the tile slice being loaded at the given index is horizontal
+ (default) or vertical. The updated tile is returned as the result.
The slice of memory read is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
```mlir
%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from">:$base,
- SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
+ Arg<AnyMemRef, "the reference to load from">:$base,
+ SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ );
let results = (outs SMETile:$result);
let extraClassDeclaration = [{
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];
let assemblyFormat = [{
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+ $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
attr-dict `:` type($base) `,` type($result)
}];
}
@@ -374,31 +410,36 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is stored from.
+ slice is stored from. An optional tile slice layout attribute specifies
+ whether the tile slice being stored from the given index is horizontal
+ (default) or vertical.
The slice of memory written is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.
- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+ Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
```mlir
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
- Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
- Variadic<Index>:$indices);
+ Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
+ Variadic<Index>:$indices,
+ DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
+ "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+ );
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
@@ -409,7 +450,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
let assemblyFormat = [{
- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+ $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
attr-dict `:` type($base) `,` type($tile)
}];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..617809e482b2caa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,9 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..b128165f75b9e81 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
- tileLoadOp.getBase(), tile,
- memrefIndices, tileSliceIndex);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
+ tileSliceIndex, tileLoadOp.getLayout());
rewriter.setInsertionPointAfter(forOp);
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %tile, %dest[%c0, %c0]
+/// arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
+/// <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
- tileStoreOp.getBase(), memrefIndices);
+ tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 7cbc382b0050a6e..25fed2c477a1886 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,7 +12,9 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arm_sme;
@@ -23,13 +25,23 @@ using namespace mlir::arm_sme;
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
void ArmSMEDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ade33..85f90a8303d466f 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
DEPENDS
MLIRArmSMEIncGen
+ MLIRArmSMEAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRIR
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..18147542e2bca73 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -204,30 +204,59 @@ struct LoadTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
- // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
+ arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+
+ // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 128:
+ rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ }
+ } else {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 128:
+ rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ }
}
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
@@ -280,29 +309,58 @@ struct StoreTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
+ arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 128:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ }
+ } else {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 128:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ }
}
return success();
@@ -479,7 +537,12 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 9ab1d79794d7659..95b51317cb0cf1b 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @arm_sme_tile_load(
-// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -12,7 +12,7 @@
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
+func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
@@ -20,9 +20,19 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
// -----
-// CHECK-LABEL: func.func @arm_sme_tile_store(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
-// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: @arm_sme_tile_load_ver
+// CHECK: arm_sme.load_tile_slice {{.*}} <vertical>
+func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
@@ -31,8 +41,18 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_store_ver
+// CHECK: arm_sme.store_tile_slice {{.*}} <vertical>
+func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
new file mode 100644
index 000000000000000..a6b5217181481de
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -0,0 +1,401 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+
+// Test conversion of ArmSME ops to LLVM intrinsics.
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @arm_sme_load_tile_slice_hor_i8(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi8>,
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK: %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: return
+// CHECK: }
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_store_tile_slice_hor_i8(
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK: %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK: "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: return
+// CHECK: }
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index bae48be87b2dcdc..3232bad7996b486 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
// CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
@@ -70,6 +74,10 @@ func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64
return %0 : vector<[2]x[2]xf64>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
@@ -142,6 +150,10 @@ func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64
return %0 : i64
}
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_get_tile_id_i8() -> i8 {
@@ -182,6 +194,10 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
return %0 : i128
}
+//===----------------------------------------------------------------------===//
+// arm_sme.zero
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_zero_i8() {
@@ -254,10 +270,14 @@ func.func @arm_sme_zero_f64() {
return
}
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
// -----
-func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load_hor_i8(%src : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
@@ -265,8 +285,8 @@ func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
// -----
-func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_load_hor_i16(%src : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
@@ -274,8 +294,8 @@ func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) {
// -----
-func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_i32(%src : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
@@ -283,8 +303,8 @@ func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) {
// -----
-func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_load_hor_i64(%src : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
@@ -292,8 +312,8 @@ func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) {
// -----
-func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_load_hor_i128(%src : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
@@ -301,8 +321,8 @@ func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) {
// -----
-func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_load_hor_f16(%src : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
@@ -310,8 +330,8 @@ func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) {
// -----
-func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_load_hor_bf16(%src : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
@@ -319,8 +339,8 @@ func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) {
// -----
-func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_load_hor_f32(%src : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
@@ -328,8 +348,8 @@ func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) {
// -----
-func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_tile_load_hor_f64(%src : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
@@ -337,8 +357,103 @@ func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) {
// -----
-func.func @arm_sme_tile_store_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load_ver_i8(%src : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i16(%src : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i32(%src : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i64(%src : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i128(%src : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f16(%src : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_bf16(%src : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f32(%src : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+/// Layout is optional and horizontal is the default, verify it's still parsed.
+func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load %src[%c0, %c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store_hor_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
@@ -346,8 +461,8 @@ func.func @arm_sme_tile_store_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_store_hor_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
@@ -355,8 +470,8 @@ func.func @arm_sme_tile_store_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
@@ -364,8 +479,8 @@ func.func @arm_sme_tile_store_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_store_hor_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
@@ -373,8 +488,8 @@ func.func @arm_sme_tile_store_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_store_hor_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
@@ -382,8 +497,8 @@ func.func @arm_sme_tile_store_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<
// -----
-func.func @arm_sme_tile_store_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_store_hor_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
@@ -391,8 +506,8 @@ func.func @arm_sme_tile_store_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_store_hor_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
@@ -400,8 +515,8 @@ func.func @arm_sme_tile_store_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<
// -----
-func.func @arm_sme_tile_store_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_store_hor_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
@@ -409,8 +524,8 @@ func.func @arm_sme_tile_store_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x
// -----
-func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
- // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_tile_store_hor_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
@@ -418,8 +533,103 @@ func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x
// -----
-func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+/// Layout is optional and horizontal is the default, verify it's still parsed.
+func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, %dest[%c0, %c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
return
@@ -427,8 +637,8 @@ func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]
// -----
-func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
return
@@ -436,8 +646,8 @@ func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8
// -----
-func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
return
@@ -445,8 +655,8 @@ func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4
// -----
-func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
return
@@ -454,8 +664,8 @@ func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2
// -----
-func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
return
@@ -463,8 +673,8 @@ func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<
// -----
-func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
return
@@ -472,8 +682,8 @@ func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8
// -----
-func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
@@ -481,8 +691,8 @@ func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<
// -----
-func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
return
@@ -490,8 +700,8 @@ func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4
// -----
-func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
return
@@ -499,8 +709,103 @@ func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2
// -----
-func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+/// Layout is optional and horizontal is the default, verify it's still parsed.
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
@@ -508,8 +813,8 @@ func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
@@ -517,8 +822,8 @@ func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
@@ -526,8 +831,8 @@ func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
@@ -535,8 +840,8 @@ func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
@@ -544,8 +849,8 @@ func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_sli
// -----
-func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
@@ -553,8 +858,8 @@ func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
@@ -562,8 +867,8 @@ func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_sli
// -----
-func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
@@ -571,8 +876,8 @@ func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice
// -----
-func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
- // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
@@ -580,6 +885,101 @@ func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice
// -----
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+/// Layout is optional and horizontal is the default, verify it's still parsed.
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
func.func @arm_sme_move_vector_to_tile_slice_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
new file mode 100644
index 000000000000000..d57001daf855f3d
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -0,0 +1,112 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+ %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @printTileEnd() {
+ %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_i32 = arith.constant 1 : i32
+
+ // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+ %vscale = vector.vscale
+ %min_elts_s = arith.constant 4 : index
+ %svl_s = arith.muli %min_elts_s, %vscale : index
+ %za_s_size = arith.muli %svl_s, %svl_s : index
+
+ // Allocate memory.
+ %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+ %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+ // Fill each "row" of "mem1" with row number.
+ //
+ // For example, assuming an SVL of 128-bits:
+ //
+ // 0, 0, 0, 0
+ // 1, 1, 1, 1
+ // 2, 2, 2, 2
+ // 3, 3, 3, 3
+ //
+ %init_0 = arith.constant 0 : i32
+ scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+ %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+ vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ %val_next = arith.addi %val, %c1_i32 : i32
+ scf.yield %val_next : i32
+ }
+
+ // Load tile from "mem1" vertically.
+ %0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Store tile back to "mem2" to print.
+ // TODO: Support vector.print for 2-D scalable vectors so don't have to spill
+ // to memory and reload to print.
+ vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // 1. ORIGINAL HORIZONTAL LAYOUT
+ // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 3, 3, 3, 3
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %za_s_size step %svl_s {
+ %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.print %tileslice : vector<[4]xi32>
+ }
+ func.call @printTileEnd() : () -> ()
+
+ // 2. VERTICAL LAYOUT
+ // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %za_s_size step %svl_s {
+ %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.print %tileslice : vector<[4]xi32>
+ }
+ func.call @printTileEnd() : () -> ()
+
+ return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
More information about the Mlir-commits
mailing list