[Mlir-commits] [mlir] [mlir][ArmSME] Support vertical layout in load and store ops (PR #66758)
Cullen Rhodes
llvmlistbot at llvm.org
Thu Sep 21 00:52:36 PDT 2023
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/66758
>From 463c7c24d4c7ccb5237829b492d1f1a0ff3ea6e2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 17 Sep 2023 11:50:14 +0000
Subject: [PATCH 1/6] [mlir][ArmSME] Support vertical layout in load and store
ops
In SME a ZA tile slice is a one-dimensional set of horizontally or
vertically contiguous elements within a ZA tile. Currently the load and
store ops only support horizontal tile slices. This patch adds a tile
slice layout attribute to the load and store ops to support both
horizontal and vertical tile slices.
When lowering from Vector dialect horizontal layout is the default.
---
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h | 5 +
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 116 ++--
.../mlir/Dialect/ArmSME/IR/CMakeLists.txt | 7 +
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 18 +-
.../VectorToArmSME/VectorToArmSME.cpp | 14 +-
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp | 12 +
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt | 1 +
.../Transforms/LegalizeForLLVMExport.cpp | 93 ++-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 42 +-
.../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir | 6 +-
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 401 +++++++++++++
mlir/test/Dialect/ArmSME/roundtrip.mlir | 540 +++++++++++++++---
.../Dialect/ArmSME/vector-ops-to-sme.mlir | 16 +-
.../Vector/CPU/ArmSME/test-load-vertical.mlir | 110 ++++
14 files changed, 1193 insertions(+), 188 deletions(-)
create mode 100644 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..1a4984f3bd6ba27 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
#ifndef ARMSME_OPS
#define ARMSME_OPS
+include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
}];
let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+ let useDefaultAttributePrinterParser = 1;
}
//===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
"::llvm::cast<VectorType>($_self).getElementType())"
".getWidth())">;
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+ I32EnumAttrCase<"Horizontal", 0, "hor">,
+ I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+ let cppNamespace = "::mlir::arm_sme";
+ let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+ "layout"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// ArmSME op definitions
//===----------------------------------------------------------------------===//
@@ -239,27 +259,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
let summary = "Tile load operation";
let description = [{
Loads a 2D SME "virtual tile" from memory defined by a base and indices,
- with the shape defined by the 2D scalable vector type of the result tile.
- The slice of memory must be contiguous. The memref must be either rank 1 or
- rank 2 with dynamic dimensions, since the operation is scalable, and the
- element type must be a scalar that matches the element type of the result.
+ with the shape defined by the 2D scalable vector type of the result tile. A
+ tile slice layout attribute specifies whether the slices of the tile being
+ loaded are horizontal or vertical. The slice of memory must be contiguous.
+ The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+ the operation is scalable, and the element type must be a scalar that
+ matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+ Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a FP 32-bit element ZA tile from memory.
+ Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a 128-bit element ZA tile from memory.
+ Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
```mlir
- %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs SMETile:$result);
@@ -274,7 +300,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
}];
let assemblyFormat =
- "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+ "$layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($result)";
}
def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +309,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
let description = [{
Stores a 2D SME "virtual tile" to memory defined by a base and indices,
with the shape defined by the 2D scalable vector type of the tile being
- stored. The slice of memory must be contiguous. The memref must be either
- rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
- and the element type must be a scalar that matches the element type of the
- result.
+ stored. A tile slice layout attribute specifies whether the slices of the
+ tile being stored are horizontal or vertical. The slice of memory must be
+ contiguous. The memref must be either rank 1 or rank 2 with dynamic
+ dimensions, since the operation is scalable, and the element type must be a
+ scalar that matches the element type of the result.
+
+ The default tile slice layout when lowering from higher-level dialects is
+ horizontal.
- Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+ Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store a FP 32-bit element ZA tile to memory.
+ Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a 128-bit element ZA tile to memory.
+ Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
```mlir
- arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$valueToStore,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -314,8 +346,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
}
}];
- let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
- "`:` type($base) `,` type($valueToStore)";
+ let assemblyFormat =
+ "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+ "`:` type($base) `,` type($valueToStore)";
}
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +359,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is loaded to. The updated tile is returned as the result.
+ slice is loaded to. A tile slice layout attribute specifies whether the
+ tile slice being loaded at the given index is horizontal or vertical. The
+ updated tile is returned as the result.
The slice of memory read is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the result.
- Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+ Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
```
- Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+ Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
```
- Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+ Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
```mlir
- %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
```
}];
let arguments = (ins
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to load from">:$base,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
let results = (outs SMETile:$result);
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
}];
let assemblyFormat = [{
- $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+ $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
attr-dict `:` type($base) `,` type($result)
}];
}
@@ -374,29 +410,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
slice is defined by the dimension of the 2D scalable vector type pointed by
the index. A tile slice index describes where in the input tile the tile
- slice is stored from.
+ slice is stored from. A tile slice layout attribute specifies whether the
+ tile slice being stored from the given index is horizontal or vertical.
The slice of memory written is defined by a base and indices and must be
contiguous. The memref must be either rank 1 or rank 2, have dynamic
dimensions since the operation is scalable, and the element type must be a
scalar that matches the element type of the input tile.
- Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+ Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
```
- Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+ Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
```
- Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+ Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
```mlir
- arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
```
}];
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+ ArmSME_TileSliceLayoutAttr:$layout,
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
Variadic<Index>:$indices);
let extraClassDeclaration = [{
@@ -409,7 +447,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
}];
let assemblyFormat = [{
- $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+ $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
attr-dict `:` type($base) `,` type($tile)
}];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
set(LLVM_TARGET_DEFINITIONS ArmSME.td)
mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
///
/// BEFORE:
/// ```mlir
-/// %tile = arm_sme.tile_load %src[%c0, %c0] :
+/// %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
/// memref<?x?xi32>, vector<[4]x[4]xi32>
/// ```
///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+/// %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
/// %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
getMemrefIndices(tileLoadOp.getIndices(),
tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
- rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
- tileLoadOp.getBase(), tile,
- memrefIndices, tileSliceIndex);
+ rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+ memrefIndices, tileSliceIndex);
rewriter.setInsertionPointAfter(forOp);
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
///
/// BEFORE:
/// ```mlir
-/// arm_sme.tile_store %tile, %dest[%c0, %c0]
+/// arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
/// : memref<?x?xi32>, vector<[4]x[4]xi32
/// ```
///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// %min_svl_s = arith.constant 4 : index
/// %svl_s = arith.muli %min_svl_s, %vscale : index
/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-/// : memref<?x?xi32>, vector<[4]x[4]xi32>
+/// arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+/// %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
/// }
/// ```
struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
numTileSlices, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
- tileStoreOp.getBase(), memrefIndices);
+ tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..feaec0e035ed9fd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -65,8 +65,8 @@ namespace {
///
/// is converted to:
///
-/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-/// vector<[16]x[16]xi8>
+/// arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- writeOp, writeOp.getVector(), writeOp.getSource(),
- writeOp.getIndices());
+ writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+ writeOp.getSource(), writeOp.getIndices());
return success();
}
};
@@ -97,7 +97,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
- load, load.getVectorType(), load.getBase(), load.getIndices());
+ load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+ load.getBase(), load.getIndices());
return success();
}
@@ -113,7 +114,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
return failure();
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
- store, store.getValueToStore(), store.getBase(), store.getIndices());
+ store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+ store.getBase(), store.getIndices());
return success();
}
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..92fb146691a0beb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,8 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::arm_sme;
@@ -22,13 +24,23 @@ using namespace mlir::arm_sme;
#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
void ArmSMEDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ade33..79f6a46c7c5889e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
DEPENDS
MLIRArmSMEIncGen
+ MLIRArmSMEEnumsIncGen
LINK_LIBS PUBLIC
MLIRIR
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..eeb822aae09180b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -204,29 +204,51 @@ struct LoadTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
auto tileI32 = castTileIDToI32(tile, loc, rewriter);
- // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
+ arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+
+ // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
- rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
break;
case 16:
- rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
break;
case 32:
- rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
break;
case 64:
- rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
break;
case 128:
- rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+ loc, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
break;
}
@@ -280,28 +302,50 @@ struct StoreTileSliceToArmSMELowering
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
+ arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 128:
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ if (layout == arm_sme::TileSliceLayout::Horizontal)
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ else
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
}
@@ -479,7 +523,12 @@ void mlir::configureArmSMELegalizeForExportTarget(
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
+ arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+ arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+ arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+ arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+ arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+ arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
arm_sme::aarch64_sme_za_disable>();
target.addLegalOp<GetTileID>();
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 9ab1d79794d7659..13c2e0479ab3957 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
-// CHECK-LABEL: func.func @arm_sme_tile_load(
-// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -11,18 +11,28 @@
// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
+// CHECK-NEXT: arm_sme.load_tile_slice <hor>, %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
// -----
-// CHECK-LABEL: func.func @arm_sme_tile_store(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
-// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: @arm_sme_tile_load_ver
+// CHECK: arm_sme.load_tile_slice <ver>
+func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
@@ -30,9 +40,19 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], <hor>, %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_store_ver
+// CHECK: arm_sme.store_tile_slice {{.*}} <ver>
+func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..49bf188f913bc8c 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -15,7 +15,7 @@
func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.zero : vector<[16]x[16]xi8>
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
@@ -32,7 +32,7 @@ func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
// CHECK: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile = arm_sme.tile_load <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return %tile : vector<[16]x[16]xi8>
}
@@ -46,6 +46,6 @@ func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
new file mode 100644
index 000000000000000..c249372a013308d
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -0,0 +1,401 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+
+// Test conversion of higher-level ArmSME ops to LLVM intrinsics.
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func.func @arm_sme_load_tile_slice_hor_i8(
+// CHECK-SAME: %[[SRC:.*]]: memref<?x?xi8>,
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK: %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: return
+// CHECK: }
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
+// CHECK: arm_sme.intr.ld1w.horiz
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
+// CHECK: arm_sme.intr.ld1d.horiz
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
+// CHECK: arm_sme.intr.ld1q.horiz
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
+// CHECK: arm_sme.intr.ld1w.horiz
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
+// CHECK: arm_sme.intr.ld1d.horiz
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
+// CHECK: arm_sme.intr.ld1b.vert
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
+// CHECK: arm_sme.intr.ld1w.vert
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
+// CHECK: arm_sme.intr.ld1d.vert
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
+// CHECK: arm_sme.intr.ld1q.vert
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
+// CHECK: arm_sme.intr.ld1w.vert
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
+// CHECK: arm_sme.intr.ld1d.vert
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_store_tile_slice_hor_i8(
+// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME: %[[TILE_SLICE_INDEX:.*]]: index,
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK: %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]] : i64
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK: %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK: "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK: return
+// CHECK: }
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
+// CHECK: arm_sme.intr.st1q.horiz
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
+// CHECK: arm_sme.intr.st1b.vert
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
+// CHECK: arm_sme.intr.st1w.vert
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
+// CHECK: arm_sme.intr.st1d.vert
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
+// CHECK: arm_sme.intr.st1q.vert
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
+// CHECK: arm_sme.intr.st1w.vert
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
+// CHECK: arm_sme.intr.st1d.vert
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index bae48be87b2dcdc..66518107c17bc4f 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
// CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
%0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
@@ -70,6 +74,10 @@ func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64
return %0 : vector<[2]x[2]xf64>
}
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
@@ -142,6 +150,10 @@ func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64
return %0 : i64
}
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_get_tile_id_i8() -> i8 {
@@ -182,6 +194,10 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
return %0 : i128
}
+//===----------------------------------------------------------------------===//
+// arm_sme.zero
+//===----------------------------------------------------------------------===//
+
// -----
func.func @arm_sme_zero_i8() {
@@ -254,332 +270,676 @@ func.func @arm_sme_zero_f64() {
return
}
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load_hor_i8(%src : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
// -----
-func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load_hor_i16(%src : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
// -----
-func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_load_hor_i32(%src : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
// -----
-func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_i64(%src : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
// -----
-func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_load_hor_i128(%src : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
// -----
-func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_load_hor_f16(%src : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
// -----
-func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_load_hor_bf16(%src : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
// -----
-func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_load_hor_f32(%src : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
// -----
-func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_load_hor_f64(%src : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
// -----
-func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) {
- // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_tile_load_ver_i8(%src : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
- %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
// -----
-func.func @arm_sme_tile_store_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+func.func @arm_sme_tile_load_ver_i16(%src : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i32(%src : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i64(%src : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i128(%src : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f16(%src : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_bf16(%src : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f32(%src : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store_hor_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
// -----
-func.func @arm_sme_tile_store_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+func.func @arm_sme_tile_store_hor_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
// -----
-func.func @arm_sme_tile_store_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+func.func @arm_sme_tile_store_hor_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
// -----
-func.func @arm_sme_tile_store_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+func.func @arm_sme_tile_store_hor_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
// -----
-func.func @arm_sme_tile_store_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+func.func @arm_sme_tile_store_hor_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
// -----
-func.func @arm_sme_tile_store_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+func.func @arm_sme_tile_store_hor_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
// -----
-func.func @arm_sme_tile_store_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+func.func @arm_sme_tile_store_hor_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
// -----
-func.func @arm_sme_tile_store_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+func.func @arm_sme_tile_store_hor_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
// -----
-func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+func.func @arm_sme_tile_store_hor_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
// CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
- arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
// -----
-func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_store_ver_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
// -----
-func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_store_ver_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
// -----
-func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_ver_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
// -----
-func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_store_ver_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
// -----
-func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_store_ver_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
// -----
-func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_store_ver_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
// -----
-func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_store_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
// -----
-func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+ // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
// -----
-func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
- // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
- %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
// -----
-func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+ // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
return
}
// -----
-func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
return
}
// -----
-func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
// -----
-func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
return
}
// -----
-func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
return
}
// -----
-func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
return
}
// -----
-func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
return
}
// -----
-func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
return
}
// -----
-func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
// CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
%c0 = arith.constant 0 : index
- arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
return
}
// -----
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+ // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+ %c0 = arith.constant 0 : index
+ arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+ return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
func.func @arm_sme_move_vector_to_tile_slice_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index cb35de11ab5b3ed..3012ed156578059 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -4,7 +4,7 @@
// CHECK-SAME: %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi8>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -17,7 +17,7 @@ func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xi16>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref<?x?xi16>
@@ -30,7 +30,7 @@ func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref<?x?xi32>
@@ -43,7 +43,7 @@ func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
@@ -56,7 +56,7 @@ func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xf16>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
@@ -69,7 +69,7 @@ func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
@@ -82,7 +82,7 @@ func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref
// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
@@ -95,7 +95,7 @@ func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
new file mode 100644
index 000000000000000..ea9cc61d931266e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -0,0 +1,110 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE: -march=aarch64 -mattr=+sve,+sme \
+// DEFINE: -e %{entry_point} -entry-point-result=void \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+ %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @printTileEnd() {
+ %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+ %1 = llvm.mlir.constant(0 : index) : i64
+ %2 = llvm.getelementptr %0[%1, %1]
+ : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+ llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+ return
+}
+
+func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c1_i32 = arith.constant 1 : i32
+
+ // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+ %vscale = vector.vscale
+ %min_elts_s = arith.constant 4 : index
+ %svl_s = arith.muli %min_elts_s, %vscale : index
+ %za_s_size = arith.muli %svl_s, %svl_s : index
+
+ // Allocate memory.
+ %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+ %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+ // Fill each "row" of "mem1" with row number.
+ //
+ // For example, assuming an SVL of 128-bits:
+ //
+ // 0, 0, 0, 0
+ // 1, 1, 1, 1
+ // 2, 2, 2, 2
+ // 3, 3, 3, 3
+ //
+ %init_0 = arith.constant 0 : i32
+ scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+ %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+ vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ %val_next = arith.addi %val, %c1_i32 : i32
+ scf.yield %val_next : i32
+ }
+
+ // Load tile from "mem1" vertically.
+ %0 = arm_sme.tile_load <ver>, %mem1[%c0, %c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Store tile back to "mem2" to print.
+ // TODO: Support vector.print for 2-D scalable vectors so don't have to spill
+ // to memory and reload to print.
+ vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+ // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 0, 0, 0, 0
+ // CHECK-NEXT: ( 1, 1, 1, 1
+ // CHECK-NEXT: ( 2, 2, 2, 2
+ // CHECK-NEXT: ( 3, 3, 3, 3
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %za_s_size step %svl_s {
+ %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.print %tileslice : vector<[4]xi32>
+ }
+ func.call @printTileEnd() : () -> ()
+
+ // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xi32.
+ //
+ // CHECK: TILE BEGIN
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK-NEXT: ( 0, 1, 2, 3
+ // CHECK: TILE END
+ func.call @printTileBegin() : () -> ()
+ scf.for %i = %c0 to %za_s_size step %svl_s {
+ %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+ vector.print %tileslice : vector<[4]xi32>
+ }
+ func.call @printTileEnd() : () -> ()
+
+ return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")
>From 4f0273dbcbe726a422dd2d18abc8fa4ef8f539a2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 08:43:49 +0000
Subject: [PATCH 2/6] reformat switch
---
.../Transforms/LegalizeForLLVMExport.cpp | 138 ++++++++++--------
1 file changed, 76 insertions(+), 62 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index eeb822aae09180b..18147542e2bca73 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -207,49 +207,56 @@ struct LoadTileSliceToArmSMELowering
arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
// Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 16:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 32:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 64:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
- tileI32, tileSliceI32);
- break;
- case 128:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
- else
+ break;
+ }
+ } else {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+ tileI32, tileSliceI32);
+ break;
+ case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
tileI32, tileSliceI32);
- break;
+ break;
+ }
}
// The load intrinsics have no result, replace 'arm_sme.tile_load' with
@@ -304,49 +311,56 @@ struct StoreTileSliceToArmSMELowering
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
- switch (tileElementWidth) {
- default:
- llvm_unreachable("unexpected element type!");
- case 8:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ if (layout == arm_sme::TileSliceLayout::Horizontal) {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 16:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 16:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 32:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 32:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 64:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 64:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- else
- rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
- storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
- case 128:
- if (layout == arm_sme::TileSliceLayout::Horizontal)
+ break;
+ case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- else
+ break;
+ }
+ } else {
+ switch (tileElementWidth) {
+ default:
+ llvm_unreachable("unexpected element type!");
+ case 8:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 16:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 32:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 64:
+ rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
+ storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+ break;
+ case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
- break;
+ break;
+ }
}
return success();
>From efc55f5fa6ab0db6e07ee1fc816a351131e79dfb Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 09:26:33 +0000
Subject: [PATCH 3/6] check types in arm-sme-to-llvm.mlir
---
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 68 +++++++++----------
1 file changed, 34 insertions(+), 34 deletions(-)
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index c249372a013308d..c582877ee89306f 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -33,7 +33,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
-// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -43,7 +43,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
-// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -53,7 +53,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
-// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -63,7 +63,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
-// CHECK: arm_sme.intr.ld1q.horiz
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -73,7 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
-// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -83,7 +83,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
-// CHECK: arm_sme.intr.ld1h.horiz
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -93,7 +93,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
-// CHECK: arm_sme.intr.ld1w.horiz
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -103,7 +103,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
-// CHECK: arm_sme.intr.ld1d.horiz
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -113,7 +113,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
-// CHECK: arm_sme.intr.ld1b.vert
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -123,7 +123,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
-// CHECK: arm_sme.intr.ld1h.vert
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -133,7 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
-// CHECK: arm_sme.intr.ld1w.vert
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -143,7 +143,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
-// CHECK: arm_sme.intr.ld1d.vert
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -153,7 +153,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
-// CHECK: arm_sme.intr.ld1q.vert
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -163,7 +163,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
-// CHECK: arm_sme.intr.ld1h.vert
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -173,7 +173,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
-// CHECK: arm_sme.intr.ld1h.vert
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -183,7 +183,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
-// CHECK: arm_sme.intr.ld1w.vert
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -193,7 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
// -----
// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
-// CHECK: arm_sme.intr.ld1d.vert
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
%c0 = arith.constant 0 : index
%tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -233,7 +233,7 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
-// CHECK: arm_sme.intr.st1h.horiz
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -243,7 +243,7 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
-// CHECK: arm_sme.intr.st1w.horiz
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -253,7 +253,7 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
-// CHECK: arm_sme.intr.st1d.horiz
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -263,7 +263,7 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
-// CHECK: arm_sme.intr.st1q.horiz
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -273,7 +273,7 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
-// CHECK: arm_sme.intr.st1h.horiz
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -283,7 +283,7 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
-// CHECK: arm_sme.intr.st1h.horiz
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -293,7 +293,7 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
-// CHECK: arm_sme.intr.st1w.horiz
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -303,7 +303,7 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
-// CHECK: arm_sme.intr.st1d.horiz
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -313,7 +313,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
-// CHECK: arm_sme.intr.st1b.vert
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -323,7 +323,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
-// CHECK: arm_sme.intr.st1h.vert
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -333,7 +333,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
-// CHECK: arm_sme.intr.st1w.vert
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -343,7 +343,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
-// CHECK: arm_sme.intr.st1d.vert
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -353,7 +353,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
-// CHECK: arm_sme.intr.st1q.vert
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -363,7 +363,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
-// CHECK: arm_sme.intr.st1h.vert
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -373,7 +373,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
-// CHECK: arm_sme.intr.st1h.vert
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -383,7 +383,7 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
-// CHECK: arm_sme.intr.st1w.vert
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -393,7 +393,7 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
// -----
// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
-// CHECK: arm_sme.intr.st1d.vert
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
%c0 = arith.constant 0 : index
arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
>From 9785b05913fc61a53803108a692024985f705de6 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 09:27:03 +0000
Subject: [PATCH 4/6] add comment to integration test
---
.../Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index ea9cc61d931266e..19ff126d2173a13 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -71,6 +71,7 @@ func.func @entry() {
// to memory and reload to print.
vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+ // 1. ORIGINAL HORIZONTAL LAYOUT
// Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
//
@@ -87,6 +88,7 @@ func.func @entry() {
}
func.call @printTileEnd() : () -> ()
+ // 2. VERTICAL LAYOUT
// Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
//
>From 5182fda3db76532364e8602307b93f339d0cfa95 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 20 Sep 2023 09:42:10 +0000
Subject: [PATCH 5/6] replace enumsincgen with attrincgen
---
mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt | 3 +--
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt | 2 +-
2 files changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 7afd0d014541687..617809e482b2caa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -9,5 +9,4 @@ mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
-add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
-add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
+add_public_tablegen_target(MLIRArmSMEAttrDefsIncGen)
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 79f6a46c7c5889e..85f90a8303d466f 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -6,7 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
DEPENDS
MLIRArmSMEIncGen
- MLIRArmSMEEnumsIncGen
+ MLIRArmSMEAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRIR
>From a3322e80b7a33aba7376072999665628fd15e389 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 21 Sep 2023 08:52:27 +0100
Subject: [PATCH 6/6] Update arm-sme-to-llvm.mlir to drop "higher-level"
---
mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index c582877ee89306f..c87983113fccdb2 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-// Test conversion of higher-level ArmSME ops to LLVM intrinsics.
+// Test conversion of ArmSME ops to LLVM intrinsics.
//===----------------------------------------------------------------------===//
// arm_sme.load_tile_slice
More information about the Mlir-commits
mailing list